Module molcrawl.rna.utils.bert_tokenizer

BERT-compatible tokenizer wrapper for RNA Wrap TranscriptomeTokenizer in a format compatible with BERT learning

Functions

def create_bert_rna_tokenizer(**kwargs) ‑> BertRnaTokenizer
Expand source code
def create_bert_rna_tokenizer(**kwargs) -> BertRnaTokenizer:
    """
    Create a BERT-compatible RNA tokenizer

    Returns:
        BertRnaTokenizer instance
    """
    return BertRnaTokenizer(**kwargs)

Create a BERT-compatible RNA tokenizer

Returns

BertRnaTokenizer instance

Classes

class BertRnaTokenizer (**kwargs)
Expand source code
class BertRnaTokenizer:
    """
    RNA TranscriptomeTokenizer wrapper for BERT compatibility

    This class wraps the TranscriptomeTokenizer to make it compatible
    with BERT training and testing pipelines
    """

    # Override model input names to use standard BERT format
    model_input_names = ["input_ids", "attention_mask"]

    def __init__(self, **kwargs):
        self.tokenizer = TranscriptomeTokenizer(**kwargs)

        # BERT compatibility attributes
        self.pad_token = "[PAD]"
        self.unk_token = "[UNK]"
        self.cls_token = "[CLS]"
        self.sep_token = "[SEP]"
        self.mask_token = "[MASK]"

        self.pad_token_id = 0
        self.unk_token_id = 1
        self.cls_token_id = 2
        self.sep_token_id = 3
        self.mask_token_id = 4

        # Create a simple vocab mapping for testing purposes
        self.vocab = {
            "[PAD]": 0,
            "[UNK]": 1,
            "[CLS]": 2,
            "[SEP]": 3,
            "[MASK]": 4,
        }

        # Add gene tokens from the underlying tokenizer
        if hasattr(self.tokenizer, "gene_token_dict"):
            for gene_id, token_id in self.tokenizer.gene_token_dict.items():
                # Offset by 5 to account for special tokens
                adjusted_token_id = token_id + 5 if isinstance(token_id, int) else token_id
                self.vocab[gene_id] = adjusted_token_id

    def get_vocab(self):
        """Get vocabulary dictionary"""
        return self.vocab

    def __len__(self):
        """Return vocabulary size"""
        return len(self.vocab)

    def tokenize(self, text: str) -> List[str]:
        """
        Tokenize text (for RNA, this is mostly for compatibility)
        """
        # For RNA data, we typically work with gene expression vectors
        # This is a simplified tokenization for testing purposes
        if isinstance(text, str):
            # Simple word-level tokenization for testing
            tokens = text.split()
            return tokens
        return []

    def encode(
        self,
        text: Union[str, List[str]],
        add_special_tokens: bool = True,
        max_length: Optional[int] = None,
        padding: Union[bool, str] = False,
        truncation: Union[bool, str] = False,
        return_tensors: Optional[str] = None,
    ) -> Union[List[int], torch.Tensor]:
        """
        Encode text to token IDs
        """
        if isinstance(text, str):
            tokens = self.tokenize(text)
        else:
            tokens = text

        # Convert tokens to IDs
        token_ids = []

        if add_special_tokens:
            token_ids.append(self.cls_token_id)

        for token in tokens:
            token_id = self.vocab.get(token, self.unk_token_id)
            token_ids.append(token_id)

        if add_special_tokens:
            token_ids.append(self.sep_token_id)

        padding_enabled = self._normalize_bool(padding)
        truncation_enabled = self._normalize_bool(truncation)

        # Apply truncation
        if max_length and truncation_enabled:
            if len(token_ids) > max_length:
                if add_special_tokens:
                    # Keep CLS token and ensure SEP token at the end
                    token_ids = token_ids[: max_length - 1] + [self.sep_token_id]
                else:
                    token_ids = token_ids[:max_length]

        # Apply padding
        if max_length and padding_enabled:
            while len(token_ids) < max_length:
                token_ids.append(self.pad_token_id)

        if return_tensors == "pt":
            import torch

            return torch.tensor(token_ids)

        return token_ids

    def __call__(
        self,
        text: Union[str, List[str]],
        add_special_tokens: bool = True,
        padding: Union[bool, str] = False,
        truncation: Union[bool, str] = False,
        max_length: Optional[int] = None,
        return_tensors: Optional[str] = None,
        **kwargs,
    ) -> Dict[str, Any]:
        """
        Tokenize and encode text
        """
        result: Dict[str, Any]
        if isinstance(text, list):
            # Batch processing
            input_ids: List[List[int]] = []
            attention_masks: List[List[int]] = []

            for single_text in text:
                encoded = self.encode(
                    single_text,
                    add_special_tokens=add_special_tokens,
                    max_length=max_length,
                    padding=padding,
                    truncation=truncation,
                    return_tensors=None,
                )
                encoded_list = cast(List[int], encoded)
                input_ids.append(encoded_list)

                # Create attention mask
                attention_mask = [1 if token_id != self.pad_token_id else 0 for token_id in encoded_list]
                attention_masks.append(attention_mask)

            result = {"input_ids": input_ids, "attention_mask": attention_masks}

            if return_tensors == "pt":
                import torch

                result["input_ids"] = torch.tensor(result["input_ids"])
                result["attention_mask"] = torch.tensor(result["attention_mask"])

        else:
            # Single text processing
            single_input_ids = self.encode(
                text,
                add_special_tokens=add_special_tokens,
                max_length=max_length,
                padding=padding,
                truncation=truncation,
                return_tensors=None,
            )
            single_input_ids = cast(List[int], single_input_ids)

            # Create attention mask
            single_attention_mask = [1 if token_id != self.pad_token_id else 0 for token_id in single_input_ids]

            result = {"input_ids": single_input_ids, "attention_mask": single_attention_mask}

            if return_tensors == "pt":
                import torch

                result["input_ids"] = torch.tensor([result["input_ids"]])
                result["attention_mask"] = torch.tensor([result["attention_mask"]])

        return result

    def decode(
        self,
        token_ids: Union[List[int], torch.Tensor],
        skip_special_tokens: bool = True,
    ) -> str:
        """
        Decode token IDs back to text
        """
        import torch

        if isinstance(token_ids, torch.Tensor):
            token_ids = token_ids.tolist()

        # Create reverse mapping
        id_to_token = {v: k for k, v in self.vocab.items()}

        tokens = []
        for token_id in cast(List[int], token_ids):
            token = id_to_token.get(token_id, self.unk_token)

            if skip_special_tokens and token in [
                self.pad_token,
                self.cls_token,
                self.sep_token,
            ]:
                continue

            tokens.append(token)

        return " ".join(tokens)

    def pad(self, encoded_inputs, **kwargs):
        """
        Pad encoded inputs (for compatibility)
        """
        return encoded_inputs

    def _normalize_bool(self, value: Union[bool, str]) -> bool:
        if isinstance(value, bool):
            return value
        return value.lower() not in {"false", "0", "no", "none", ""}

