Module molcrawl.evaluation.bert.proteingym_data_preparation

Preprocessing script for BERT version of ProteinGym dataset

Preprocess the ProteinGym dataset for the BERT model, Save it in a format suitable for evaluation.

Functions

def get_log_dir() ‑> pathlib.Path
Expand source code
def get_log_dir() -> Path:
    """Get the output destination of the ProteinGym preprocessing log and create it if it does not exist."""
    learning_source_dir: str = check_learning_source_dir()
    log_dir: Path = Path(learning_source_dir) / "protein_sequence" / "logs"
    log_dir.mkdir(parents=True, exist_ok=True)
    return log_dir

Get the output destination of the ProteinGym preprocessing log and create it if it does not exist.

def main() ‑> None
Expand source code
def main() -> None:
    parser = argparse.ArgumentParser(description="BERT ProteinGym data preprocessing")
    parser.add_argument("--output_dir", type=str, default="./bert_proteingym_data", help="Output directory for processed data")
    parser.add_argument("--download", action="store_true", help="Download ProteinGym data")
    parser.add_argument("--max_variants_per_assay", type=int, default=1000, help="Maximum variants per assay")
    parser.add_argument("--sample_only", action="store_true", help="Create sample dataset only")
    parser.add_argument("--data_dir", type=str, default="./bert_proteingym_data", help="Directory containing ProteinGym data")

    args = parser.parse_args()

    processor = BERTProteinGymDataProcessor(args.output_dir)

    try:
        if args.sample_only:
            # create sample dataset
            logger.info("Creating sample dataset")
            sample_data: List[Dict[str, Any]] = [
                {
                    "mutant": "A1V",
                    "mutated_sequence": "VLKGDLSGLTQVKSGQDKGLTRVKDDLSVLTQVKSGQDKGLT",
                    "target_seq": "ALKGDLSGLTQVKSGQDKGLTRVKDDLSVLTQVKSGQDKGLT",
                    "DMS_score": 0.85,
                    "assay_name": "SAMPLE_PROTEIN",
                },
                {
                    "mutant": "L2P",
                    "mutated_sequence": "APKGDLSGLTQVKSGQDKGLTRVKDDLSVLTQVKSGQDKGLT",
                    "target_seq": "ALKGDLSGLTQVKSGQDKGLTRVKDDLSVLTQVKSGQDKGLT",
                    "DMS_score": 0.15,
                    "assay_name": "SAMPLE_PROTEIN",
                },
                {
                    "mutant": "WT",
                    "mutated_sequence": "ALKGDLSGLTQVKSGQDKGLTRVKDDLSVLTQVKSGQDKGLT",
                    "target_seq": "ALKGDLSGLTQVKSGQDKGLTRVKDDLSVLTQVKSGQDKGLT",
                    "DMS_score": 1.0,
                    "assay_name": "SAMPLE_PROTEIN",
                },
            ]

            df_sample: pd.DataFrame = pd.DataFrame(sample_data)
            processor.save_bert_ready_data(df_sample, "bert_proteingym_sample.csv")
            logger.info("Sample dataset created")
            return

        if args.download:
            # Download data
            logger.info("Downloading ProteinGym data")

            # Download substitution mutation data
            zip_file = processor.download_file(processor.PROTEINGYM_URLS["substitutions"])

            # Download reference data
            ref_file = processor.download_file(processor.PROTEINGYM_URLS["reference_substitutions"])

            # unzip
            extract_dir = processor.extract_zip(zip_file)

            # read data
            datasets: Dict[str, pd.DataFrame] = processor.load_proteingym_data(extract_dir, ref_file)
        else:
            # Use existing data
            data_dir = args.data_dir
            if not Path(data_dir).exists():
                raise ValueError(f"Data directory not found: {data_dir}")

            datasets = processor.load_proteingym_data(data_dir)

        # Preprocess for BERT
        processed_df: pd.DataFrame = processor.preprocess_for_bert(
            datasets,
            max_variants_per_assay=args.max_variants_per_assay,
        )

        # keep
        processor.save_bert_ready_data(processed_df)

        logger.info("BERT ProteinGym data preprocessing completed successfully")

    except Exception as e:
        logger.error(f"Data preprocessing failed: {e}")
        raise

