Module molcrawl.rna.dataset.celltype.prepare_celltype
Cell Type Annotation Dataset Preparation Script for GPT-2 / BERT Fine-tuning
Downloads the Geneformer cell type annotation example dataset from the Genecorpus-30M repository (ctheodoris/Genecorpus-30M on HuggingFace) and prepares it for language-model fine-tuning using the RNA tokenizer.
The dataset contains human single-cell transcriptomes already pre-tokenized as rank-value encodings (same format as the CellxGene pretraining corpus). No re-tokenization is required — only splitting and chunking.
Pipeline
- Download cell_type_train_data.dataset from ctheodoris/Genecorpus-30M (example_input_files/cell_classification/cell_type_annotation/) via the HuggingFace Hub.
- Extract
input_ids(gene token IDs, already tokenized by TranscriptomeTokenizer). - Shuffle and apply an 80 / 10 / 10 train / valid / test split.
- Concatenate all token sequences into one long stream (with EOS between sequences), then chunk into fixed-length blocks of context_length.
- Save as a HuggingFace DatasetDict to
RNA_CELLTYPE_DIR/training_ready_hf_dataset/.
Usage (standalone): LEARNING_SOURCE_DIR=learning_source_20260311 \ python -m molcrawl.rna.dataset.celltype.prepare_celltype
The output is saved to:
$LEARNING_SOURCE_DIR/rna/celltype/training_ready_hf_dataset/
Functions
def download_celltype(save_path: str | pathlib.Path) ‑> pathlib.Path-
Expand source code
def download_celltype(save_path: Union[str, Path]) -> Path: """ Download the Geneformer cell type annotation dataset from HuggingFace Hub. The Arrow dataset directory is placed at ``save_path/cell_type_train_data .dataset/``. If it already exists and is non-empty the download is skipped. Args: save_path: Local directory into which the dataset is downloaded. Returns: Path to the Arrow dataset directory. """ from huggingface_hub import snapshot_download save_path = Path(save_path) save_path.mkdir(parents=True, exist_ok=True) dataset_dir = save_path / _DOWNLOADED_DIRNAME if dataset_dir.exists() and any(dataset_dir.iterdir()): logger.info("Cell type dataset already downloaded: %s", dataset_dir) return dataset_dir logger.info( "Downloading cell type annotation dataset from %s (this may take a while) ...", _HF_REPO_ID, ) cache_dir = save_path / "_hf_cache" snapshot_download( repo_id=_HF_REPO_ID, repo_type="dataset", allow_patterns=[f"{_HF_DATASET_SUBPATH}/**"], local_dir=str(cache_dir), ) # Move the nested downloaded directory to the expected flat location. downloaded_nested = cache_dir / Path(_HF_DATASET_SUBPATH) if downloaded_nested.exists(): shutil.copytree(str(downloaded_nested), str(dataset_dir)) shutil.rmtree(str(cache_dir)) logger.info("Downloaded cell type dataset to %s", dataset_dir) else: raise FileNotFoundError( f"Expected dataset not found after download: {downloaded_nested}\n" "The repository layout may have changed; check " f"https://huggingface.co/datasets/{_HF_REPO_ID}" ) return dataset_dirDownload the Geneformer cell type annotation dataset from HuggingFace Hub.
The Arrow dataset directory is placed at
save_path/cell_type_train_data .dataset/. If it already exists and is non-empty the download is skipped.Args
save_path- Local directory into which the dataset is downloaded.
Returns
Path to the Arrow dataset directory.
def prepare_celltype(source_dir: str | pathlib.Path,
output_dir: str | pathlib.Path,
*,
context_length: int = 1024,
train_ratio: float = 0.8,
seed: int = 42,
num_proc: int = 4,
eos_token_id: int = 0) ‑> str-
Expand source code
def prepare_celltype( source_dir: Union[str, Path], output_dir: Union[str, Path], *, context_length: int = 1024, train_ratio: float = 0.8, seed: int = 42, num_proc: int = 4, eos_token_id: int = 0, ) -> str: """ Load the pre-tokenized Geneformer cell type dataset from *source_dir*, build a language-model training dataset, and save it to *output_dir*. The dataset is already tokenized (gene rank-value encodings produced by TranscriptomeTokenizer), so only splitting, concatenation, and chunking are performed. Args: source_dir: Directory containing the downloaded ``cell_type_train_data.dataset/`` Arrow directory. output_dir: Parent directory; the dataset is written under ``output_dir/training_ready_hf_dataset/``. context_length: Token block length (default 1024). train_ratio: Fraction used for training (rest split 50/50 between validation and test). seed: Random seed for reproducible shuffling. num_proc: Parallel workers for HuggingFace Dataset.map() operations. eos_token_id: Token ID appended between sequences (default 0, consistent with RNA BERT config ``eos_token = 0``). Returns: Absolute path string of the saved ``training_ready_hf_dataset/`` directory. """ from datasets import load_from_disk source_dir = Path(source_dir) output_dir = Path(output_dir) dataset_arrow = source_dir / _DOWNLOADED_DIRNAME if not dataset_arrow.exists(): raise FileNotFoundError( f"Cell type dataset not found: {dataset_arrow}\n" "Run the download step first:\n" " python -m molcrawl.preparation.preparation_script_rna " "assets/configs/rna.yaml --datasets celltype" ) # ------------------------------------------------------------------ # 1. Load pre-tokenized HuggingFace Arrow dataset # ------------------------------------------------------------------ logger.info("Loading cell type dataset from %s", dataset_arrow) raw_dataset = load_from_disk(str(dataset_arrow)) # Unwrap DatasetDict if needed if hasattr(raw_dataset, "keys"): split_name = next(iter(raw_dataset.keys())) raw_dataset = raw_dataset[split_name] total = len(raw_dataset) logger.info("Loaded %d single-cell transcriptomes", total) # ------------------------------------------------------------------ # 2. Extract input_ids (gene token IDs; labels discarded for LM # fine-tuning) # ------------------------------------------------------------------ all_ids: List[List[int]] = list(raw_dataset["input_ids"]) # ------------------------------------------------------------------ # 3. Random 80 / 10 / 10 split # ------------------------------------------------------------------ rng = np.random.default_rng(seed) idx = rng.permutation(total) n_train = int(total * train_ratio) n_val = int(total * (1 - train_ratio) / 2) train_ids = [all_ids[i] for i in idx[:n_train]] val_ids = [all_ids[i] for i in idx[n_train : n_train + n_val]] test_ids = [all_ids[i] for i in idx[n_train + n_val :]] logger.info( "Split — train: %d, valid: %d, test: %d", len(train_ids), len(val_ids), len(test_ids), ) raw_split = DatasetDict( { "train": Dataset.from_dict({"input_ids": train_ids}), "valid": Dataset.from_dict({"input_ids": val_ids}), "test": Dataset.from_dict({"input_ids": test_ids}), } ) # ------------------------------------------------------------------ # 4. Concatenate all sequences + chunk into fixed-length blocks # ------------------------------------------------------------------ _concat_fn = partial(_concatenate_texts, eos_token_id=eos_token_id) _chunk_fn = partial(_create_chunks, context_length=context_length) output_dir.mkdir(parents=True, exist_ok=True) dataset_output = output_dir / "training_ready_hf_dataset" split_datasets = {} for sname, split_ds in raw_split.items(): logger.info("Processing split '%s' (%d cells)...", sname, len(split_ds)) flat = split_ds.map( _concat_fn, batched=True, batch_size=len(split_ds), remove_columns=split_ds.column_names, num_proc=1, # full-batch concat requires a single worker ) chunked = flat.map( _chunk_fn, batched=True, batch_size=1, remove_columns=flat.column_names, num_proc=1, ) logger.info(" → %d chunks of length %d", len(chunked), context_length) split_datasets[sname] = chunked final_dataset = DatasetDict(split_datasets) logger.info("Saving chunked dataset to %s", dataset_output) final_dataset.save_to_disk(str(dataset_output)) logger.info( "Done. Saved %d / %d / %d chunks (train / valid / test).", len(split_datasets["train"]), len(split_datasets["valid"]), len(split_datasets["test"]), ) return str(dataset_output)Load the pre-tokenized Geneformer cell type dataset from source_dir, build a language-model training dataset, and save it to output_dir.
The dataset is already tokenized (gene rank-value encodings produced by TranscriptomeTokenizer), so only splitting, concatenation, and chunking are performed.
Args
source_dir- Directory containing the downloaded
cell_type_train_data.dataset/Arrow directory. output_dir- Parent directory; the dataset is written under
output_dir/training_ready_hf_dataset/. context_length- Token block length (default 1024).
train_ratio- Fraction used for training (rest split 50/50 between validation and test).
seed- Random seed for reproducible shuffling.
num_proc- Parallel workers for HuggingFace Dataset.map() operations.
eos_token_id- Token ID appended between sequences (default 0,
consistent with RNA BERT config
eos_token = 0).
Returns
Absolute path string of the saved
training_ready_hf_dataset/directory.