Module molcrawl.protein_sequence.dataset.prepare_gpt2
Functions
def concatenate_texts(examples: Dict[str, List[List[int]]], eos_token_id: int) ‑> Dict[str, List[int]]-
Expand source code
def concatenate_texts(examples: Dict[str, List[List[int]]], eos_token_id: int) -> Dict[str, List[int]]: concatenated_ids: List[int] = [] for input_ids in examples["input_ids"]: concatenated_ids.extend(input_ids + [eos_token_id]) return {"input_ids": concatenated_ids} def create_chunks(examples: Dict[str, List[int]], context_length: int) ‑> Dict[str, List[List[int]]]-
Expand source code
def create_chunks(examples: Dict[str, List[int]], context_length: int) -> Dict[str, List[List[int]]]: concatenated_ids: List[int] = examples["input_ids"] # Calculate the total number of chunks total_length = len(concatenated_ids) num_chunks = total_length // context_length # Truncate the concatenated_ids to a multiple of context_length total_length = num_chunks * context_length concatenated_ids = concatenated_ids[:total_length] # Split into chunks input_ids = [concatenated_ids[i : i + context_length] for i in range(0, total_length, context_length)] return {"input_ids": input_ids} def tokenize_batch_dataset(path_output: Path, context_length: int, number_sample: int) ‑> None-
Expand source code
def tokenize_batch_dataset(path_output: Path, context_length: int, number_sample: int) -> None: from datasets import DatasetDict, load_dataset from molcrawl.protein_sequence.dataset.tokenizer import EsmSequenceTokenizer raw_dir: Path = Path(path_output) / "raw_files" raw_files: List[Path] = sorted(raw_dir.glob("*.raw")) + sorted(raw_dir.glob("*.txt")) if not raw_files: raise FileNotFoundError( f"No raw data files found in {raw_dir}. " "Expected *.raw or *.txt files. Check symlinks or rerun the preparation step." ) # Avoid the effect of extension determination by explicitly passing the file list data = ( load_dataset( "text", data_files={"train": [str(p) for p in raw_files]}, split="train", ) .shuffle() .select(range(number_sample)) ) raw_datasets = data.train_test_split(test_size=0.2) valid_test_split = raw_datasets["test"].train_test_split(test_size=0.5) raw_datasets = DatasetDict( {"train": raw_datasets["train"], "valid": valid_test_split["train"], "test": valid_test_split["test"]} ) tokenizer: EsmSequenceTokenizer = EsmSequenceTokenizer() tokenized_datasets = raw_datasets.map( partial(tokenize_function, tokenizer=tokenizer), batched=True, remove_columns=["text"], ) concatenated_dataset = tokenized_datasets.map( partial(concatenate_texts, eos_token_id=tokenizer.eos_token_id), batched=True, batch_size=-1, ) chunked_dataset = concatenated_dataset.map( partial(create_chunks, context_length=context_length), batched=True, batch_size=-1, ) path_dataset: str = str(path_output / "training_ready_hf_dataset") print(f"Saving dataset to: {path_dataset}. Match this path to the train_gpt2_config.py->dataset_dir parameter.") chunked_dataset.save_to_disk(path_dataset) def tokenize_function(examples: Dict[str, List[str]], tokenizer: EsmSequenceTokenizer) ‑> Dict[str, List[List[int]]]-
Expand source code
def tokenize_function(examples: Dict[str, List[str]], tokenizer: EsmSequenceTokenizer) -> Dict[str, List[List[int]]]: return { "input_ids": tokenizer( examples["text"], truncation=False, add_special_tokens=False, # We'll add special tokens manually )["input_ids"] }