Classes

class BERTProteinGymDataProcessor (output_dir: str = './bert_proteingym_data')
Expand source code
class BERTProteinGymDataProcessor:
    """Preprocessing class for ProteinGym dataset for BERT"""

    # Official URL of ProteinGym v1.3 dataset
    PROTEINGYM_URLS = {
        # DMS (Deep Mutational Scanning) data - for main evaluation
        "substitutions": "https://marks.hms.harvard.edu/proteingym/ProteinGym_v1.3/DMS_ProteinGym_substitutions.zip",
        "reference_substitutions": "https://marks.hms.harvard.edu/proteingym/ProteinGym_v1.3/DMS_substitutions.csv",
    }

    def __init__(self, output_dir: str = "./bert_proteingym_data") -> None:
        """
        initialization

        Args:
            output_dir (str): Output directory
        """
        self.output_dir: Path = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)

        # Log directory has been created with get_log_dir

    def download_file(self, url: str, filename: Optional[str] = None) -> str:
        """
        Download file

        Args:
            url (str): Download URL
            filename (str): Save file name(Noneautomatically generated)

        Returns:
            str: path of the downloaded file
        """
        if filename is None:
            filename = Path(urlparse(url).path).name

        filepath = self.output_dir / filename

        logger.info(f"Downloading {url} to {filepath}")

        response = requests.get(url, stream=True)
        response.raise_for_status()

        total_size = int(response.headers.get("content-length", 0))

        with (
            open(filepath, "wb") as f,
            tqdm(
                desc=filename,
                total=total_size,
                unit="iB",
                unit_scale=True,
                unit_divisor=1024,
            ) as progress_bar,
        ):
            for chunk in response.iter_content(chunk_size=8192):
                size = f.write(chunk)
                progress_bar.update(size)

        logger.info(f"Downloaded: {filepath}")
        return str(filepath)

    def extract_zip(self, zip_path: str) -> str:
        """
        Extract the ZIP file

        Args:
            zip_path (str): ZIP file path

        Returns:
            str: path of expanded directory
        """
        extract_dir = self.output_dir / Path(zip_path).stem
        extract_dir.mkdir(exist_ok=True)

        logger.info(f"Extracting {zip_path} to {extract_dir}")

        with zipfile.ZipFile(zip_path, "r") as zip_ref:
            zip_ref.extractall(extract_dir)

        logger.info(f"Extracted to: {extract_dir}")
        return str(extract_dir)

    def load_proteingym_data(self, data_dir: str, reference_csv: Optional[str] = None) -> Dict[str, pd.DataFrame]:
        """
        Load ProteinGym data

        Args:
            data_dir (str): data directory
            reference_csv (str): Reference CSV file (optional)

        Returns:
            dict: dataset dictionary
        """
        logger.info(f"Loading ProteinGym data from {data_dir}")

        data_dir_path: Path = Path(data_dir)
        datasets: Dict[str, pd.DataFrame] = {}

        # Check ProteinGym structure
        dms_dir = None
        possible_dirs = [
            data_dir_path / "DMS_ProteinGym_substitutions" / "DMS_ProteinGym_substitutions",
            data_dir_path / "DMS_ProteinGym_substitutions",
            data_dir_path,
        ]

        for possible_dir in possible_dirs:
            if possible_dir.exists() and any(possible_dir.glob("*.csv")):
                dms_dir = possible_dir
                logger.info(f"Found ProteinGym data in: {dms_dir}")
                break

        if not dms_dir:
            logger.error(f"No ProteinGym CSV files found in {data_dir} or subdirectories")
            return datasets

        # Read DMS CSV file
        csv_files = list(dms_dir.glob("*.csv"))
        logger.info(f"Found {len(csv_files)} CSV files")

        for csv_file in csv_files:
            try:
                df = pd.read_csv(csv_file)
                dataset_name = csv_file.stem
                datasets[dataset_name] = df
                logger.info(f"Loaded {dataset_name}: {len(df)} variants")
            except Exception as e:
                logger.warning(f"Failed to load {csv_file}: {e}")

        # Load reference information (optional)
        if reference_csv and os.path.exists(reference_csv):
            try:
                ref_df = pd.read_csv(reference_csv)
                datasets["reference"] = ref_df
                logger.info(f"Loaded reference data: {len(ref_df)} assays")
            except Exception as e:
                logger.warning(f"Failed to load reference CSV: {e}")

        return datasets

    def preprocess_for_bert(
        self,
        datasets: Dict[str, pd.DataFrame],
        max_variants_per_assay: int = 1000,
    ) -> pd.DataFrame:
        """
        Preprocess data for BERT evaluation

        Args:
            datasets (dict): dataset dictionary
            max_variants_per_assay (int): Maximum number of variants per assay

        Returns:
            pd.DataFrame: Preprocessed data
        """
        logger.info("Starting BERT-specific preprocessing")

        all_variants: List[pd.DataFrame] = []

        for dataset_name, df in datasets.items():
            if dataset_name == "reference":
                continue

            logger.info(f"Processing dataset: {dataset_name}")

            # Check required columns
            required_columns = ["mutated_sequence", "DMS_score"]
            missing_columns = [col for col in required_columns if col not in df.columns]

            if missing_columns:
                logger.warning(f"Missing columns in {dataset_name}: {missing_columns}")
                logger.info(f"Available columns: {list(df.columns)}")
                continue

            # Basic filtering of data
            df_clean = df.dropna(subset=required_columns).copy()

            # Check the validity of DMS_score
            df_clean = df_clean[(df_clean["DMS_score"].notna()) & (np.isfinite(df_clean["DMS_score"]))].copy()

            # Check the validity of array length
            df_clean = df_clean[
                (df_clean["mutated_sequence"].str.len() > 10)
                & (df_clean["mutated_sequence"].str.len() < 2000)  # Consider BERT limitations
            ].copy()

            # Sampling per assay
            if len(df_clean) > max_variants_per_assay:
                df_clean = df_clean.sample(n=max_variants_per_assay, random_state=42)
                logger.info(f"Sampled {max_variants_per_assay} variants from {dataset_name}")

            # add dataset name
            df_clean["assay_name"] = dataset_name

            # Infer wild type sequence (if necessary)
            if "target_seq" not in df_clean.columns:
                df_clean = self._infer_wildtype_sequences(df_clean)

            all_variants.append(df_clean)
            logger.info(f"Processed {len(df_clean)} variants from {dataset_name}")

        if not all_variants:
            raise ValueError("No valid datasets found")

        # combine all data
        combined_df: pd.DataFrame = pd.concat(all_variants, ignore_index=True)

        # Final cleaning
        combined_df = self._final_cleaning(combined_df)

        logger.info(f"Total processed variants: {len(combined_df)}")
        return combined_df

    def _infer_wildtype_sequences(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Estimated wild type sequence

        Args:
            df (pd.DataFrame): variant data

        Returns:
            pd.DataFrame: data with wild type sequences added
        """
        logger.info("Inferring wildtype sequences")

        def infer_wildtype(mutated_seq: str, mutant_info: Any) -> str:
            """Inferring wild type from a single mutation"""
            if pd.isna(mutant_info) or mutant_info == "WT":
                return mutated_seq

            # Supported only for single mutation (example: A1V)
            if isinstance(mutant_info, str) and len(mutant_info) >= 3:
                try:
                    orig_aa = mutant_info[0]
                    pos = int(mutant_info[1:-1]) - 1  # convert to 0-indexed
                    mut_aa = mutant_info[-1]

                    # Return the mutant sequence to wild type
                    wt_seq = list(mutated_seq)
                    if 0 <= pos < len(wt_seq) and wt_seq[pos] == mut_aa:
                        wt_seq[pos] = orig_aa
                        return "".join(wt_seq)
                except Exception:
                    pass

            return mutated_seq

        # If there is a mutant column
        if "mutant" in df.columns:
            df["target_seq"] = df.apply(
                lambda row: infer_wildtype(row["mutated_sequence"], row.get("mutant", "")),
                axis=1,
            )
        else:
            # If there is no mutant information, use mutated_sequence as is
            df["target_seq"] = df["mutated_sequence"]
            logger.warning("No mutant information found, using mutated_sequence as target_seq")

        return df

    def _final_cleaning(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Final data cleaning

        Args:
            df (pd.DataFrame): Preprocessed data

        Returns:
            pd.DataFrame: Cleaned data
        """
        logger.info("Final data cleaning")

        # Duplicate removal
        before_dedup = len(df)
        df = df.drop_duplicates(subset=["mutated_sequence"], keep="first")
        after_dedup = len(df)
        logger.info(f"Removed {before_dedup - after_dedup} duplicate sequences")

        valid_aa: set = set("ACDEFGHIKLMNPQRSTVWY")

        def is_valid_protein_sequence(seq: Any) -> bool:
            """Check if the protein sequence is valid"""
            if not isinstance(seq, str):
                return False
            return all(aa in valid_aa for aa in seq.upper())

        valid_seq_mask = df["mutated_sequence"].apply(is_valid_protein_sequence)
        df = df[valid_seq_mask].copy()
        logger.info(f"Retained {len(df)} variants with valid amino acid sequences")

        # Log DMS_score statistics
        logger.info(f"DMS_score statistics:\n{df['DMS_score'].describe()}")

        return df

    def save_bert_ready_data(self, df: pd.DataFrame, filename: str = "bert_proteingym_dataset.csv") -> None:
        """
        Save as BERT data

        Args:
            df (pd.DataFrame): Preprocessed data
            filename (str): Save file name
        """
        filepath = self.output_dir / filename

        # Save in CSV format
        df.to_csv(filepath, index=False)
        logger.info(f"Saved BERT-ready dataset: {filepath}")

        # Save also in JSON format (with metadata)
        json_data: Dict[str, Any] = {
            "metadata": {
                "total_variants": len(df),
                "unique_assays": df["assay_name"].nunique() if "assay_name" in df.columns else 1,
                "dms_score_range": [float(df["DMS_score"].min()), float(df["DMS_score"].max())],
                "avg_sequence_length": float(df["mutated_sequence"].str.len().mean()),
                "processing_date": datetime.now().isoformat(),
            },
            "data": df.to_dict("records")[:100],  # Include only the first 100 records in JSON
        }

        json_filepath = filepath.with_suffix(".json")
        with open(json_filepath, "w") as f:
            json.dump(json_data, f, indent=2)
        logger.info(f"Saved metadata and sample data: {json_filepath}")

        # Also create statistical report
        self._create_statistics_report(df)

    def _create_statistics_report(self, df: pd.DataFrame) -> None:
        """
        Create statistical report

        Args:
            df (pd.DataFrame): Dataset
        """
        report_file = self.output_dir / "bert_proteingym_statistics.txt"

        with open(report_file, "w") as f:
            f.write("BERT ProteinGym Dataset Statistics Report\n")
            f.write("=" * 50 + "\n\n")

            f.write(f"Total variants: {len(df)}\n")
            f.write(f"Unique assays: {df['assay_name'].nunique() if 'assay_name' in df.columns else 'N/A'}\n\n")

            f.write("DMS Score Distribution:\n")
            f.write(f"{df['DMS_score'].describe()}\n\n")

            f.write("Sequence Length Distribution:\n")
            f.write(f"{df['mutated_sequence'].str.len().describe()}\n\n")

            if "assay_name" in df.columns:
                f.write("Top 10 Assays by Variant Count:\n")
                for assay, count in df["assay_name"].value_counts().head(10).items():
                    f.write(f"  {assay}: {count}\n")
                f.write("\n")

            f.write("Sample Sequences (first 5):\n")
            for i, seq in enumerate(df["mutated_sequence"].head(5)):
                f.write(f"  {i + 1}: {seq[:50]}{'...' if len(seq) > 50 else ''}\n")

        logger.info(f"Statistics report saved: {report_file}")

Preprocessing class for ProteinGym dataset for BERT

initialization

Args

output_dir : str
Output directory

Class variables

var PROTEINGYM_URLS

Methods

def download_file(self, url: str, filename: str | None = None) ‑> str
Expand source code
def download_file(self, url: str, filename: Optional[str] = None) -> str:
    """
    Download file

    Args:
        url (str): Download URL
        filename (str): Save file name(Noneautomatically generated)

    Returns:
        str: path of the downloaded file
    """
    if filename is None:
        filename = Path(urlparse(url).path).name

    filepath = self.output_dir / filename

    logger.info(f"Downloading {url} to {filepath}")

    response = requests.get(url, stream=True)
    response.raise_for_status()

    total_size = int(response.headers.get("content-length", 0))

    with (
        open(filepath, "wb") as f,
        tqdm(
            desc=filename,
            total=total_size,
            unit="iB",
            unit_scale=True,
            unit_divisor=1024,
        ) as progress_bar,
    ):
        for chunk in response.iter_content(chunk_size=8192):
            size = f.write(chunk)
            progress_bar.update(size)

    logger.info(f"Downloaded: {filepath}")
    return str(filepath)

Download file

Args

url : str
Download URL
filename : str
Save file name(Noneautomatically generated)

Returns

str
path of the downloaded file
def extract_zip(self, zip_path: str) ‑> str
Expand source code
def extract_zip(self, zip_path: str) -> str:
    """
    Extract the ZIP file

    Args:
        zip_path (str): ZIP file path

    Returns:
        str: path of expanded directory
    """
    extract_dir = self.output_dir / Path(zip_path).stem
    extract_dir.mkdir(exist_ok=True)

    logger.info(f"Extracting {zip_path} to {extract_dir}")

    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extract_dir)

    logger.info(f"Extracted to: {extract_dir}")
    return str(extract_dir)

