Module molcrawl.rna.utils.preprocess

Functions

def binning(row: Union[np.ndarray, torch.Tensor], n_bins: int) ‑> Union[np.ndarray, torch.Tensor]
Expand source code
def binning(row: Union[np.ndarray, torch.Tensor], n_bins: int) -> Union[np.ndarray, torch.Tensor]:
    """Binning the row into n_bins."""
    import torch

    dtype = row.dtype
    return_np = False if isinstance(row, torch.Tensor) else True
    row = row.cpu().numpy() if isinstance(row, torch.Tensor) else row
    # TODO: use torch.quantile and torch.bucketize

    if row.max() == 0:
        logging.warning("The input data contains row of zeros. Please make sure this is expected.")
        return np.zeros_like(row, dtype=dtype) if return_np else torch.zeros_like(row, dtype=dtype)

    if row.min() <= 0:
        non_zero_ids = row.nonzero()
        non_zero_row = row[non_zero_ids]
        bins = np.quantile(non_zero_row, np.linspace(0, 1, n_bins - 1))
        non_zero_digits = _digitize(non_zero_row, bins)
        binned_row = np.zeros_like(row, dtype=np.int64)
        binned_row[non_zero_ids] = non_zero_digits
    else:
        bins = np.quantile(row, np.linspace(0, 1, n_bins - 1))
        binned_row = _digitize(row, bins)
    return torch.from_numpy(binned_row) if not return_np else binned_row.astype(dtype)

Binning the row into n_bins.