Source code for Compocyte.core.models.fit_methods

from copy import deepcopy
import os
from typing import Union
import numpy as np
import pandas as pd
from scipy import sparse
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import robust_scale
import torch
import logging
import dask.array as da
from torch.utils.data import TensorDataset, random_split, DataLoader, IterableDataset, get_worker_info
from Compocyte.core.models.dense_torch import DenseTorch
from Compocyte.core.models.dummy_classifier import DummyClassifier
from Compocyte.core.models.log_reg import LogisticRegression
from Compocyte.core.models.trees import BoostedTrees
from balanced_loss import Loss as BalancedLoss

logger = logging.getLogger(__name__)

[docs] def to_categorical(y, num_classes, dtype="float32"): """ Simplified from keras to avoid dependency and premature conversion to a Tensor. """ y = np.array(y, dtype="int") input_shape = y.shape y = y.reshape(-1) n = y.shape[0] categorical = np.zeros((n, num_classes), dtype=dtype) categorical[np.arange(n), y] = 1 output_shape = input_shape + (num_classes,) categorical = np.reshape(categorical, output_shape) return categorical
[docs] class DaskBatchDataset(IterableDataset):
[docs] def __init__(self, X, y): # Convert both to lists of delayed chunks self.X_chunks = X.to_delayed().ravel() self.y_chunks = y.to_delayed().ravel() self._epoch = 0 assert len(self.X_chunks) == len(self.y_chunks), \ "Feature and label chunks must be aligned"
[docs] def set_epoch(self, epoch): # Call before each epoch so __iter__ uses a fresh permutation. # Must be called in the main process before the DataLoader starts iterating # (before workers are spawned), so the updated value is pickled into each worker. self._epoch = epoch
def __iter__(self): # All workers derive the same permutation from the epoch seed, then each takes # a disjoint slice — this gives epoch-level shuffling that is safe with num_workers > 0. rng = np.random.default_rng(self._epoch) perm = rng.permutation(len(self.X_chunks)) X_shuffled = self.X_chunks[perm] y_shuffled = self.y_chunks[perm] worker_info = get_worker_info() if worker_info is None: chunk_iter = zip(X_shuffled, y_shuffled) else: worker_id = worker_info.id num_workers = worker_info.num_workers chunk_iter = zip( X_shuffled[worker_id::num_workers], y_shuffled[worker_id::num_workers] ) window_size = 5 buf_X, buf_y, chunk_sizes = [], [], [] def _flush(buf_X, buf_y, chunk_sizes): X_buf = np.concatenate(buf_X, axis=0) y_buf = np.concatenate(buf_y, axis=0) perm = rng.permutation(len(X_buf)) X_buf, y_buf = X_buf[perm], y_buf[perm] start = 0 for size in chunk_sizes: yield ( torch.from_numpy(X_buf[start:start + size]).to(torch.float32), torch.from_numpy(y_buf[start:start + size]).to(torch.float32) ) start += size for X_chunk, y_chunk in chunk_iter: X_np = X_chunk.compute() y_np = y_chunk.compute() buf_X.append(X_np) buf_y.append(y_np) chunk_sizes.append(len(X_np)) if len(buf_X) == window_size: yield from _flush(buf_X, buf_y, chunk_sizes) buf_X, buf_y, chunk_sizes = [], [], [] if buf_X: yield from _flush(buf_X, buf_y, chunk_sizes)
[docs] def predict_logits(model, x): x = robust_scale(x, axis=1, with_centering=False, copy=False, unit_variance=True) if isinstance(x, sparse.csr_matrix): x = sparse.csr_matrix.toarray(x) if isinstance(model, DenseTorch): logits = model.predict_logits(x) elif isinstance(model, LogisticRegression): logits = model.predict_logits(x) elif isinstance(model, BoostedTrees): logits = model.predict_logits(x) elif isinstance(model, DummyClassifier): logits = model.predict_logits(x) else: raise Exception('Unknown classifier type.') return logits
[docs] def predict(model, x, threshold=-1, monte_carlo: int=None): x = robust_scale(x, axis=1, with_centering=False, copy=False, unit_variance=True) if isinstance(x, sparse.csr_matrix): x = sparse.csr_matrix.toarray(x) if monte_carlo is not None: all_logits = [] dropout = torch.nn.Dropout(p=0.5) for _ in range(monte_carlo): x_dropout = np.array(dropout(torch.Tensor(x))) if isinstance(model, DenseTorch): all_logits.append(model.predict_logits(x_dropout)) elif isinstance(model, LogisticRegression): all_logits.append(model.predict_logits(x_dropout)) elif isinstance(model, BoostedTrees): all_logits.append(model.predict_logits(x_dropout)) elif isinstance(model, DummyClassifier): return model.predict(x) else: raise Exception('Unknown classifier type') all_logits = np.array(all_logits) logits = np.mean(all_logits, axis=0) else: if isinstance(model, DenseTorch): logits = model.predict_logits(x) elif isinstance(model, LogisticRegression): logits = model.predict_logits(x) elif isinstance(model, BoostedTrees): logits = model.predict_logits(x) elif isinstance(model, DummyClassifier): return model.predict(x) else: raise Exception('Unknown classifier type') max_activation = np.max(logits, axis=1) pred = np.argmax(logits, axis=1).astype(int) pred = np.array([model.labels_dec[p] for p in pred]) pred[max_activation <= threshold] = '' if monte_carlo is not None: return pred, all_logits return pred
[docs] def samples_per_class(y): spc = list(torch.zeros(y.shape[1])) classes_counted = np.unique(np.argmax(y, axis=1), return_counts=True) for c, samples in zip(classes_counted[0], classes_counted[1]): spc[c] = samples return spc
[docs] def set_threads(num_threads, parallelize): if num_threads > 10 and not parallelize: # 4 for dask, 4 for dataloader, the rest for torch num_workers = 4 num_threads = num_threads - 4 - 4 else: num_workers = 0 num_threads = 1 logger.info(f'num_workers set to {num_workers}') torch.set_num_threads(num_threads) logger.info(f'num_threads set to {torch.get_num_threads()}') return num_workers
[docs] def dataloaders_from_dask(x, y, batch_size, num_workers): total_samples = x.shape[0] indices = np.arange(total_samples) rng = np.random.default_rng(12345) rng.shuffle(indices) indices_train = indices[:int(np.floor(total_samples * .8))] indices_val = indices[int(np.floor(total_samples * .8)):] x_train, y_train = x[indices_train], y[indices_train] x_val, y_val = x[indices_val], y[indices_val] batch_size = min(batch_size, x_train.shape[0]) x_train = da.from_array(x_train, chunks=(batch_size, x_train.shape[1])) x_train = x_train.map_blocks( sparse.csr_matrix.toarray, dtype=np.float32) y_train = da.from_array(y_train, chunks=(batch_size, y_train.shape[1])) train_dataset = DaskBatchDataset(x_train, y_train) num_batches = len(train_dataset.X_chunks) train_dataloader = DataLoader(train_dataset, batch_size=None, num_workers=num_workers) batch_size = min(batch_size, x_val.shape[0]) x_val = da.from_array(x_val, chunks=(batch_size, x_val.shape[1])) x_val = x_val.map_blocks( sparse.csr_matrix.toarray, dtype=np.float32) y_val = da.from_array(y_val, chunks=(batch_size, y_val.shape[1])) val_dataset = DaskBatchDataset(x_val, y_val) num_batches_val = len(val_dataset.X_chunks) val_dataloader = DataLoader(val_dataset, batch_size=None, num_workers=num_workers) return train_dataloader, val_dataloader, num_batches, num_batches_val
[docs] def dataloaders_from_dense(x, y, batch_size, num_workers): x = torch.from_numpy( sparse.csr_matrix.toarray(x) ).to(torch.float32) y = torch.from_numpy(y).to(torch.float32) dataset = TensorDataset(x, y) train_dataset, val_dataset = random_split( dataset, [0.8, 0.2]) batch_size = min(batch_size, len(train_dataset)) leaves_remainder = len(train_dataset) % batch_size == 1 train_dataloader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, drop_last=leaves_remainder, num_workers=num_workers) batch_size = min(batch_size, len(val_dataset)) leaves_remainder = len(val_dataset) % batch_size == 1 val_dataloader = DataLoader( val_dataset, batch_size=batch_size, shuffle=True, drop_last=leaves_remainder, num_workers=num_workers) return train_dataloader, val_dataloader
[docs] def fit_torch( model: DenseTorch, x: np.array, y: np.array, epochs: int=40, batch_size: int=64, starting_lr: float=0.01, max_lr: float=0.1, momentum: float=0.5, parallelize: bool=True, num_threads: int=1, beta: float=0.8, gamma: float=2.0, class_balance: bool=True, max_cells: int=1_000_000): num_workers = set_threads(num_threads, parallelize) y = to_categorical(y, num_classes=len(model.labels_enc.keys())) total_samples = x.shape[0] if total_samples > max_cells: train_dataloader, val_dataloader, num_batches, num_batches_val = dataloaders_from_dask( x, y, batch_size, num_workers) else: train_dataloader, val_dataloader = dataloaders_from_dense(x, y, batch_size, num_workers) num_batches = len(train_dataloader) num_batches_val = len(val_dataloader) model.train() optimizer = torch.optim.SGD( model.parameters(), lr=starting_lr, momentum=momentum ) scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=max_lr, div_factor=10, epochs=epochs, steps_per_epoch=num_batches ) loss_function = BalancedLoss( loss_type="focal_loss", samples_per_class=samples_per_class(y), beta=beta, # class-balanced loss beta fl_gamma=gamma, # focal loss gamma class_balanced=class_balance, safe=True ) state_dicts = [] learning_curve = pd.DataFrame(columns=['loss', 'val_loss', 'lr']) for epoch in range(epochs): if hasattr(train_dataloader.dataset, 'set_epoch'): train_dataloader.dataset.set_epoch(epoch) model.train() cumulative_loss = 0 for xb, yb in train_dataloader: logits = model(xb) logits = torch.clamp(logits, 0, 1) loss = loss_function(logits, torch.argmax(yb, dim=-1).to(torch.int64)) loss.backward() cumulative_loss += loss.item() optimizer.step() scheduler.step() optimizer.zero_grad() cumulative_loss = cumulative_loss / num_batches model.eval() running_vloss = 0.0 if hasattr(val_dataloader.dataset, 'set_epoch'): val_dataloader.dataset.set_epoch(epoch) for xb, yb in val_dataloader: logits = model(xb) logits = torch.clamp(logits, 0, 1) val_loss = loss_function(logits, torch.argmax(yb, dim=-1).to(torch.int64)).item() running_vloss += val_loss val_loss = running_vloss / num_batches_val learning_curve.loc[epoch, ['loss', 'val_loss', 'lr']] = cumulative_loss, val_loss, scheduler.get_last_lr() state_dicts.append(deepcopy(model.state_dict())) model.load_state_dict(state_dicts[np.argmin(learning_curve['val_loss'].values)]) model.is_fitted = True return learning_curve
[docs] def fit_logreg(model: LogisticRegression, x, y, **fit_kwargs): fit = model.model.fit( x, y, ) model.is_fitted = True #model.labels_enc = {label: i for i, label in enumerate(model.model.classes_)} #model.labels_dec = {model.labels_enc[label]: label for label in model.labels_enc.keys()} return fit
[docs] def fit_trees(model: BoostedTrees, x, y, **fit_kwargs): x, x_val, y, y_val = train_test_split(x, y, train_size=0.75, random_state=42) if not np.all(np.isin(np.unique(y_val), np.unique(y))): # if the validation set contains labels not in the training set, remove them x_val = x_val[np.isin(y_val, np.unique(y))] y_val = np.array([label for label in y_val if label in np.unique(y)]) fit = model.model.fit( x, y, eval_set=[(x_val, y_val)], #**fit_kwargs ) model.is_fitted = True return fit
[docs] def fit( model: Union[DenseTorch, LogisticRegression, DummyClassifier], x: np.array, y: np.array, standardize_idx: list=None, **fit_kwargs): """Args: model (Union[DenseTorch, LogisticRegression, DummyClassifier]): Model to be fitted. x (np.array): Input data. y (np.array): Target data in the shape of a 1-dimensional array of label strings. Returns: _type_: _description_ """ # Standardize batches separately if list of idxs per dataset is provided if standardize_idx is not None: for idx in standardize_idx: x[idx] = robust_scale(x[idx], axis=1, with_centering=False, copy=False, unit_variance=True) else: x = robust_scale(x, axis=1, with_centering=False, copy=False, unit_variance=True) if not isinstance(model, DenseTorch): x = sparse.csr_matrix.toarray(x) y = np.array([model.labels_enc[label] for label in y]) if isinstance(model, DenseTorch): return fit_torch(model, x, y, **fit_kwargs) elif isinstance(model, LogisticRegression): return fit_logreg(model, x, y, **fit_kwargs) elif isinstance(model, BoostedTrees): return fit_trees(model, x, y, **fit_kwargs) elif isinstance(model, DummyClassifier): model.is_fitted = True return model.fit(x, y)