Extract the ZIP file

Args

zip_path : str
ZIP file path

Returns

str
path of expanded directory
def load_proteingym_data(self, data_dir: str, reference_csv: str | None = None) ‑> Dict[str, pandas.core.frame.DataFrame]
Expand source code
def load_proteingym_data(self, data_dir: str, reference_csv: Optional[str] = None) -> Dict[str, pd.DataFrame]:
    """
    Load ProteinGym data

    Args:
        data_dir (str): data directory
        reference_csv (str): Reference CSV file (optional)

    Returns:
        dict: dataset dictionary
    """
    logger.info(f"Loading ProteinGym data from {data_dir}")

    data_dir_path: Path = Path(data_dir)
    datasets: Dict[str, pd.DataFrame] = {}

    # Check ProteinGym structure
    dms_dir = None
    possible_dirs = [
        data_dir_path / "DMS_ProteinGym_substitutions" / "DMS_ProteinGym_substitutions",
        data_dir_path / "DMS_ProteinGym_substitutions",
        data_dir_path,
    ]

    for possible_dir in possible_dirs:
        if possible_dir.exists() and any(possible_dir.glob("*.csv")):
            dms_dir = possible_dir
            logger.info(f"Found ProteinGym data in: {dms_dir}")
            break

    if not dms_dir:
        logger.error(f"No ProteinGym CSV files found in {data_dir} or subdirectories")
        return datasets

    # Read DMS CSV file
    csv_files = list(dms_dir.glob("*.csv"))
    logger.info(f"Found {len(csv_files)} CSV files")

    for csv_file in csv_files:
        try:
            df = pd.read_csv(csv_file)
            dataset_name = csv_file.stem
            datasets[dataset_name] = df
            logger.info(f"Loaded {dataset_name}: {len(df)} variants")
        except Exception as e:
            logger.warning(f"Failed to load {csv_file}: {e}")

    # Load reference information (optional)
    if reference_csv and os.path.exists(reference_csv):
        try:
            ref_df = pd.read_csv(reference_csv)
            datasets["reference"] = ref_df
            logger.info(f"Loaded reference data: {len(ref_df)} assays")
        except Exception as e:
            logger.warning(f"Failed to load reference CSV: {e}")

    return datasets

