Module molcrawl.protein_sequence.dataset.prepare_proteingym
ProteinGym Dataset Preparation Script for GPT-2 / BERT Fine-tuning
Converts ProteinGym v1.3 DMS substitution assay CSVs into the training_ready_hf_dataset format compatible with the existing protein_sequence training pipeline.
Pipeline
- Read every
<assay>.csvfrom the substitutions directory. - Collect
mutated_sequence(mutant variants) ANDtarget_seq(wild-type) as training sequences. - Deduplicate by exact sequence string.
- Shuffle and apply an 80 / 10 / 10 train / valid / test split.
- Tokenise with
EsmSequenceTokenizer(character-level, vocab_size=33). - Concatenate all token sequences into one long stream (with EOS between sequences), then chunk into fixed-length blocks of context_length.
- Save as a HuggingFace DatasetDict to
output_dir/training_ready_hf_dataset/.
Usage (standalone): LEARNING_SOURCE_DIR=learning_source_20260311 \ python -m molcrawl.protein_sequence.dataset.prepare_proteingym \ assets/configs/protein_sequence.yaml
The output is saved to:
$LEARNING_SOURCE_DIR/protein_sequence/proteingym/training_ready_hf_dataset/
Functions
def prepare_proteingym(source_dir: str | pathlib.Path,
output_dir: str | pathlib.Path,
*,
context_length: int = 1024,
train_ratio: float = 0.8,
seed: int = 42,
num_proc: int = 4) ‑> str-
Expand source code
def prepare_proteingym( source_dir: Union[str, Path], output_dir: Union[str, Path], *, context_length: int = 1024, train_ratio: float = 0.8, seed: int = 42, num_proc: int = 4, ) -> str: """ Load ProteinGym DMS substitution CSV files from *source_dir*, build a language-model training dataset, and save it to *output_dir*. Args: source_dir: Directory containing individual assay ``*.csv`` files (produced by ``download_proteingym()``). output_dir: Parent directory; the dataset is written under ``output_dir/training_ready_hf_dataset/``. context_length: Token block length (default 1024). train_ratio: Fraction of sequences used for training (rest split 50/50 between validation and test). seed: Random seed for reproducible shuffling. num_proc: Parallel workers for HuggingFace Dataset.map() operations. Returns: Absolute path string of the saved dataset directory. """ import pandas as pd source_dir = Path(source_dir) output_dir = Path(output_dir) if not source_dir.exists(): raise FileNotFoundError( f"ProteinGym source directory not found: {source_dir}\n" "Run the download step first:\n" " python -m molcrawl.preparation.preparation_script_protein_sequence" " assets/configs/protein_sequence.yaml --datasets proteingym --download-only" ) # ------------------------------------------------------------------ # 1. Collect sequences from all assay CSV files # ------------------------------------------------------------------ csv_files = sorted(source_dir.glob("*.csv")) if not csv_files: raise FileNotFoundError(f"No CSV files found in {source_dir}") logger.info("Found %d assay CSV files in %s", len(csv_files), source_dir) sequences: List[str] = [] for csv_path in csv_files: try: df = pd.read_csv(csv_path, usecols=lambda c: c in {"mutated_sequence", "target_seq"}) except Exception as exc: logger.warning("Skipping %s: %s", csv_path.name, exc) continue if "mutated_sequence" in df.columns: sequences.extend(df["mutated_sequence"].dropna().tolist()) if "target_seq" in df.columns: sequences.extend(df["target_seq"].dropna().tolist()) logger.info("Total sequences collected (before dedup): %d", len(sequences)) # ------------------------------------------------------------------ # 2. Deduplicate # ------------------------------------------------------------------ sequences = list(dict.fromkeys(s for s in sequences if isinstance(s, str) and s.strip())) logger.info("Unique sequences after deduplication: %d", len(sequences)) # ------------------------------------------------------------------ # 3. Random 80 / 10 / 10 split # ------------------------------------------------------------------ rng = np.random.default_rng(seed) idx = rng.permutation(len(sequences)) n_train = int(len(idx) * train_ratio) n_val = int(len(idx) * (1 - train_ratio) / 2) train_seqs = [sequences[i] for i in idx[:n_train]] val_seqs = [sequences[i] for i in idx[n_train : n_train + n_val]] test_seqs = [sequences[i] for i in idx[n_train + n_val :]] logger.info( "Split — train: %d, valid: %d, test: %d", len(train_seqs), len(val_seqs), len(test_seqs), ) raw_split = DatasetDict( { "train": Dataset.from_dict({"sequence": train_seqs}), "valid": Dataset.from_dict({"sequence": val_seqs}), "test": Dataset.from_dict({"sequence": test_seqs}), } ) # ------------------------------------------------------------------ # 4. Tokenise # ------------------------------------------------------------------ logger.info("Initialising EsmSequenceTokenizer...") tokenizer = EsmSequenceTokenizer() eos_token_id = tokenizer.eos_token_id logger.info("vocab_size=%d eos_token_id=%d", tokenizer.vocab_size, eos_token_id) logger.info("Tokenising sequences...") def _tokenize_batch(examples): results = [] for seq in examples["sequence"]: encoded = tokenizer(str(seq), add_special_tokens=True) results.append(encoded["input_ids"]) return {"input_ids": results} tokenized = raw_split.map( _tokenize_batch, batched=True, batch_size=1000, remove_columns=["sequence"], num_proc=num_proc, desc="Tokenising", ) # ------------------------------------------------------------------ # 5. Concatenate into a single stream, then chunk # ------------------------------------------------------------------ logger.info("Concatenating and chunking to length %d...", context_length) concatenated = tokenized.map( partial(_concatenate_texts, eos_token_id=eos_token_id), batched=True, batch_size=-1, remove_columns=tokenized["train"].column_names, desc="Concatenating", ) chunked = concatenated.map( partial(_create_chunks, context_length=context_length), batched=True, batch_size=-1, desc="Chunking", ) # ------------------------------------------------------------------ # 6. Save # ------------------------------------------------------------------ output_path = output_dir / "training_ready_hf_dataset" output_path.mkdir(parents=True, exist_ok=True) logger.info("Saving dataset to %s", output_path) chunked.save_to_disk(str(output_path)) logger.info("Done! Dataset statistics:") for split_name in chunked: logger.info( " %s: %d chunks of %d tokens", split_name, len(chunked[split_name]), context_length, ) return str(output_path)Load ProteinGym DMS substitution CSV files from source_dir, build a language-model training dataset, and save it to output_dir.
Args
source_dir- Directory containing individual assay
*.csvfiles (produced bydownload_proteingym()). output_dir- Parent directory; the dataset is written under
output_dir/training_ready_hf_dataset/. context_length- Token block length (default 1024).
train_ratio- Fraction of sequences used for training (rest split 50/50 between validation and test).
seed- Random seed for reproducible shuffling.
num_proc- Parallel workers for HuggingFace Dataset.map() operations.
Returns
Absolute path string of the saved dataset directory.