Module molcrawl.genome_sequence.dataset.sentence_piece_tokenizer

Functions

def train_tokenizer(output_dir, vocab_size, max_lines_per_file, input_sentence_size)
Expand source code
def train_tokenizer(output_dir, vocab_size, max_lines_per_file, input_sentence_size):
    import numpy as np
    import sentencepiece as spm

    """Train a tokenizer with sentence piece: https://github.com/google/sentencepiece"""
    path_dir = Path(output_dir) / "raw_files"
    files = list(path_dir.glob("*.raw"))
    files = np.random.permutation(files)

    spm.SentencePieceTrainer.train(
        input=",".join([str(f) for f in files[: int(input_sentence_size / max_lines_per_file) * 2]]),
        normalization_rule_name="identity",
        model_type="bpe",
        model_prefix=str(Path(output_dir) / "spm_tokenizer"),
        vocab_size=vocab_size,
        input_sentence_size=input_sentence_size,
        allow_whitespace_only_pieces=False,
        remove_extra_whitespaces=True,
        max_sentencepiece_length=50,
        split_by_whitespace=False,
        add_dummy_prefix=False,
    )