Source code for Compocyte.core.tools

import anndata
import numpy as np
from scipy import sparse
import networkx as nx
from copy import deepcopy
import pandas as pd

[docs] def set_node_to_depth(dictionary, depth=0, node_to_depth={}): for node in dictionary.keys(): node_to_depth = set_node_to_depth(dictionary[node], depth=depth+1) node_to_depth[node] = depth return node_to_depth
[docs] def is_counts(matrix, n_rows_to_try=100): """Determines whether or not a matrix (such as adata.X, adata.raw.X or an adata layer) contains count data by manually checking a subsample of the supplied matrix. """ test_data = matrix[:n_rows_to_try] test_data = sparse.csr_matrix.toarray(test_data) contains_negative_values = np.any(test_data < 0) contains_non_whole_numbers = np.any(test_data % 1 != 0) return not contains_negative_values and not contains_non_whole_numbers
[docs] def dict_depth(dictionary, running_count=0): if not type(dictionary) == dict: raise TypeError() elif len(dictionary.keys()) == 0: return running_count running_counts_subdicts = [] for key in dictionary.keys(): if key == 'classifier': continue running_counts_subdicts.append( dict_depth( dictionary[key], running_count)) return max(running_counts_subdicts) + 1
[docs] def infer_levels(hierarchy, labels, root_node, adata=None, prefix_obs='Level_'): """ Infer hierarchical levels for labels based on a hierarchy graph. This function takes a hierarchy (either as a dict or NetworkX DiGraph) and assigns hierarchical levels to labels by finding the shortest path from a root node to each label in the hierarchy graph. Parameters ---------- hierarchy : dict or nx.DiGraph The hierarchy structure. Can be either a dictionary of edges or a NetworkX directed graph. If a dict is provided, it will be converted to a DiGraph. labels : str, list, or array-like The labels to infer levels for. Can be: - A string key referring to a column in adata.obs - A list of labels - An array-like object (with tolist() method) of labels root_node : str or int The root node of the hierarchy from which to compute shortest paths. adata : anndata.AnnData, optional An AnnData object. If provided, the obs dataframe from this object will be used as the base dataframe. If None, a new empty DataFrame is created. Default is None. prefix_obs : str, optional Prefix for the level column names. Default is 'Level_'. Level columns will be named 'Level_0', 'Level_1', etc. Returns ------- pd.DataFrame A DataFrame with the original labels and new columns for each hierarchical level, containing the nodes at each level of the hierarchy for each label. list A list of the new level column names. Notes ----- - Paths shorter than the maximum depth are padded with empty strings. - The number of levels is determined by the depth of the hierarchy. """ if isinstance(hierarchy, dict): graph = nx.DiGraph() make_graph_from_edges(hierarchy, graph) elif isinstance(hierarchy, nx.DiGraph): graph = hierarchy hierarchy = infer_dict(graph) else: raise TypeError('Hierarchy must be provided as a dict or a NetworkX DiGraph.') if adata is not None and isinstance(adata, anndata.AnnData): obs = adata.obs else: obs = pd.DataFrame() labels_key = 'label' if isinstance(labels, str) and labels in obs.columns: labels_key = labels elif hasattr(labels, 'tolist'): labels = labels.tolist() obs[labels_key] = labels elif isinstance(labels, list): obs[labels_key] = labels else: raise TypeError('Labels must be provided as a list.') depth = dict_depth(hierarchy) levels = [f'{prefix_obs}{i}' for i in range(depth)] for level in levels: obs[level] = '' for label in obs[labels_key].unique(): path = nx.shortest_path(graph, root_node, label) path = path + [''] * (depth - len(path)) obs.loc[obs[labels_key] == label, levels] = path return obs, levels
[docs] def flatten_dict(dictionary, running_list_of_values=[]): if not type(dictionary) == dict: raise TypeError() elif len(dictionary.keys()) == 0: return running_list_of_values else: for key in dictionary.keys(): if key == 'classifier': continue running_list_of_values = running_list_of_values + flatten_dict(dictionary[key]) + [key] return running_list_of_values
[docs] def hierarchy_names_unique(hierarchy_dict): all_nodes = flatten_dict(hierarchy_dict) return len(all_nodes) == len(set(all_nodes))
[docs] def z_transform_properties(data_arr, discretization=False): """Calculates a z transformation to center properties across cells in data_arr \ around mean zero """ mean_vals = np.mean(data_arr, axis=0) std_val = np.std(data_arr) data_transformed = (data_arr - mean_vals) / std_val bin_boundaries = [-0.675, 0, 0.675] if discretization: data_transformed = np.digitize(data_transformed, bin_boundaries) return np.array(data_transformed)
[docs] def make_graph_from_edges(d, g, parent_key=''): """Add explanation """ for key in d.keys(): if parent_key != '': g.add_edge(parent_key, key) if len(d[key]) == 0: pass else: make_graph_from_edges(d[key], g, parent_key=key)
[docs] def get_last_annotation(obs_names, adata, barcodes=None, true_only=False): if barcodes is None: barcodes = adata.obs_names obs_names_pred = [f'{x}_pred' for x in obs_names] for i, (true_key, pred_key) in enumerate(zip(obs_names, obs_names_pred)): if i == 0: if true_only: obs_df = adata.obs.loc[barcodes, [true_key]] else: obs_df = adata.obs.loc[barcodes, [true_key, pred_key]] obs_df = obs_df[~obs_df[true_key].isin([np.nan, "", "nan"])] obs_df.rename(columns={true_key: 'true_last'}, inplace=True) if not true_only: obs_df = obs_df[~obs_df[pred_key].isin([np.nan, "", "nan"])] obs_df.rename(columns={pred_key: 'pred_last'}, inplace=True) obs_df = obs_df.astype(str) else: if true_only: obs_df_level = adata.obs.loc[barcodes, [true_key]] obs_df_level.rename(columns={true_key: 'true_last'}, inplace=True) else: obs_df_level = adata.obs.loc[barcodes, [true_key, pred_key]] obs_df_level.rename(columns={true_key: 'true_last', pred_key: 'pred_last'}, inplace=True) obs_df_level = obs_df_level.astype(str) obs_df_level_true = obs_df_level[~obs_df_level["true_last"].isin([np.nan, "", "nan"])] level_barcodes_true = [x for x in obs_df_level_true.index if x in obs_df.index] obs_df.loc[level_barcodes_true, 'true_last'] = obs_df_level_true.loc[level_barcodes_true, 'true_last'] if not true_only: obs_df_level_pred = obs_df_level[~obs_df_level["pred_last"].isin([np.nan, "", "nan"])] level_barcodes_pred = [x for x in obs_df_level_pred.index if x in obs_df.index] obs_df.loc[level_barcodes_pred, 'pred_last'] = obs_df_level_pred.loc[level_barcodes_pred, 'pred_last'] return obs_df
[docs] def get_leaf_nodes(hierarchy): leaf_nodes = [] for node in hierarchy.keys(): if len(hierarchy[node].keys()) != 0: leaf_nodes += get_leaf_nodes(hierarchy[node]) else: leaf_nodes += [node] return leaf_nodes
[docs] def delete_dict_entries(dictionary, del_key='classifier', first_run=True, deleted_key=False): if first_run: dictionary = deepcopy(dictionary) keys = list(dictionary.keys()) for key in keys: if key == del_key: del dictionary[key] deleted_key = True else: dictionary[key], deleted_key = delete_dict_entries( dictionary[key], del_key=del_key, first_run=False, deleted_key=deleted_key) return dictionary, deleted_key
[docs] def flatten_labels(pred_h_labels, graph, root_node, verbose=False): pred_h_labels[:, 0] = root_node # Some predictions did not have the root label as their first value # Calculates extent of intersections between predicted labels and valid labels as per the provided graph by cell in_graph = np.isin(pred_h_labels, graph.nodes) n_valid_labels = np.sum(in_graph, axis=1) if verbose: invalid_labels = np.unique( pred_h_labels[~np.isin(pred_h_labels, graph.nodes)] ).tolist() print(f'The hierarchical annotations contained {len(invalid_labels)} invalid labels:\n{invalid_labels}.\nThis is only problematic if these are labels you intended to be counted as valid.') # The last valid label per cell is at n_valid_labels - 1 assuming valid labels start at index 0 idx_last_valid_label = np.fmax( n_valid_labels - 1, np.zeros(shape=n_valid_labels.shape) ).astype(int) pred_labels_flat = np.take_along_axis( pred_h_labels, idx_last_valid_label[:, np.newaxis], axis = 1 ) pred_labels_flat = np.squeeze(pred_labels_flat) return pred_labels_flat
[docs] class Hierarchical_Metric():
[docs] def __init__(self, true_labels, predicted_labels, hierarchy_structure, root_node='Blood'): '''hierarchy_structure: NetworkX graph of hierarchical classifier''' self.true_labels = np.array(true_labels) self.predicted_labels = np.array(predicted_labels) self.hierarchy_structure = hierarchy_structure self.root_node = root_node self.augmented_lookups = {} self.intersect_lookups = {}
[docs] def augmented_set_of_node_n(self, node): '''Assuming a tree hierarchy structure, ancestors of node n, including node, excluding root''' if node not in self.hierarchy_structure.nodes: node = self.root_node if node not in self.augmented_lookups.keys(): # avoid having to call nx ancestors for every single true and predicted label ancestors = nx.shortest_path(self.hierarchy_structure, self.root_node, node) self.augmented_lookups[node] = np.array(ancestors) return self.augmented_lookups[node]
[docs] def calculate_intersects(self, t_label, p_label, t_label_augmented, p_label_augmented): cardinality_intersect_t_p = len(np.intersect1d(t_label_augmented, p_label_augmented)) cardinality_p_label_augmented = len(p_label_augmented) #test for over specialization and in case cut augmented p to len of augmented true if cardinality_intersect_t_p == len(t_label_augmented): cardinality_p_label_augmented = cardinality_intersect_t_p if t_label not in self.intersect_lookups.keys(): self.intersect_lookups[t_label] = {} self.intersect_lookups[t_label][p_label] = (cardinality_intersect_t_p, cardinality_p_label_augmented)
[docs] def hP(self): numerator = [] denominator = [] for t_label, p_label in zip(self.true_labels, self.predicted_labels): if not (t_label in self.intersect_lookups.keys() and p_label in self.intersect_lookups[t_label].keys()): t_label_augmented = self.augmented_set_of_node_n(t_label) p_label_augmented = self.augmented_set_of_node_n(p_label) self.calculate_intersects(t_label, p_label, t_label_augmented, p_label_augmented) cardinality_intersect_t_p, cardinality_p_label_augmented = self.intersect_lookups[t_label][p_label] numerator.append(cardinality_intersect_t_p) denominator.append(cardinality_p_label_augmented) return np.sum(np.array(numerator)) / np.sum(np.array(denominator))
[docs] def hR(self): numerator = [] denominator = [] for t_label, p_label in zip(self.true_labels, self.predicted_labels): t_label_augmented = self.augmented_set_of_node_n(t_label) p_label_augmented = self.augmented_set_of_node_n(p_label) if not (t_label in self.intersect_lookups.keys() and p_label in self.intersect_lookups[t_label].keys()): self.calculate_intersects(t_label, p_label, t_label_augmented, p_label_augmented) cardinality_intersect_t_p, _ = self.intersect_lookups[t_label][p_label] cardinality_t_label_augmented = len(t_label_augmented) numerator.append(cardinality_intersect_t_p) denominator.append(cardinality_t_label_augmented) return np.sum(np.array(numerator)) / np.sum(np.array(denominator))
[docs] def hF(self, beta): hP = self.hP() hR = self.hR() hF = (beta**2 + 1) * hP * hR / (beta**2 * hP + hR) return hF
[docs] def macro_hF(self, beta): '''Macro averaged hF-Score (average of micro hF1's for each label)''' labels = pd.Series(self.true_labels).value_counts().keys() label_Fb = [] for label in labels: true_label_idcs = np.where(self.true_labels == label)[0] hP = self.hP() hR = self.hR() Fb = (beta**2 + 1) * hP * hR / (beta**2 * hP + hR) label_Fb.append(Fb) return np.sum(np.array(label_Fb))/len(labels)
[docs] def list_micro_metrics(self, beta): label_metrics = pd.DataFrame(columns=[f'hF{beta}', 'hR', 'hP']) for label in np.unique(self.true_labels): true_label_idcs = np.where(self.true_labels == label)[0] hP = self.hP() hR = self.hR() Fb = (beta**2 + 1) * hP * hR / (beta**2 * hP + hR) label_metrics.loc[label] = [np.round(Fb, 2), np.round(hR, 2), np.round(hP, 2)] return label_metrics
[docs] def infer_dict(graph, parent=None): dict_of_cell_relations = {} if parent is None: for node in graph.nodes: if len(list(graph.predecessors(node))) == 0: parent = node dict_of_cell_relations[parent] = infer_dict( graph, parent=parent ) else: for parent_node, child_node in graph.edges: if parent_node == parent: dict_of_cell_relations[child_node] = infer_dict( graph, parent=child_node ) return dict_of_cell_relations