Module molcrawl.rna.dataset.rna_dataset
RNA Transcriptome Dataset
Classes
class RNADataset (data_dir, split='train', vocab_file=None, test_size=0.1)-
Expand source code
class RNADataset: """RNA Transcriptome Dataset""" def __init__(self, data_dir, split="train", vocab_file=None, test_size=0.1): import pyarrow as pa from datasets import Dataset, load_from_disk self.data_dir = data_dir self.split = split self.test_size = test_size # Load vocabulary if vocab_file and os.path.exists(vocab_file): with open(vocab_file, "r") as f: self.vocab = json.load(f) self.vocab_size = len(self.vocab) else: # Default RNA vocabulary self.vocab = {"<pad>": 0, "<unk>": 1, "<eos>": 2} self.vocab_size = 3 # Load dataset - direct arrow file reading to bypass metadata issues print(f"📂 Attempting to load data from {data_dir}") try: data_path = Path(data_dir) arrow_files = sorted(list(data_path.glob("*.arrow"))) if arrow_files: print(f"📁 Found {len(arrow_files)} arrow files: {[f.name for f in arrow_files]}") all_batches = [] for arrow_file in arrow_files: print(f"📖 Reading {arrow_file.name}...") try: # Try as memory mapped stream first with pa.memory_map(str(arrow_file)) as mmap: with pa.ipc.open_stream(mmap) as reader: table = reader.read_all() print(f"✓ Read table via stream: {len(table)} rows") all_batches.append(table) except Exception: try: # Fallback to RecordBatch file with pa.memory_map(str(arrow_file)) as mmap: with pa.ipc.open_file(mmap) as reader: table = reader.read_all() print(f"✓ Read table via file: {len(table)} rows") all_batches.append(table) except Exception as e: print(f"❌ Failed to read {arrow_file.name}: {e}") continue if all_batches: # Combine all tables combined_table = pa.concat_tables(all_batches) print(f"📊 Combined {len(all_batches)} tables: {len(combined_table)} total rows") # Convert PyArrow table to pandas DataFrame, then to HuggingFace Dataset df = combined_table.to_pandas() print(f"📋 Converted to pandas DataFrame: {len(df)} rows") # Convert numpy arrays to lists for HuggingFace compatibility if "token" in df.columns: df["token"] = df["token"].apply(lambda x: x.tolist() if hasattr(x, "tolist") else x) # Create dataset from pandas DataFrame (bypasses metadata issues) self.dataset = Dataset.from_pandas(df) print("✅ Created HuggingFace Dataset from pandas DataFrame") print(f"🔍 Dataset columns: {self.dataset.column_names}") else: raise ValueError("No arrow files could be read successfully") else: raise FileNotFoundError(f"No .arrow files found in {data_dir}") except Exception as e: print(f"❌ Arrow loading failed: {e}") # Fallback to other methods try: print("🔄 Trying HuggingFace format as fallback...") self.dataset = load_from_disk(data_dir) print(f"✅ Loaded HuggingFace dataset from {data_dir}") except Exception as e2: print(f"❌ All loading methods failed: {e2}") raise FileNotFoundError(f"Could not load data from {data_dir}") from e2 # Split into train/valid if needed if hasattr(self.dataset, "keys") and isinstance(self.dataset, dict) and "train" in self.dataset: # Already has splits if split == "train": self.data = self.dataset["train"] elif split == "valid" or split == "val": self.data = self.dataset.get("valid", self.dataset.get("test", self.dataset["train"])) else: # Single dataset, need to split total_size = len(self.dataset) if split == "train": self.data = self.dataset.select(range(int(total_size * (1 - self.test_size)))) else: # valid self.data = self.dataset.select(range(int(total_size * (1 - self.test_size)), total_size)) print(f"Loaded {len(self.data)} samples for {split}") # Sample a few examples to understand data structure if len(self.data) > 0: sample = self.data[0] print(f"Sample keys: {list(sample.keys())}") for key, value in sample.items(): if isinstance(value, (list, str)): print(f" {key}: {type(value)} of length {len(value)}") else: print(f" {key}: {type(value)} = {value}") def __len__(self): return len(self.data) def __getitem__(self, idx): import torch item = self.data[idx] # RNA data has 'token' column with numpy arrays tokens = None # Try 'token' column first (RNA data format) if "token" in item and item["token"] is not None: tokens = item["token"] else: # Try other possible token column names for key in ["input_ids", "tokens", "token_ids", "tokenized"]: if key in item and item[key] is not None: tokens = item[key] break if tokens is None: # If no tokens, try to find text and tokenize it text = None for key in ["text", "sequence", "input_text"]: if key in item and item[key] is not None: text = item[key] break if text is not None: # Simple tokenization (this is a fallback) tokens = [self.vocab.get(char, self.vocab.get("<unk>", 1)) for char in str(text)] else: # Last resort: use all numeric values as tokens numeric_values = [v for v in item.values() if isinstance(v, (int, list))] if numeric_values: tokens = numeric_values[0] if isinstance(numeric_values[0], list) else [numeric_values[0]] else: tokens = [0] # padding token # Handle numpy array or list if hasattr(tokens, "tolist"): # Convert numpy array to list tokens = tokens.tolist() elif not isinstance(tokens, list): tokens = list(tokens) # Convert to integers if needed try: tokens = [int(t) for t in tokens] except (ValueError, TypeError): tokens = [self.vocab.get("<unk>", 1) for _ in tokens] return torch.tensor(tokens, dtype=torch.long)RNA Transcriptome Dataset