Load ProteinGym data

Args

data_dir : str
data directory
reference_csv : str
Reference CSV file (optional)

Returns

dict
dataset dictionary
def preprocess_for_bert(self,
datasets: Dict[str, pandas.core.frame.DataFrame],
max_variants_per_assay: int = 1000) ‑> pandas.core.frame.DataFrame
Expand source code
def preprocess_for_bert(
    self,
    datasets: Dict[str, pd.DataFrame],
    max_variants_per_assay: int = 1000,
) -> pd.DataFrame:
    """
    Preprocess data for BERT evaluation

    Args:
        datasets (dict): dataset dictionary
        max_variants_per_assay (int): Maximum number of variants per assay

    Returns:
        pd.DataFrame: Preprocessed data
    """
    logger.info("Starting BERT-specific preprocessing")

    all_variants: List[pd.DataFrame] = []

    for dataset_name, df in datasets.items():
        if dataset_name == "reference":
            continue

        logger.info(f"Processing dataset: {dataset_name}")

        # Check required columns
        required_columns = ["mutated_sequence", "DMS_score"]
        missing_columns = [col for col in required_columns if col not in df.columns]

        if missing_columns:
            logger.warning(f"Missing columns in {dataset_name}: {missing_columns}")
            logger.info(f"Available columns: {list(df.columns)}")
            continue

        # Basic filtering of data
        df_clean = df.dropna(subset=required_columns).copy()

        # Check the validity of DMS_score
        df_clean = df_clean[(df_clean["DMS_score"].notna()) & (np.isfinite(df_clean["DMS_score"]))].copy()

        # Check the validity of array length
        df_clean = df_clean[
            (df_clean["mutated_sequence"].str.len() > 10)
            & (df_clean["mutated_sequence"].str.len() < 2000)  # Consider BERT limitations
        ].copy()

        # Sampling per assay
        if len(df_clean) > max_variants_per_assay:
            df_clean = df_clean.sample(n=max_variants_per_assay, random_state=42)
            logger.info(f"Sampled {max_variants_per_assay} variants from {dataset_name}")

        # add dataset name
        df_clean["assay_name"] = dataset_name

        # Infer wild type sequence (if necessary)
        if "target_seq" not in df_clean.columns:
            df_clean = self._infer_wildtype_sequences(df_clean)

        all_variants.append(df_clean)
        logger.info(f"Processed {len(df_clean)} variants from {dataset_name}")

    if not all_variants:
        raise ValueError("No valid datasets found")

    # combine all data
    combined_df: pd.DataFrame = pd.concat(all_variants, ignore_index=True)

    # Final cleaning
    combined_df = self._final_cleaning(combined_df)

    logger.info(f"Total processed variants: {len(combined_df)}")
    return combined_df

