Expand source code
def tokenize_batch_dataset(parquet_path, context_length, number_sample):
from datasets import DatasetDict
import numpy as np
from molcrawl.molecule_nat_lang.utils.tokenizer import MoleculeNatLangTokenizer
tokenize_dataset = DatasetDict(read_dataset(parquet_path))
# Handle validation/valid split naming
if "validation" in tokenize_dataset and "valid" not in tokenize_dataset:
tokenize_dataset["valid"] = tokenize_dataset["validation"]
del tokenize_dataset["validation"]
elif "valid" not in tokenize_dataset and "validation" not in tokenize_dataset:
raise KeyError("Neither 'valid' nor 'validation' split found in dataset")
tokenize_dataset["train"] = tokenize_dataset["train"].select(
np.random.choice(len(tokenize_dataset["train"]), int(number_sample * 0.8), replace=False)
)
tokenize_dataset["valid"] = tokenize_dataset["valid"].select(
np.random.choice(len(tokenize_dataset["valid"]), int(number_sample * 0.1), replace=False)
)
tokenize_dataset["test"] = tokenize_dataset["test"].select(
np.random.choice(len(tokenize_dataset["test"]), int(number_sample * 0.1), replace=False)
)
tokenizer = MoleculeNatLangTokenizer()
concatenated_dataset = tokenize_dataset.map(
partial(concatenate_texts, eos_token_id=tokenizer.tokenizer.eos_token_id),
batched=True,
batch_size=-1,
remove_columns=tokenize_dataset["train"].column_names,
)
chunked_dataset = concatenated_dataset.map(
partial(create_chunks, context_length=context_length),
batched=True,
batch_size=-1,
)
path_dataset = str(Path(parquet_path).parent / "training_ready_hf_dataset")
print(f"Saving dataset to: {path_dataset}. Match this path to the train_gpt2_config.py->dataset_dir parameter.")
chunked_dataset.save_to_disk(path_dataset)