Module molcrawl.molecule_nat_lang.utils.general
Functions
def compute_resource_aware_params(num_rows: int = 3300000,
bytes_per_row_estimate: int = 2048,
safety_factor: float = 0.5,
max_workers: int = 8,
target_batch_bytes: int = 268435456) ‑> dict-
Expand source code
def compute_resource_aware_params( num_rows: int = 3_300_000, bytes_per_row_estimate: int = 2_048, safety_factor: float = 0.5, max_workers: int = 8, target_batch_bytes: int = 256 * 1024 * 1024, # 256 MB ) -> dict: """ Inspect available system memory and CPU count, then compute safe values for: - ``num_workers`` — parallelism for ``Dataset.map()`` - ``batch_size`` — rows per batch for streaming parquet writes Parameters ---------- num_rows: Estimated total number of rows across all splits. bytes_per_row_estimate: Estimated bytes per row in memory (input_ids + masks, etc.). safety_factor: Fraction of available memory that this process may consume. max_workers: Hard upper bound on the returned ``num_workers``. target_batch_bytes: Target memory footprint per parquet write batch. Returns ------- dict with keys ``num_workers``, ``batch_size``, ``available_gb``, ``total_gb``. """ available_bytes, total_bytes = get_available_memory_bytes() available_gb = available_bytes / 1024**3 total_gb = total_bytes / 1024**3 cpu_count = os.cpu_count() or 1 budget_bytes = available_bytes * safety_factor dataset_bytes = max(num_rows * bytes_per_row_estimate, 1) # Each worker needs ~1 partition of the dataset; pick largest safe count workers_by_mem = max(1, int(budget_bytes / dataset_bytes)) num_workers = min(workers_by_mem, cpu_count, max_workers) # Parquet batch: aim for target_batch_bytes per batch, clamp to [1 000, 200 000] batch_size = max(1_000, min(int(target_batch_bytes / bytes_per_row_estimate), 200_000)) logger.info( "[ResourceAware] memory: available=%.1f GB / total=%.1f GB | " "CPUs=%d | estimated dataset=%.1f GB | " "→ num_workers=%d batch_size=%d", available_gb, total_gb, cpu_count, dataset_bytes / 1024**3, num_workers, batch_size, ) return { "num_workers": num_workers, "batch_size": batch_size, "available_gb": available_gb, "total_gb": total_gb, }Inspect available system memory and CPU count, then compute safe values for: -
num_workers— parallelism forDataset.map()-batch_size— rows per batch for streaming parquet writesParameters
num_rows: Estimated total number of rows across all splits. bytes_per_row_estimate: Estimated bytes per row in memory (input_ids + masks, etc.). safety_factor: Fraction of available memory that this process may consume. max_workers: Hard upper bound on the returned
num_workers. target_batch_bytes: Target memory footprint per parquet write batch.Returns
dict with keys
num_workers,batch_size,available_gb,total_gb. def get_available_memory_bytes() ‑> tuple[int, int]-
Expand source code
def get_available_memory_bytes() -> tuple[int, int]: """ Return (available_bytes, total_bytes) from /proc/meminfo. Falls back to psutil if /proc/meminfo is absent (non-Linux). """ proc_meminfo = Path("/proc/meminfo") if proc_meminfo.exists(): mem: dict[str, int] = {} with open(proc_meminfo) as fh: for line in fh: parts = line.split() if len(parts) >= 2: mem[parts[0].rstrip(":")] = int(parts[1]) * 1024 # kB → bytes available = mem.get("MemAvailable", mem.get("MemFree", 8 * 1024**3)) total = mem.get("MemTotal", available) return available, total try: import psutil # type: ignore[import] vm = psutil.virtual_memory() return vm.available, vm.total except ImportError: fallback = 8 * 1024**3 # assume 8 GB when nothing else is available return fallback, fallbackReturn (available_bytes, total_bytes) from /proc/meminfo. Falls back to psutil if /proc/meminfo is absent (non-Linux).
def load_jsonl_dataset(dataset_path: Union[str, Path])-
Expand source code
def load_jsonl_dataset(dataset_path: Union[str, Path]): """ Load SMolInstruct dataset from JSONL files in raw/ directory Args: dataset_path: Path to SMolInstruct directory (contains raw/ subdirectory) Returns: DatasetDict with train/dev/test splits """ dataset_path_obj = Path(dataset_path) raw_dir = dataset_path_obj / "raw" if not raw_dir.exists(): raise FileNotFoundError(f"Raw directory not found: {raw_dir}") splits = {} for split_name in ["train", "dev", "test"]: split_dir = raw_dir / split_name if not split_dir.exists(): logger.warning(f"Split directory not found: {split_dir}") continue logger.info(f"Loading {split_name} split from {split_dir}") # Load all JSONL files in the split directory all_data = [] jsonl_files = list(split_dir.glob("*.jsonl")) logger.info(f"Found {len(jsonl_files)} JSONL files in {split_name}") for jsonl_file in jsonl_files: with open(jsonl_file, "r", encoding="utf-8") as f: for line in f: if line.strip(): data = json.loads(line) all_data.append(data) logger.info(f"Loaded {len(all_data)} samples for {split_name}") # Create Dataset from list of dicts # Explicitly define features to ensure proper serialization if all_data: # Infer features from first sample, all fields as string from datasets import Dataset, Features, Value features = Features({key: Value("string") for key in all_data[0].keys()}) splits[split_name] = Dataset.from_list(all_data, features=features) else: from datasets import Dataset splits[split_name] = Dataset.from_list(all_data) # Rename 'dev' to 'valid' for consistency if "dev" in splits: splits["valid"] = splits.pop("dev") from datasets import DatasetDict return DatasetDict(splits)Load SMolInstruct dataset from JSONL files in raw/ directory
Args
dataset_path- Path to SMolInstruct directory (contains raw/ subdirectory)
Returns
DatasetDict with train/dev/test splits
def read_dataset(dataset_path: Union[str, Path])-
Expand source code
def read_dataset(dataset_path: Union[str, Path]): """ Read dataset from disk, supporting JSONL, split directories, and DatasetDict format Args: dataset_path: Path to the dataset directory Returns: dict or DatasetDict: Dictionary of splits with Dataset objects """ dataset_path_obj = Path(dataset_path) if not dataset_path_obj.exists(): raise FileNotFoundError(f"Dataset path does not exist: {dataset_path_obj}") ## TODO: duplication check. # Check if this is a parquet file if dataset_path_obj.is_file() and dataset_path_obj.suffix == ".parquet": logger.info(f"Loading parquet file: {dataset_path_obj}") from datasets import Dataset dataset = Dataset.from_parquet(str(dataset_path_obj)) # If the dataset has a 'split' column, split it accordingly if "split" in dataset.column_names: logger.info("Found 'split' column, splitting dataset by split values") split_values = dataset.unique("split") splits = {} for split_name in split_values: split_dataset = dataset.filter(lambda x, split_name=split_name: x["split"] == split_name) # Remove the split column as it's no longer needed split_dataset = split_dataset.remove_columns(["split"]) splits[split_name] = split_dataset logger.info(f"Created {split_name} split with {len(split_dataset)} samples") from datasets import DatasetDict return DatasetDict(splits) else: # Return as DatasetDict with train split from datasets import DatasetDict return DatasetDict({"train": dataset}) # Check if this is a SMolInstruct-style directory with raw/ subdirectory raw_dir = dataset_path_obj / "raw" if raw_dir.exists() and raw_dir.is_dir(): logger.info(f"Detected JSONL format dataset at {dataset_path_obj}") return load_jsonl_dataset(dataset_path_obj) # Try to load as DatasetDict first (if it was saved with save_to_disk) try: logger.info(f"Attempting to load dataset as DatasetDict from {dataset_path_obj}") from datasets import DatasetDict dataset_dict = DatasetDict.load_from_disk(str(dataset_path_obj)) logger.info(f"Successfully loaded DatasetDict with splits: {list(dataset_dict.keys())}") return dataset_dict except Exception as e: logger.debug(f"Not a DatasetDict format: {e}") # Fall back to loading individual split directories splits = {} try: for folder in os.listdir(dataset_path_obj): folder_path = dataset_path_obj / folder if folder_path.is_dir(): # Skip cache and metadata directories if folder.startswith(".") or folder == "hf_cache": continue try: logger.info(f"Loading split: {folder}") from datasets import Dataset splits[folder] = Dataset.load_from_disk(str(folder_path)) logger.info(f"Loaded {folder} with {len(splits[folder])} samples") except Exception as split_error: logger.warning(f"Failed to load split {folder}: {split_error}") if not splits: raise ValueError(f"No valid dataset splits found in {dataset_path_obj}") return splits except Exception as e: logger.error(f"Failed to read dataset from {dataset_path_obj}: {e}") raiseRead dataset from disk, supporting JSONL, split directories, and DatasetDict format
Args
dataset_path- Path to the dataset directory
Returns
dictorDatasetDict- Dictionary of splits with Dataset objects
def save_dataset(dataset, dataset_path: Union[str, Path], batch_size: int = 50000)-
Expand source code
def save_dataset(dataset, dataset_path: Union[str, Path], batch_size: int = 50000): """ Save dataset to disk or as parquet file Args: dataset: Dictionary of Dataset objects or DatasetDict dataset_path: Path to save the dataset (directory or .parquet file) batch_size: Number of rows per batch when writing parquet (avoids OOM) """ dataset_path_obj = Path(dataset_path) # Check if saving as parquet file if dataset_path_obj.suffix == ".parquet": logger.info(f"Saving dataset as parquet to {dataset_path_obj}") os.makedirs(dataset_path_obj.parent, exist_ok=True) # Convert to DatasetDict if it's a dict from datasets import DatasetDict if not isinstance(dataset, DatasetDict): dataset = DatasetDict(dataset) import pyarrow as pa import pyarrow.parquet as pq writer = None total_saved = 0 try: for split_name, split_dataset in dataset.items(): logger.info(f"Writing split '{split_name}' ({len(split_dataset)} samples) to parquet...") num_rows = len(split_dataset) # .data returns a huggingface InMemoryTable wrapper; .table is the # actual pyarrow.lib.Table that ParquetWriter.write_table() requires. # .slice() on a pyarrow Table is zero-copy — no extra memory allocated. pa_table = split_dataset.data.table for start in range(0, num_rows, batch_size): length = min(batch_size, num_rows - start) batch_pa = pa_table.slice(start, length) split_col = pa.array([split_name] * length, type=pa.string()) batch_pa = batch_pa.append_column("split", split_col) if writer is None: writer = pq.ParquetWriter(str(dataset_path_obj), batch_pa.schema) writer.write_table(batch_pa) total_saved += length logger.info(f" ... written {start + length}/{num_rows} rows for '{split_name}'") finally: if writer is not None: writer.close() logger.info(f"Saved {total_saved} samples to parquet file") return # Otherwise save as directory structure os.makedirs(dataset_path_obj, exist_ok=True) logger.info(f"Saving dataset to {dataset_path_obj}") # If it's a DatasetDict, we can save it directly from datasets import DatasetDict if isinstance(dataset, DatasetDict): dataset.save_to_disk(str(dataset_path_obj)) logger.info(f"Saved DatasetDict with {len(dataset)} splits") return # Otherwise, save each split separately for split in dataset.keys(): split_path = dataset_path_obj / split logger.info(f"Saving {split} split to {split_path}") dataset[split].save_to_disk(str(split_path)) logger.info(f"Saved {split} with {len(dataset[split])} samples")Save dataset to disk or as parquet file
Args
dataset- Dictionary of Dataset objects or DatasetDict
dataset_path- Path to save the dataset (directory or .parquet file)
batch_size- Number of rows per batch when writing parquet (avoids OOM)