Preprocess data for BERT evaluation

Args

datasets : dict
dataset dictionary
max_variants_per_assay : int
Maximum number of variants per assay

Returns

pd.DataFrame
Preprocessed data
def save_bert_ready_data(self,
df: pandas.core.frame.DataFrame,
filename: str = 'bert_proteingym_dataset.csv') ‑> None
Expand source code
def save_bert_ready_data(self, df: pd.DataFrame, filename: str = "bert_proteingym_dataset.csv") -> None:
    """
    Save as BERT data

    Args:
        df (pd.DataFrame): Preprocessed data
        filename (str): Save file name
    """
    filepath = self.output_dir / filename

    # Save in CSV format
    df.to_csv(filepath, index=False)
    logger.info(f"Saved BERT-ready dataset: {filepath}")

    # Save also in JSON format (with metadata)
    json_data: Dict[str, Any] = {
        "metadata": {
            "total_variants": len(df),
            "unique_assays": df["assay_name"].nunique() if "assay_name" in df.columns else 1,
            "dms_score_range": [float(df["DMS_score"].min()), float(df["DMS_score"].max())],
            "avg_sequence_length": float(df["mutated_sequence"].str.len().mean()),
            "processing_date": datetime.now().isoformat(),
        },
        "data": df.to_dict("records")[:100],  # Include only the first 100 records in JSON
    }

    json_filepath = filepath.with_suffix(".json")
    with open(json_filepath, "w") as f:
        json.dump(json_data, f, indent=2)
    logger.info(f"Saved metadata and sample data: {json_filepath}")

    # Also create statistical report
    self._create_statistics_report(df)

Save as BERT data

Args

df : pd.DataFrame
Preprocessed data
filename : str
Save file name