RNA TranscriptomeTokenizer wrapper for BERT compatibility

This class wraps the TranscriptomeTokenizer to make it compatible with BERT training and testing pipelines

Class variables

var model_input_names

Methods

def decode(self,
token_ids: Union[List[int], torch.Tensor],
skip_special_tokens: bool = True) ‑> str
Expand source code
def decode(
    self,
    token_ids: Union[List[int], torch.Tensor],
    skip_special_tokens: bool = True,
) -> str:
    """
    Decode token IDs back to text
    """
    import torch

    if isinstance(token_ids, torch.Tensor):
        token_ids = token_ids.tolist()

    # Create reverse mapping
    id_to_token = {v: k for k, v in self.vocab.items()}

    tokens = []
    for token_id in cast(List[int], token_ids):
        token = id_to_token.get(token_id, self.unk_token)

        if skip_special_tokens and token in [
            self.pad_token,
            self.cls_token,
            self.sep_token,
        ]:
            continue

        tokens.append(token)

    return " ".join(tokens)

Decode token IDs back to text

def encode(self,
text: Union[str, List[str]],
add_special_tokens: bool = True,
max_length: Optional[int] = None,
padding: Union[bool, str] = False,
truncation: Union[bool, str] = False,
return_tensors: Optional[str] = None) ‑> Union[List[int], torch.Tensor]
Expand source code
def encode(
    self,
    text: Union[str, List[str]],
    add_special_tokens: bool = True,
    max_length: Optional[int] = None,
    padding: Union[bool, str] = False,
    truncation: Union[bool, str] = False,
    return_tensors: Optional[str] = None,
) -> Union[List[int], torch.Tensor]:
    """
    Encode text to token IDs
    """
    if isinstance(text, str):
        tokens = self.tokenize(text)
    else:
        tokens = text

    # Convert tokens to IDs
    token_ids = []

    if add_special_tokens:
        token_ids.append(self.cls_token_id)

    for token in tokens:
        token_id = self.vocab.get(token, self.unk_token_id)
        token_ids.append(token_id)

    if add_special_tokens:
        token_ids.append(self.sep_token_id)

    padding_enabled = self._normalize_bool(padding)
    truncation_enabled = self._normalize_bool(truncation)

    # Apply truncation
    if max_length and truncation_enabled:
        if len(token_ids) > max_length:
            if add_special_tokens:
                # Keep CLS token and ensure SEP token at the end
                token_ids = token_ids[: max_length - 1] + [self.sep_token_id]
            else:
                token_ids = token_ids[:max_length]

    # Apply padding
    if max_length and padding_enabled:
        while len(token_ids) < max_length:
            token_ids.append(self.pad_token_id)

    if return_tensors == "pt":
        import torch

        return torch.tensor(token_ids)

    return token_ids

Encode text to token IDs

def get_vocab(self)
Expand source code
def get_vocab(self):
    """Get vocabulary dictionary"""
    return self.vocab

Get vocabulary dictionary

def pad(self, encoded_inputs, **kwargs)
Expand source code
def pad(self, encoded_inputs, **kwargs):
    """
    Pad encoded inputs (for compatibility)
    """
    return encoded_inputs

Pad encoded inputs (for compatibility)

def tokenize(self, text: str) ‑> List[str]
Expand source code
def tokenize(self, text: str) -> List[str]:
    """
    Tokenize text (for RNA, this is mostly for compatibility)
    """
    # For RNA data, we typically work with gene expression vectors
    # This is a simplified tokenization for testing purposes
    if isinstance(text, str):
        # Simple word-level tokenization for testing
        tokens = text.split()
        return tokens
    return []

Tokenize text (for RNA, this is mostly for compatibility)