Module molcrawl.bert.main
Classes
class RNADatasetForBERT (data_dir, split='train', vocab_file=None, test_size=0.1)-
Expand source code
class RNADatasetForBERT: """Custom dataset class for loading RNA data""" def __init__(self, data_dir, split="train", vocab_file=None, test_size=0.1): self.data_dir = data_dir self.split = split self.test_size = test_size print(f"📂 Attempting to load BERT data from {data_dir}") # Load vocabulary if provided if vocab_file and os.path.exists(vocab_file): with open(vocab_file, "r") as f: self.vocab = json.load(f) print(f"📖 Loaded vocabulary: {len(self.vocab)} tokens") else: print("⚠️ No vocabulary file provided") self.vocab = None # Load data from Arrow files try: arrow_files = list(Path(data_dir).glob("*.arrow")) if arrow_files: print(f"📁 Found {len(arrow_files)} arrow files: {[f.name for f in arrow_files]}") all_batches = [] for arrow_file in arrow_files: try: print(f"📖 Reading {arrow_file.name}...") with pa.memory_map(str(arrow_file), "r") as mmap: with pa.ipc.open_stream(mmap) as reader: table = reader.read_all() print(f"🔢 Read table via stream: {len(table)} rows") all_batches.append(table) except Exception as e: print(f"❌ Failed to read {arrow_file.name}: {e}") continue if all_batches: # Combine all tables combined_table = pa.concat_tables(all_batches) print(f"📊 Combined {len(all_batches)} tables: {len(combined_table)} total rows") # Convert PyArrow table to pandas DataFrame, then to HuggingFace Dataset df = combined_table.to_pandas() print(f"📋 Converted to pandas DataFrame: {len(df)} rows") # Convert numpy arrays to lists for HuggingFace compatibility if "token" in df.columns: df["token"] = df["token"].apply(lambda x: x.tolist() if hasattr(x, "tolist") else x) # Create dataset from pandas DataFrame (bypasses metadata issues) self.dataset = Dataset.from_pandas(df) print("✅ Created HuggingFace Dataset from pandas DataFrame") print(f"🔍 Dataset columns: {self.dataset.column_names}") else: raise ValueError("No arrow files could be read successfully") else: raise FileNotFoundError(f"No .arrow files found in {data_dir}") except Exception as e: print(f"❌ Arrow loading failed: {e}") raise FileNotFoundError(f"Could not load data from {data_dir}") from e # Split into train/valid if needed if hasattr(self.dataset, "keys") and isinstance(self.dataset, dict) and "train" in self.dataset: # Already has splits if split == "train": self.data = self.dataset["train"] elif split in ["valid", "val", "test"]: self.data = self.dataset.get("valid", self.dataset.get("test", self.dataset["train"])) else: # Create splits if test_size > 0: split_dataset = self.dataset.train_test_split(test_size=test_size, seed=42) if split == "train": self.data = split_dataset["train"] elif split in ["valid", "val", "test"]: self.data = split_dataset["test"] else: self.data = self.dataset print(f"Loaded {len(self.data)} samples for {split}") # Show sample if len(self.data) > 0: sample = self.data[0] print("Sample keys:", list(sample.keys())) for key, value in sample.items(): if isinstance(value, list): print(f" {key}: {type(value)} of length {len(value)}") else: print(f" {key}: {type(value)} = {value}") def get_dataset(self): """Return the HuggingFace Dataset object""" return self.dataCustom dataset class for loading RNA data
Methods
def get_dataset(self)-
Expand source code
def get_dataset(self): """Return the HuggingFace Dataset object""" return self.dataReturn the HuggingFace Dataset object