Source code for Compocyte.core.hierarchical_classifier

"""Hierarchical classifier implementation.

This module provides `HierarchicalClassifier`, a composite classifier that
manages a hierarchy of local classifiers for annotating single-cell data.
The class orchestrates feature selection, training and prediction across the
hierarchy and provides save/load utilities.
"""

from typing import Union
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.preprocessing import robust_scale
from Compocyte.core.base.data_base import DataBase
from Compocyte.core.base.hierarchy_base import HierarchyBase
from Compocyte.core.base.export_import_base import ExportImportBase
from Compocyte.core.models.dummy_classifier import DummyClassifier
from Compocyte.core.models.fit_methods import fit, predict
from Compocyte.core.models.log_reg import LogisticRegression
from Compocyte.core.models.dense_torch import DenseTorch
from time import time
from scipy import sparse
import numpy as np
import os
import pickle
import scanpy as sc
import multiprocessing as mp
from Compocyte.core.models.trees import BoostedTrees
from Compocyte.core.tools import z_transform_properties


[docs] class HierarchicalClassifier( DataBase, HierarchyBase, ExportImportBase): def __init__( self, save_path, adata=None, root_node=None, dict_of_cell_relations=None, obs_names=None, default_input_data='normlog', num_threads=1, ignore_counts=False, temp_path=None ): self.save_path = save_path self.default_input_data = default_input_data self.num_threads = num_threads self.adata = None self.var_names = None self.dict_of_cell_relations = None self.root_node = None self.obs_names = None self.ignore_counts = ignore_counts if type(adata) != type(None): self.load_adata(adata) if root_node is not None and dict_of_cell_relations is not None and obs_names is not None: self.set_cell_relations(root_node, dict_of_cell_relations, obs_names, temp_path)
[docs] def save(self, save_adata=False): # save all attributes # get types, for adata use adatas write function with hash of adata # save state of all local classifiers (what does dumping self.graph do?) # save state of all nodes in the graph, label encoders, var names ... data_path = os.path.join( self.save_path, 'data' ) timestamp = str(time()).replace('.', '_') hc_path = os.path.join( self.save_path, 'hierarchical_classifiers', timestamp ) if not os.path.exists(hc_path): os.makedirs(hc_path) settings_dict = {} for key in self.__dict__.keys(): if key == 'adata': if self.adata is None or not save_adata: continue if not os.path.exists(data_path): os.makedirs(data_path) self.adata.write(os.path.join(data_path, f'{timestamp}.h5ad')) elif key == 'graph': continue else: settings_dict[key] = self.__dict__[key] with open(os.path.join(hc_path, 'hierarchical_classifier_settings.pickle'), 'wb') as f: pickle.dump(settings_dict, f) for node in list(self.graph): node_content_path = os.path.join( self.save_path, 'node_content', node, timestamp ) if not os.path.exists(node_content_path): os.makedirs(node_content_path) for key in self.graph.nodes[node].keys(): if key == 'local_classifier': model_path = os.path.join( self.save_path, 'models', node, timestamp ) if not os.path.exists(model_path): os.makedirs(model_path) local_classifier = self.graph.nodes[node]['local_classifier'] if isinstance(local_classifier, DenseTorch) or isinstance(local_classifier, LogisticRegression) or isinstance(local_classifier, DummyClassifier) or isinstance(local_classifier, BoostedTrees): self.graph.nodes[node]['local_classifier']._save(model_path) continue with open(os.path.join(node_content_path, f'{key}.pickle'), 'wb') as f: pickle.dump(self.graph.nodes[node][key], f)
[docs] def load(self, load_path=None, load_adata=False): if load_path is None: load_path = self.save_path data_path = os.path.join( load_path, 'data' ) hc_path = os.path.join( load_path, 'hierarchical_classifiers' ) if os.path.exists(hc_path): timestamps = os.listdir(hc_path) last_timestamp = sorted(timestamps)[-1] with open(os.path.join(hc_path, last_timestamp, 'hierarchical_classifier_settings.pickle'), 'rb') as f: settings_dict = pickle.load(f) for key in settings_dict.keys(): self.__dict__[key] = settings_dict[key] if os.path.exists(data_path) and load_adata: timestamps = os.listdir(data_path) last_adata = sorted(timestamps)[-1] adata = sc.read_h5ad(os.path.join(data_path, last_adata)) self.load_adata(adata) if not hasattr(self, 'graph') or self.graph is None: self.make_classifier_graph() for node in list(self.graph): model_path = os.path.join( load_path, 'models', node ) node_content_path = os.path.join( load_path, 'node_content', node ) if os.path.exists(model_path): timestamps = os.listdir(model_path) last_timestamp = sorted(timestamps)[-1] contents = os.listdir(os.path.join(model_path, last_timestamp)) if len([c for c in contents if c.startswith('non_param_dict')]) > 0: classifier = DenseTorch._load(os.path.join(model_path, last_timestamp)) elif 'labels_dec.pickle' in contents and not 'model.cbm' in contents: classifier = LogisticRegression._load(os.path.join(model_path, last_timestamp)) elif 'labels_dec.pickle' in contents and 'model.cbm' in contents: classifier = BoostedTrees._load(os.path.join(model_path, last_timestamp)) else: classifier = DummyClassifier._load(os.path.join(model_path, last_timestamp)) self.graph.nodes[node]['local_classifier'] = classifier if os.path.exists(node_content_path): timestamps = os.listdir(node_content_path) last_timestamp = sorted(timestamps)[-1] properties = os.listdir(os.path.join(node_content_path, last_timestamp)) for p in properties: key = p.replace('.pickle', '') with open( os.path.join(node_content_path, last_timestamp, p), 'rb' ) as f: p = pickle.load(f) self.graph.nodes[node][key] = p
[docs] def limit_cells( self, subset: sc.AnnData, max_cells: int, stratify_by: str) -> sc.AnnData: if len(subset) > max_cells: rng = np.random.default_rng(42) datasets = subset.obs[stratify_by].unique() cells_per_dataset = max_cells // len(datasets) limited_indices = [] for dataset in datasets: dataset_indices = subset.obs[subset.obs['dataset'] == dataset].index if len(dataset_indices) > cells_per_dataset: dataset_indices = rng.choice(dataset_indices, cells_per_dataset, replace=False) limited_indices.extend(dataset_indices) # If we have fewer cells than max_cells, randomly sample additional cells to reach max_cells if len(limited_indices) < max_cells: additional_indices = rng.choice( subset.obs.index.difference(limited_indices), max_cells - len(limited_indices), replace=False) limited_indices.extend(additional_indices) subset = subset[limited_indices, :] return subset
[docs] def introduce_limit(self, max_cells: int, stratify_by: str): """ Introduces a limit on the number of cells per local classifier training and \ specifies a stratification criterion. Parameters: max_cells (int): The maximum number of cells allowed. stratify_by (str): The criterion by which to stratify the cells. """ self.max_cells = max_cells self.stratify_by = stratify_by
[docs] def select_subset( self, node: str, features: list=None, max_cells: int=None) -> sc.AnnData: obs = self.obs_names[self.node_to_depth[node]] child_obs = self.obs_names[self.node_to_depth[node] + 1] is_node = self.adata.obs[obs] == node has_child_label = self.adata.obs[child_obs] != '' subset = self.adata[is_node & has_child_label] if features is not None: subset = subset[:, features] stratify_by = getattr(self, 'stratify_by', None) if max_cells is not None and stratify_by is not None: subset = self.limit_cells(subset, max_cells, stratify_by) """ if max_cells is None: max_cells = getattr(self, 'max_cells', None) stratify_by = getattr(self, 'stratify_by', None) if max_cells is not None and stratify_by is not None: subset = self.limit_cells(subset, max_cells, stratify_by) """ return subset
[docs] def select_subset_prediction(self, node: str, features: list=None, for_trial=False) -> sc.AnnData: obs = self.obs_names[self.node_to_depth[node]] obs = f'{obs}_pred' if obs not in self.adata.obs.columns and not for_trial: subset = self.adata elif obs not in self.adata.obs.columns and for_trial: is_node = self.adata.obs[self.obs_names[self.node_to_depth[node]]] == node subset = self.adata[is_node] else: is_node = self.adata.obs[obs] == node subset = self.adata[is_node] if features is not None: subset = subset[:, features] return subset
[docs] def run_feature_selection( self, node: str, overwrite: bool=False, n_features: int=-1, max_features: int=None, min_features: int=30, test_factor: float=1.0, max_cells=100_000): has_features = 'selected_var_names' in self.graph.nodes[node].keys() if has_features and not overwrite: raise Exception(f'Features have already been selected at {node}.') subset = self.select_subset(node, max_cells=max_cells) x = sparse.csr_matrix.toarray(subset.X) child_obs = self.obs_names[self.node_to_depth[node] + 1] if len(subset.obs[child_obs].unique()) <= 1: return self.adata.var_names.tolist() # Rule of thumb from Google's rules of ML: # At least 100 samples per feature if n_features < 0: # test_factor should be taken account during hypopt with reduced sample numbers n_features = int(len(subset) / 100 * test_factor) n_features = max(min_features, n_features) if max_features is None: max_features = len(self.adata.var_names) n_features = min(n_features, max_features) x = np.asarray(x) x = robust_scale(x, axis=1, with_centering=False, copy=False, unit_variance=True) y = np.array(subset.obs[child_obs]) selecter = SelectKBest(f_classif, k=n_features) selecter.fit(x, y) features = self.adata.var_names[selecter.get_support()] return features.tolist()
[docs] def create_local_classifier( self, node: str, overwrite: bool=False, classifier_type: Union[DenseTorch, LogisticRegression, BoostedTrees]=DenseTorch, **classifier_kwargs): has_classifier = 'local_classifier' in self.graph.nodes[node].keys() if has_classifier and not overwrite: raise Exception(f'A classifier already exists at {node}.') features = self.graph.nodes[node].get('selected_var_names', None) if features is None: raise Exception(f'Cannot create classifier at {node} without features.\ Please run run_feature_selection first.') subset = self.select_subset(node) child_obs = self.obs_names[self.node_to_depth[node] + 1] labels = subset.obs[child_obs].unique().tolist() n_input = len(features) n_output = len(labels) if isinstance(classifier_type, str): if classifier_type == 'DenseTorch': classifier_type = DenseTorch elif classifier_type == 'LogisticRegression': classifier_type = LogisticRegression elif classifier_type == 'BoostedTrees': classifier_type = BoostedTrees else: raise Exception(f'Unknown classifier type: {classifier_type}') if n_output == 1: classifier_type = DummyClassifier local_classifier = classifier_type( labels, n_input=n_input, n_output=n_output, **classifier_kwargs) self.graph.nodes[node]['local_classifier'] = local_classifier
[docs] def train_single_node(self, node, standardize_separately: str=None, **fit_kwargs): if not hasattr(self, 'num_threads') and not 'num_threads' in fit_kwargs: raise Exception('Please specify the number of threads to use for training.') elif 'num_threads' in fit_kwargs: self.num_threads = fit_kwargs['num_threads'] has_classifier = 'local_classifier' in self.graph.nodes[node].keys() # This weird approach is currently necessary to allow for training with mp.pool if hasattr(self, 'tuned_kwargs') and node in self.tuned_kwargs: kwargs = self.tuned_kwargs[node] features_kwargs = { 'n_features': kwargs['n_features'] } classifier_kwargs = { 'hidden_layers': kwargs['hidden_layers'], 'dropout': kwargs['dropout'], } fit_kwargs = { 'epochs': kwargs['epochs'], 'batch_size': kwargs['batch_size'], 'starting_lr': kwargs['starting_lr'], 'max_lr': kwargs['max_lr'], 'momentum': kwargs['momentum'], 'beta': kwargs['beta'], 'gamma': kwargs['gamma'], 'max_cells': getattr(self, 'max_cells', 1_000_000) } self.graph.nodes[node]['threshold'] = kwargs['threshold'] else: features_kwargs = {} classifier_kwargs = {} if not has_classifier: subset = self.select_subset(node) if len(subset) < 5: return features = self.run_feature_selection(node, **features_kwargs) self.graph.nodes[node]['selected_var_names'] = features classifier_type = DenseTorch hidden_layers = classifier_kwargs.get('hidden_layers', []) if -1 in hidden_layers: classifier_type = BoostedTrees # If classifier types other than the standard have been set, use those specified_classifier_types = getattr(self, 'specified_classifier_types', {}) classifier_type = specified_classifier_types.get(node, classifier_type) self.create_local_classifier(node, classifier_type=classifier_type, **classifier_kwargs) child_obs = self.obs_names[self.node_to_depth[node] + 1] features = self.graph.nodes[node]['selected_var_names'] subset = self.select_subset(node, features=features) if len(subset) == 0: return model = self.graph.nodes[node]['local_classifier'] x = subset.X y = subset.obs[child_obs].values print(f'Training at {node}.') if standardize_separately is not None: idx = [] for dataset in subset.obs[standardize_separately].unique(): idx.append(np.where(subset.obs[standardize_separately] == dataset)) else: idx = None if not 'max_cells' in fit_kwargs: fit_kwargs['max_cells'] = getattr(self, 'max_cells', 1_000_000) if not 'num_threads' in fit_kwargs: fit_kwargs['num_threads'] = self.num_threads # Necessary to avoid data loss when using mp.pool return { **self.graph.nodes[node], 'learning_curve': fit(model, x, y, standardize_idx=idx, **fit_kwargs) }
[docs] def set_classifier_type(self, node, classifier_type): if isinstance(node, list): for n in node: self.set_classifier_type(n, classifier_type) else: if not hasattr(self, 'specified_classifier_types'): self.specified_classifier_types = {} self.specified_classifier_types[node] = classifier_type
[docs] def train_all_child_nodes( self, parallelize: bool=False, processes: int=None) -> None: nodes_to_train = [] for node in self.graph.nodes: n_children = len(list(self.graph.successors(node))) if n_children >= 1: nodes_to_train.append(node) if not parallelize: for node in nodes_to_train: self.train_single_node(node, parallelize=False) else: if processes is None: raise Exception('Please specify the number of processes to use for parallelization.') # When setting num_threads > 1 per training process, the number of processes should be limited if self.num_threads is not None: processes = int(processes / self.num_threads) print(f"Using multiprocessing for training with {mp.cpu_count()} available CPU cores.\n") with mp.Pool(processes=processes) as pool: all_trained_node_params = pool.map(self.train_single_node, nodes_to_train) for node, params in zip(nodes_to_train, all_trained_node_params): if params is not None: #this should only happen at nodes that have not been trained for key in params.keys(): if params.get(key) is not None: self.graph.nodes[node][key] = params.get(key)
[docs] def predict_single_node( self, node: str, threshold: float=-1, monte_carlo: int=None) -> np.array: if 'local_classifier' not in self.graph.nodes[node]: return [] features = self.graph.nodes[node]['selected_var_names'] subset = self.select_subset_prediction(node, features=features) if len(subset) == 0: return model = self.graph.nodes[node]['local_classifier'] x = subset.X print(f'Predicting at {node}.') pred = predict(model, x, threshold=threshold, monte_carlo=monte_carlo) all_logits = None if monte_carlo is not None and isinstance(pred, tuple): pred, all_logits = pred if len(all_logits.shape) < 3: all_logits = np.expand_dims(all_logits, axis=1) if 'overclustering' in subset.obs.columns: for cluster_name in subset.obs['overclustering'].unique(): cluster_indices = subset.obs['overclustering'] == cluster_name if np.sum(cluster_indices) > 0: cluster_preds = pred[cluster_indices] if len(cluster_preds) > 0: most_common = max(set(cluster_preds), key=list(cluster_preds).count) pred[cluster_indices] = most_common child_obs = self.obs_names[self.node_to_depth[node] + 1] child_obs = f'{child_obs}_pred' if child_obs not in self.adata.obs.columns: self.adata.obs[child_obs] = '' self.adata.obs[child_obs] = self.adata.obs[child_obs].astype(str) self.adata.obs.loc[ subset.obs_names, child_obs ] = pred if monte_carlo is not None and all_logits is not None: # all_logits: Shape: (iterations, samples, labels) # mean_activations_per_sample: Shape: (samples, labels) mean_activations_per_sample = np.mean(all_logits, axis=0) # idx_max_activations: Shape: (samples) idx_max_activations = np.argmax(mean_activations_per_sample, axis=1) # idx_tile: Shape: (iterations, samples) idx_tile = np.tile(idx_max_activations, (all_logits.shape[0], 1)) # Expand idx_tile dims to match test dimensions # For each iteration and sample, take the activation corresponding to # the label with the highest mean activation across iterations for this sample # activations_chosen_label_per_iteration: Shape: (iterations, samples, 1) activations_chosen_label_per_iteration = np.take_along_axis( all_logits, np.expand_dims(idx_tile, axis=2), axis=2) # activations_chosen_label_per_sample: Shape: (samples, iterations) activations_chosen_label_per_sample = np.squeeze(activations_chosen_label_per_iteration).T # when dealing with single samples, activations are squeezed to 1D axis = 0 if activations_chosen_label_per_sample.ndim == 1 else 1 # mean_activation_chosen_label_per_sample: Shape: (samples) mean_activation_chosen_label_per_sample = np.mean(activations_chosen_label_per_sample, axis=axis) std_activations_chosen_label_per_sample = np.std(activations_chosen_label_per_sample, axis=axis) self.adata.obs.loc[ subset.obs_names, 'monte_carlo_mean', ] = mean_activation_chosen_label_per_sample self.adata.obs.loc[ subset.obs_names, 'monte_carlo_std', ] = std_activations_chosen_label_per_sample return pred
[docs] def predict_all_child_nodes( self, node: str, threshold: float=-1, mlnp: bool=False, monte_carlo: int=None): # For mandatory leaf node prediction use -1 if not mlnp: threshold = self.graph.nodes[node].get('threshold', threshold) self.predict_single_node(node, threshold=threshold, monte_carlo=monte_carlo) for child_node in self.get_child_nodes(node): if len(self.get_child_nodes(child_node)) == 0: continue self.predict_all_child_nodes(child_node, threshold=threshold, mlnp=mlnp, monte_carlo=monte_carlo)