Module molcrawl.dnabert2.main

DNABERT-2 Training Script for Genome Sequence Data

DNABERT-2 is a BERT base model specialized for DNA sequence analysis. Main improvements: - BPE (Byte Pair Encoding) Tokenization(k-merunnecessary) - More efficient attention mechanism - Architecture that takes into account the unique characteristics of DNA

reference: DNABERT-2: Efficient Foundation Model and Benchmark for Multi-Species Genome https://github.com/MAGICS-LAB/DNABERT_2

Classes

class DNADatasetLoader (data_dir, split='train', test_size=0.1)
Expand source code
class DNADatasetLoader:
    """
    Loader for DNA sequence datasets

    Load the existing genome_sequence dataset,
    Perform preprocessing for DNABERT-2.
    """

    def __init__(self, data_dir, split="train", test_size=0.1):
        self.data_dir = data_dir
        self.split = split
        self.test_size = test_size

        print(f"πŸ“‚ Loading DNA dataset from {data_dir}")

        # Load data from Arrow files or HuggingFace format
        try:
            arrow_files = list(Path(data_dir).glob("*.arrow"))
            if arrow_files:
                print(f"πŸ“ Found {len(arrow_files)} arrow files")
                all_batches = []
                for arrow_file in arrow_files:
                    try:
                        print(f"πŸ“– Reading {arrow_file.name}...")
                        with pa.memory_map(str(arrow_file), "r") as mmap:
                            with pa.ipc.open_stream(mmap) as reader:
                                table = reader.read_all()
                                print(f"πŸ”’ Read table: {len(table)} rows")
                                all_batches.append(table)
                    except Exception as e:
                        print(f"❌ Failed to read {arrow_file.name}: {e}")
                        continue

                if all_batches:
                    combined_table = pa.concat_tables(all_batches)
                    print(f"πŸ“Š Combined {len(all_batches)} tables: {len(combined_table)} total rows")
                    df = combined_table.to_pandas()

                    # Convert numpy arrays to lists for HuggingFace compatibility
                    if "token" in df.columns:
                        df["token"] = df["token"].apply(lambda x: x.tolist() if hasattr(x, "tolist") else x)

                    self.dataset = Dataset.from_pandas(df)
                    print("βœ… Created HuggingFace Dataset")
                    print(f"πŸ” Dataset columns: {self.dataset.column_names}")
                else:
                    raise ValueError("No arrow files could be read successfully")
            else:
                raise FileNotFoundError(f"No .arrow files found in {data_dir}")

        except Exception as e:
            print(f"❌ Arrow loading failed: {e}")
            raise FileNotFoundError(f"Could not load data from {data_dir}") from e

        # Split into train/valid if needed
        if hasattr(self.dataset, "keys") and isinstance(self.dataset, dict) and "train" in self.dataset:
            if split == "train":
                self.data = self.dataset["train"]
            elif split in ["valid", "val", "test"]:
                self.data = self.dataset.get("valid", self.dataset.get("test", self.dataset["train"]))
        else:
            if test_size > 0:
                split_dataset = self.dataset.train_test_split(test_size=test_size, seed=42)
                if split == "train":
                    self.data = split_dataset["train"]
                elif split in ["valid", "val", "test"]:
                    self.data = split_dataset["test"]
            else:
                self.data = self.dataset

        print(f"βœ… Loaded {len(self.data)} samples for {split}")

        # Show sample
        if len(self.data) > 0:
            sample = self.data[0]
            print("Sample keys:", list(sample.keys()))
            for key, value in sample.items():
                if isinstance(value, list):
                    print(f"  {key}: {type(value)} of length {len(value)}")
                else:
                    print(f"  {key}: {type(value)} = {value}")

    def get_dataset(self):
        """Return the HuggingFace Dataset object"""
        return self.data

Loader for DNA sequence datasets

Load the existing genome_sequence dataset, Perform preprocessing for DNABERT-2.

Methods

def get_dataset(self)
Expand source code
def get_dataset(self):
    """Return the HuggingFace Dataset object"""
    return self.data

Return the HuggingFace Dataset object