Module molcrawl.evaluation.bert.proteingym_data_preparation
Preprocessing script for BERT version of ProteinGym dataset
Preprocess the ProteinGym dataset for the BERT model, Save it in a format suitable for evaluation.
Functions
def get_log_dir() ‑> pathlib.Path-
Expand source code
def get_log_dir() -> Path: """Get the output destination of the ProteinGym preprocessing log and create it if it does not exist.""" learning_source_dir: str = check_learning_source_dir() log_dir: Path = Path(learning_source_dir) / "protein_sequence" / "logs" log_dir.mkdir(parents=True, exist_ok=True) return log_dirGet the output destination of the ProteinGym preprocessing log and create it if it does not exist.
def main() ‑> None-
Expand source code
def main() -> None: parser = argparse.ArgumentParser(description="BERT ProteinGym data preprocessing") parser.add_argument("--output_dir", type=str, default="./bert_proteingym_data", help="Output directory for processed data") parser.add_argument("--download", action="store_true", help="Download ProteinGym data") parser.add_argument("--max_variants_per_assay", type=int, default=1000, help="Maximum variants per assay") parser.add_argument("--sample_only", action="store_true", help="Create sample dataset only") parser.add_argument("--data_dir", type=str, default="./bert_proteingym_data", help="Directory containing ProteinGym data") args = parser.parse_args() processor = BERTProteinGymDataProcessor(args.output_dir) try: if args.sample_only: # create sample dataset logger.info("Creating sample dataset") sample_data: List[Dict[str, Any]] = [ { "mutant": "A1V", "mutated_sequence": "VLKGDLSGLTQVKSGQDKGLTRVKDDLSVLTQVKSGQDKGLT", "target_seq": "ALKGDLSGLTQVKSGQDKGLTRVKDDLSVLTQVKSGQDKGLT", "DMS_score": 0.85, "assay_name": "SAMPLE_PROTEIN", }, { "mutant": "L2P", "mutated_sequence": "APKGDLSGLTQVKSGQDKGLTRVKDDLSVLTQVKSGQDKGLT", "target_seq": "ALKGDLSGLTQVKSGQDKGLTRVKDDLSVLTQVKSGQDKGLT", "DMS_score": 0.15, "assay_name": "SAMPLE_PROTEIN", }, { "mutant": "WT", "mutated_sequence": "ALKGDLSGLTQVKSGQDKGLTRVKDDLSVLTQVKSGQDKGLT", "target_seq": "ALKGDLSGLTQVKSGQDKGLTRVKDDLSVLTQVKSGQDKGLT", "DMS_score": 1.0, "assay_name": "SAMPLE_PROTEIN", }, ] df_sample: pd.DataFrame = pd.DataFrame(sample_data) processor.save_bert_ready_data(df_sample, "bert_proteingym_sample.csv") logger.info("Sample dataset created") return if args.download: # Download data logger.info("Downloading ProteinGym data") # Download substitution mutation data zip_file = processor.download_file(processor.PROTEINGYM_URLS["substitutions"]) # Download reference data ref_file = processor.download_file(processor.PROTEINGYM_URLS["reference_substitutions"]) # unzip extract_dir = processor.extract_zip(zip_file) # read data datasets: Dict[str, pd.DataFrame] = processor.load_proteingym_data(extract_dir, ref_file) else: # Use existing data data_dir = args.data_dir if not Path(data_dir).exists(): raise ValueError(f"Data directory not found: {data_dir}") datasets = processor.load_proteingym_data(data_dir) # Preprocess for BERT processed_df: pd.DataFrame = processor.preprocess_for_bert( datasets, max_variants_per_assay=args.max_variants_per_assay, ) # keep processor.save_bert_ready_data(processed_df) logger.info("BERT ProteinGym data preprocessing completed successfully") except Exception as e: logger.error(f"Data preprocessing failed: {e}") raise
Classes
class BERTProteinGymDataProcessor (output_dir: str = './bert_proteingym_data')-
Expand source code
class BERTProteinGymDataProcessor: """Preprocessing class for ProteinGym dataset for BERT""" # Official URL of ProteinGym v1.3 dataset PROTEINGYM_URLS = { # DMS (Deep Mutational Scanning) data - for main evaluation "substitutions": "https://marks.hms.harvard.edu/proteingym/ProteinGym_v1.3/DMS_ProteinGym_substitutions.zip", "reference_substitutions": "https://marks.hms.harvard.edu/proteingym/ProteinGym_v1.3/DMS_substitutions.csv", } def __init__(self, output_dir: str = "./bert_proteingym_data") -> None: """ initialization Args: output_dir (str): Output directory """ self.output_dir: Path = Path(output_dir) self.output_dir.mkdir(parents=True, exist_ok=True) # Log directory has been created with get_log_dir def download_file(self, url: str, filename: Optional[str] = None) -> str: """ Download file Args: url (str): Download URL filename (str): Save file name(Noneautomatically generated) Returns: str: path of the downloaded file """ if filename is None: filename = Path(urlparse(url).path).name filepath = self.output_dir / filename logger.info(f"Downloading {url} to {filepath}") response = requests.get(url, stream=True) response.raise_for_status() total_size = int(response.headers.get("content-length", 0)) with ( open(filepath, "wb") as f, tqdm( desc=filename, total=total_size, unit="iB", unit_scale=True, unit_divisor=1024, ) as progress_bar, ): for chunk in response.iter_content(chunk_size=8192): size = f.write(chunk) progress_bar.update(size) logger.info(f"Downloaded: {filepath}") return str(filepath) def extract_zip(self, zip_path: str) -> str: """ Extract the ZIP file Args: zip_path (str): ZIP file path Returns: str: path of expanded directory """ extract_dir = self.output_dir / Path(zip_path).stem extract_dir.mkdir(exist_ok=True) logger.info(f"Extracting {zip_path} to {extract_dir}") with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(extract_dir) logger.info(f"Extracted to: {extract_dir}") return str(extract_dir) def load_proteingym_data(self, data_dir: str, reference_csv: Optional[str] = None) -> Dict[str, pd.DataFrame]: """ Load ProteinGym data Args: data_dir (str): data directory reference_csv (str): Reference CSV file (optional) Returns: dict: dataset dictionary """ logger.info(f"Loading ProteinGym data from {data_dir}") data_dir_path: Path = Path(data_dir) datasets: Dict[str, pd.DataFrame] = {} # Check ProteinGym structure dms_dir = None possible_dirs = [ data_dir_path / "DMS_ProteinGym_substitutions" / "DMS_ProteinGym_substitutions", data_dir_path / "DMS_ProteinGym_substitutions", data_dir_path, ] for possible_dir in possible_dirs: if possible_dir.exists() and any(possible_dir.glob("*.csv")): dms_dir = possible_dir logger.info(f"Found ProteinGym data in: {dms_dir}") break if not dms_dir: logger.error(f"No ProteinGym CSV files found in {data_dir} or subdirectories") return datasets # Read DMS CSV file csv_files = list(dms_dir.glob("*.csv")) logger.info(f"Found {len(csv_files)} CSV files") for csv_file in csv_files: try: df = pd.read_csv(csv_file) dataset_name = csv_file.stem datasets[dataset_name] = df logger.info(f"Loaded {dataset_name}: {len(df)} variants") except Exception as e: logger.warning(f"Failed to load {csv_file}: {e}") # Load reference information (optional) if reference_csv and os.path.exists(reference_csv): try: ref_df = pd.read_csv(reference_csv) datasets["reference"] = ref_df logger.info(f"Loaded reference data: {len(ref_df)} assays") except Exception as e: logger.warning(f"Failed to load reference CSV: {e}") return datasets def preprocess_for_bert( self, datasets: Dict[str, pd.DataFrame], max_variants_per_assay: int = 1000, ) -> pd.DataFrame: """ Preprocess data for BERT evaluation Args: datasets (dict): dataset dictionary max_variants_per_assay (int): Maximum number of variants per assay Returns: pd.DataFrame: Preprocessed data """ logger.info("Starting BERT-specific preprocessing") all_variants: List[pd.DataFrame] = [] for dataset_name, df in datasets.items(): if dataset_name == "reference": continue logger.info(f"Processing dataset: {dataset_name}") # Check required columns required_columns = ["mutated_sequence", "DMS_score"] missing_columns = [col for col in required_columns if col not in df.columns] if missing_columns: logger.warning(f"Missing columns in {dataset_name}: {missing_columns}") logger.info(f"Available columns: {list(df.columns)}") continue # Basic filtering of data df_clean = df.dropna(subset=required_columns).copy() # Check the validity of DMS_score df_clean = df_clean[(df_clean["DMS_score"].notna()) & (np.isfinite(df_clean["DMS_score"]))].copy() # Check the validity of array length df_clean = df_clean[ (df_clean["mutated_sequence"].str.len() > 10) & (df_clean["mutated_sequence"].str.len() < 2000) # Consider BERT limitations ].copy() # Sampling per assay if len(df_clean) > max_variants_per_assay: df_clean = df_clean.sample(n=max_variants_per_assay, random_state=42) logger.info(f"Sampled {max_variants_per_assay} variants from {dataset_name}") # add dataset name df_clean["assay_name"] = dataset_name # Infer wild type sequence (if necessary) if "target_seq" not in df_clean.columns: df_clean = self._infer_wildtype_sequences(df_clean) all_variants.append(df_clean) logger.info(f"Processed {len(df_clean)} variants from {dataset_name}") if not all_variants: raise ValueError("No valid datasets found") # combine all data combined_df: pd.DataFrame = pd.concat(all_variants, ignore_index=True) # Final cleaning combined_df = self._final_cleaning(combined_df) logger.info(f"Total processed variants: {len(combined_df)}") return combined_df def _infer_wildtype_sequences(self, df: pd.DataFrame) -> pd.DataFrame: """ Estimated wild type sequence Args: df (pd.DataFrame): variant data Returns: pd.DataFrame: data with wild type sequences added """ logger.info("Inferring wildtype sequences") def infer_wildtype(mutated_seq: str, mutant_info: Any) -> str: """Inferring wild type from a single mutation""" if pd.isna(mutant_info) or mutant_info == "WT": return mutated_seq # Supported only for single mutation (example: A1V) if isinstance(mutant_info, str) and len(mutant_info) >= 3: try: orig_aa = mutant_info[0] pos = int(mutant_info[1:-1]) - 1 # convert to 0-indexed mut_aa = mutant_info[-1] # Return the mutant sequence to wild type wt_seq = list(mutated_seq) if 0 <= pos < len(wt_seq) and wt_seq[pos] == mut_aa: wt_seq[pos] = orig_aa return "".join(wt_seq) except Exception: pass return mutated_seq # If there is a mutant column if "mutant" in df.columns: df["target_seq"] = df.apply( lambda row: infer_wildtype(row["mutated_sequence"], row.get("mutant", "")), axis=1, ) else: # If there is no mutant information, use mutated_sequence as is df["target_seq"] = df["mutated_sequence"] logger.warning("No mutant information found, using mutated_sequence as target_seq") return df def _final_cleaning(self, df: pd.DataFrame) -> pd.DataFrame: """ Final data cleaning Args: df (pd.DataFrame): Preprocessed data Returns: pd.DataFrame: Cleaned data """ logger.info("Final data cleaning") # Duplicate removal before_dedup = len(df) df = df.drop_duplicates(subset=["mutated_sequence"], keep="first") after_dedup = len(df) logger.info(f"Removed {before_dedup - after_dedup} duplicate sequences") valid_aa: set = set("ACDEFGHIKLMNPQRSTVWY") def is_valid_protein_sequence(seq: Any) -> bool: """Check if the protein sequence is valid""" if not isinstance(seq, str): return False return all(aa in valid_aa for aa in seq.upper()) valid_seq_mask = df["mutated_sequence"].apply(is_valid_protein_sequence) df = df[valid_seq_mask].copy() logger.info(f"Retained {len(df)} variants with valid amino acid sequences") # Log DMS_score statistics logger.info(f"DMS_score statistics:\n{df['DMS_score'].describe()}") return df def save_bert_ready_data(self, df: pd.DataFrame, filename: str = "bert_proteingym_dataset.csv") -> None: """ Save as BERT data Args: df (pd.DataFrame): Preprocessed data filename (str): Save file name """ filepath = self.output_dir / filename # Save in CSV format df.to_csv(filepath, index=False) logger.info(f"Saved BERT-ready dataset: {filepath}") # Save also in JSON format (with metadata) json_data: Dict[str, Any] = { "metadata": { "total_variants": len(df), "unique_assays": df["assay_name"].nunique() if "assay_name" in df.columns else 1, "dms_score_range": [float(df["DMS_score"].min()), float(df["DMS_score"].max())], "avg_sequence_length": float(df["mutated_sequence"].str.len().mean()), "processing_date": datetime.now().isoformat(), }, "data": df.to_dict("records")[:100], # Include only the first 100 records in JSON } json_filepath = filepath.with_suffix(".json") with open(json_filepath, "w") as f: json.dump(json_data, f, indent=2) logger.info(f"Saved metadata and sample data: {json_filepath}") # Also create statistical report self._create_statistics_report(df) def _create_statistics_report(self, df: pd.DataFrame) -> None: """ Create statistical report Args: df (pd.DataFrame): Dataset """ report_file = self.output_dir / "bert_proteingym_statistics.txt" with open(report_file, "w") as f: f.write("BERT ProteinGym Dataset Statistics Report\n") f.write("=" * 50 + "\n\n") f.write(f"Total variants: {len(df)}\n") f.write(f"Unique assays: {df['assay_name'].nunique() if 'assay_name' in df.columns else 'N/A'}\n\n") f.write("DMS Score Distribution:\n") f.write(f"{df['DMS_score'].describe()}\n\n") f.write("Sequence Length Distribution:\n") f.write(f"{df['mutated_sequence'].str.len().describe()}\n\n") if "assay_name" in df.columns: f.write("Top 10 Assays by Variant Count:\n") for assay, count in df["assay_name"].value_counts().head(10).items(): f.write(f" {assay}: {count}\n") f.write("\n") f.write("Sample Sequences (first 5):\n") for i, seq in enumerate(df["mutated_sequence"].head(5)): f.write(f" {i + 1}: {seq[:50]}{'...' if len(seq) > 50 else ''}\n") logger.info(f"Statistics report saved: {report_file}")Preprocessing class for ProteinGym dataset for BERT
initialization
Args
output_dir:str- Output directory
Class variables
var PROTEINGYM_URLS
Methods
def download_file(self, url: str, filename: str | None = None) ‑> str-
Expand source code
def download_file(self, url: str, filename: Optional[str] = None) -> str: """ Download file Args: url (str): Download URL filename (str): Save file name(Noneautomatically generated) Returns: str: path of the downloaded file """ if filename is None: filename = Path(urlparse(url).path).name filepath = self.output_dir / filename logger.info(f"Downloading {url} to {filepath}") response = requests.get(url, stream=True) response.raise_for_status() total_size = int(response.headers.get("content-length", 0)) with ( open(filepath, "wb") as f, tqdm( desc=filename, total=total_size, unit="iB", unit_scale=True, unit_divisor=1024, ) as progress_bar, ): for chunk in response.iter_content(chunk_size=8192): size = f.write(chunk) progress_bar.update(size) logger.info(f"Downloaded: {filepath}") return str(filepath)Download file
Args
url:str- Download URL
filename:str- Save file name(Noneautomatically generated)
Returns
str- path of the downloaded file
def extract_zip(self, zip_path: str) ‑> str-
Expand source code
def extract_zip(self, zip_path: str) -> str: """ Extract the ZIP file Args: zip_path (str): ZIP file path Returns: str: path of expanded directory """ extract_dir = self.output_dir / Path(zip_path).stem extract_dir.mkdir(exist_ok=True) logger.info(f"Extracting {zip_path} to {extract_dir}") with zipfile.ZipFile(zip_path, "r") as zip_ref: zip_ref.extractall(extract_dir) logger.info(f"Extracted to: {extract_dir}") return str(extract_dir)Extract the ZIP file
Args
zip_path:str- ZIP file path
Returns
str- path of expanded directory
def load_proteingym_data(self, data_dir: str, reference_csv: str | None = None) ‑> Dict[str, pandas.core.frame.DataFrame]-
Expand source code
def load_proteingym_data(self, data_dir: str, reference_csv: Optional[str] = None) -> Dict[str, pd.DataFrame]: """ Load ProteinGym data Args: data_dir (str): data directory reference_csv (str): Reference CSV file (optional) Returns: dict: dataset dictionary """ logger.info(f"Loading ProteinGym data from {data_dir}") data_dir_path: Path = Path(data_dir) datasets: Dict[str, pd.DataFrame] = {} # Check ProteinGym structure dms_dir = None possible_dirs = [ data_dir_path / "DMS_ProteinGym_substitutions" / "DMS_ProteinGym_substitutions", data_dir_path / "DMS_ProteinGym_substitutions", data_dir_path, ] for possible_dir in possible_dirs: if possible_dir.exists() and any(possible_dir.glob("*.csv")): dms_dir = possible_dir logger.info(f"Found ProteinGym data in: {dms_dir}") break if not dms_dir: logger.error(f"No ProteinGym CSV files found in {data_dir} or subdirectories") return datasets # Read DMS CSV file csv_files = list(dms_dir.glob("*.csv")) logger.info(f"Found {len(csv_files)} CSV files") for csv_file in csv_files: try: df = pd.read_csv(csv_file) dataset_name = csv_file.stem datasets[dataset_name] = df logger.info(f"Loaded {dataset_name}: {len(df)} variants") except Exception as e: logger.warning(f"Failed to load {csv_file}: {e}") # Load reference information (optional) if reference_csv and os.path.exists(reference_csv): try: ref_df = pd.read_csv(reference_csv) datasets["reference"] = ref_df logger.info(f"Loaded reference data: {len(ref_df)} assays") except Exception as e: logger.warning(f"Failed to load reference CSV: {e}") return datasetsLoad ProteinGym data
Args
data_dir:str- data directory
reference_csv:str- Reference CSV file (optional)
Returns
dict- dataset dictionary
def preprocess_for_bert(self,
datasets: Dict[str, pandas.core.frame.DataFrame],
max_variants_per_assay: int = 1000) ‑> pandas.core.frame.DataFrame-
Expand source code
def preprocess_for_bert( self, datasets: Dict[str, pd.DataFrame], max_variants_per_assay: int = 1000, ) -> pd.DataFrame: """ Preprocess data for BERT evaluation Args: datasets (dict): dataset dictionary max_variants_per_assay (int): Maximum number of variants per assay Returns: pd.DataFrame: Preprocessed data """ logger.info("Starting BERT-specific preprocessing") all_variants: List[pd.DataFrame] = [] for dataset_name, df in datasets.items(): if dataset_name == "reference": continue logger.info(f"Processing dataset: {dataset_name}") # Check required columns required_columns = ["mutated_sequence", "DMS_score"] missing_columns = [col for col in required_columns if col not in df.columns] if missing_columns: logger.warning(f"Missing columns in {dataset_name}: {missing_columns}") logger.info(f"Available columns: {list(df.columns)}") continue # Basic filtering of data df_clean = df.dropna(subset=required_columns).copy() # Check the validity of DMS_score df_clean = df_clean[(df_clean["DMS_score"].notna()) & (np.isfinite(df_clean["DMS_score"]))].copy() # Check the validity of array length df_clean = df_clean[ (df_clean["mutated_sequence"].str.len() > 10) & (df_clean["mutated_sequence"].str.len() < 2000) # Consider BERT limitations ].copy() # Sampling per assay if len(df_clean) > max_variants_per_assay: df_clean = df_clean.sample(n=max_variants_per_assay, random_state=42) logger.info(f"Sampled {max_variants_per_assay} variants from {dataset_name}") # add dataset name df_clean["assay_name"] = dataset_name # Infer wild type sequence (if necessary) if "target_seq" not in df_clean.columns: df_clean = self._infer_wildtype_sequences(df_clean) all_variants.append(df_clean) logger.info(f"Processed {len(df_clean)} variants from {dataset_name}") if not all_variants: raise ValueError("No valid datasets found") # combine all data combined_df: pd.DataFrame = pd.concat(all_variants, ignore_index=True) # Final cleaning combined_df = self._final_cleaning(combined_df) logger.info(f"Total processed variants: {len(combined_df)}") return combined_dfPreprocess data for BERT evaluation
Args
datasets:dict- dataset dictionary
max_variants_per_assay:int- Maximum number of variants per assay
Returns
pd.DataFrame- Preprocessed data
def save_bert_ready_data(self,
df: pandas.core.frame.DataFrame,
filename: str = 'bert_proteingym_dataset.csv') ‑> None-
Expand source code
def save_bert_ready_data(self, df: pd.DataFrame, filename: str = "bert_proteingym_dataset.csv") -> None: """ Save as BERT data Args: df (pd.DataFrame): Preprocessed data filename (str): Save file name """ filepath = self.output_dir / filename # Save in CSV format df.to_csv(filepath, index=False) logger.info(f"Saved BERT-ready dataset: {filepath}") # Save also in JSON format (with metadata) json_data: Dict[str, Any] = { "metadata": { "total_variants": len(df), "unique_assays": df["assay_name"].nunique() if "assay_name" in df.columns else 1, "dms_score_range": [float(df["DMS_score"].min()), float(df["DMS_score"].max())], "avg_sequence_length": float(df["mutated_sequence"].str.len().mean()), "processing_date": datetime.now().isoformat(), }, "data": df.to_dict("records")[:100], # Include only the first 100 records in JSON } json_filepath = filepath.with_suffix(".json") with open(json_filepath, "w") as f: json.dump(json_data, f, indent=2) logger.info(f"Saved metadata and sample data: {json_filepath}") # Also create statistical report self._create_statistics_report(df)Save as BERT data
Args
df:pd.DataFrame- Preprocessed data
filename:str- Save file name