import anndata
import numpy as np
from scipy import sparse
import networkx as nx
from copy import deepcopy
import pandas as pd
[docs]
def set_node_to_depth(dictionary, depth=0, node_to_depth={}):
for node in dictionary.keys():
node_to_depth = set_node_to_depth(dictionary[node], depth=depth+1)
node_to_depth[node] = depth
return node_to_depth
[docs]
def is_counts(matrix, n_rows_to_try=100):
"""Determines whether or not a matrix (such as adata.X, adata.raw.X or an adata layer) contains
count data by manually checking a subsample of the supplied matrix.
"""
test_data = matrix[:n_rows_to_try]
test_data = sparse.csr_matrix.toarray(test_data)
contains_negative_values = np.any(test_data < 0)
contains_non_whole_numbers = np.any(test_data % 1 != 0)
return not contains_negative_values and not contains_non_whole_numbers
[docs]
def dict_depth(dictionary, running_count=0):
if not type(dictionary) == dict:
raise TypeError()
elif len(dictionary.keys()) == 0:
return running_count
running_counts_subdicts = []
for key in dictionary.keys():
if key == 'classifier':
continue
running_counts_subdicts.append(
dict_depth(
dictionary[key],
running_count))
return max(running_counts_subdicts) + 1
[docs]
def infer_levels(hierarchy, labels, root_node, adata=None, prefix_obs='Level_'):
"""
Infer hierarchical levels for labels based on a hierarchy graph.
This function takes a hierarchy (either as a dict or NetworkX DiGraph) and assigns
hierarchical levels to labels by finding the shortest path from a root node to each
label in the hierarchy graph.
Parameters
----------
hierarchy : dict or nx.DiGraph
The hierarchy structure. Can be either a dictionary of edges or a NetworkX
directed graph. If a dict is provided, it will be converted to a DiGraph.
labels : str, list, or array-like
The labels to infer levels for. Can be:
- A string key referring to a column in adata.obs
- A list of labels
- An array-like object (with tolist() method) of labels
root_node : str or int
The root node of the hierarchy from which to compute shortest paths.
adata : anndata.AnnData, optional
An AnnData object. If provided, the obs dataframe from this object will be used
as the base dataframe. If None, a new empty DataFrame is created. Default is None.
prefix_obs : str, optional
Prefix for the level column names. Default is 'Level_'.
Level columns will be named 'Level_0', 'Level_1', etc.
Returns
-------
pd.DataFrame
A DataFrame with the original labels and new columns for each hierarchical level,
containing the nodes at each level of the hierarchy for each label.
list
A list of the new level column names.
Notes
-----
- Paths shorter than the maximum depth are padded with empty strings.
- The number of levels is determined by the depth of the hierarchy.
"""
if isinstance(hierarchy, dict):
graph = nx.DiGraph()
make_graph_from_edges(hierarchy, graph)
elif isinstance(hierarchy, nx.DiGraph):
graph = hierarchy
hierarchy = infer_dict(graph)
else:
raise TypeError('Hierarchy must be provided as a dict or a NetworkX DiGraph.')
if adata is not None and isinstance(adata, anndata.AnnData):
obs = adata.obs
else:
obs = pd.DataFrame()
labels_key = 'label'
if isinstance(labels, str) and labels in obs.columns:
labels_key = labels
elif hasattr(labels, 'tolist'):
labels = labels.tolist()
obs[labels_key] = labels
elif isinstance(labels, list):
obs[labels_key] = labels
else:
raise TypeError('Labels must be provided as a list.')
depth = dict_depth(hierarchy)
levels = [f'{prefix_obs}{i}' for i in range(depth)]
for level in levels:
obs[level] = ''
for label in obs[labels_key].unique():
path = nx.shortest_path(graph, root_node, label)
path = path + [''] * (depth - len(path))
obs.loc[obs[labels_key] == label, levels] = path
return obs, levels
[docs]
def flatten_dict(dictionary, running_list_of_values=[]):
if not type(dictionary) == dict:
raise TypeError()
elif len(dictionary.keys()) == 0:
return running_list_of_values
else:
for key in dictionary.keys():
if key == 'classifier':
continue
running_list_of_values = running_list_of_values + flatten_dict(dictionary[key]) + [key]
return running_list_of_values
[docs]
def hierarchy_names_unique(hierarchy_dict):
all_nodes = flatten_dict(hierarchy_dict)
return len(all_nodes) == len(set(all_nodes))
[docs]
def make_graph_from_edges(d, g, parent_key=''):
"""Add explanation
"""
for key in d.keys():
if parent_key != '':
g.add_edge(parent_key, key)
if len(d[key]) == 0:
pass
else:
make_graph_from_edges(d[key], g, parent_key=key)
[docs]
def get_last_annotation(obs_names, adata, barcodes=None, true_only=False):
if barcodes is None:
barcodes = adata.obs_names
obs_names_pred = [f'{x}_pred' for x in obs_names]
for i, (true_key, pred_key) in enumerate(zip(obs_names, obs_names_pred)):
if i == 0:
if true_only:
obs_df = adata.obs.loc[barcodes, [true_key]]
else:
obs_df = adata.obs.loc[barcodes, [true_key, pred_key]]
obs_df = obs_df[~obs_df[true_key].isin([np.nan, "", "nan"])]
obs_df.rename(columns={true_key: 'true_last'}, inplace=True)
if not true_only:
obs_df = obs_df[~obs_df[pred_key].isin([np.nan, "", "nan"])]
obs_df.rename(columns={pred_key: 'pred_last'}, inplace=True)
obs_df = obs_df.astype(str)
else:
if true_only:
obs_df_level = adata.obs.loc[barcodes, [true_key]]
obs_df_level.rename(columns={true_key: 'true_last'}, inplace=True)
else:
obs_df_level = adata.obs.loc[barcodes, [true_key, pred_key]]
obs_df_level.rename(columns={true_key: 'true_last', pred_key: 'pred_last'}, inplace=True)
obs_df_level = obs_df_level.astype(str)
obs_df_level_true = obs_df_level[~obs_df_level["true_last"].isin([np.nan, "", "nan"])]
level_barcodes_true = [x for x in obs_df_level_true.index if x in obs_df.index]
obs_df.loc[level_barcodes_true, 'true_last'] = obs_df_level_true.loc[level_barcodes_true, 'true_last']
if not true_only:
obs_df_level_pred = obs_df_level[~obs_df_level["pred_last"].isin([np.nan, "", "nan"])]
level_barcodes_pred = [x for x in obs_df_level_pred.index if x in obs_df.index]
obs_df.loc[level_barcodes_pred, 'pred_last'] = obs_df_level_pred.loc[level_barcodes_pred, 'pred_last']
return obs_df
[docs]
def get_leaf_nodes(hierarchy):
leaf_nodes = []
for node in hierarchy.keys():
if len(hierarchy[node].keys()) != 0:
leaf_nodes += get_leaf_nodes(hierarchy[node])
else:
leaf_nodes += [node]
return leaf_nodes
[docs]
def delete_dict_entries(dictionary, del_key='classifier', first_run=True, deleted_key=False):
if first_run:
dictionary = deepcopy(dictionary)
keys = list(dictionary.keys())
for key in keys:
if key == del_key:
del dictionary[key]
deleted_key = True
else:
dictionary[key], deleted_key = delete_dict_entries(
dictionary[key],
del_key=del_key,
first_run=False,
deleted_key=deleted_key)
return dictionary, deleted_key
[docs]
def flatten_labels(pred_h_labels, graph, root_node, verbose=False):
pred_h_labels[:, 0] = root_node # Some predictions did not have the root label as their first value
# Calculates extent of intersections between predicted labels and valid labels as per the provided graph by cell
in_graph = np.isin(pred_h_labels, graph.nodes)
n_valid_labels = np.sum(in_graph, axis=1)
if verbose:
invalid_labels = np.unique(
pred_h_labels[~np.isin(pred_h_labels, graph.nodes)]
).tolist()
print(f'The hierarchical annotations contained {len(invalid_labels)} invalid labels:\n{invalid_labels}.\nThis is only problematic if these are labels you intended to be counted as valid.')
# The last valid label per cell is at n_valid_labels - 1 assuming valid labels start at index 0
idx_last_valid_label = np.fmax(
n_valid_labels - 1,
np.zeros(shape=n_valid_labels.shape)
).astype(int)
pred_labels_flat = np.take_along_axis(
pred_h_labels,
idx_last_valid_label[:, np.newaxis],
axis = 1
)
pred_labels_flat = np.squeeze(pred_labels_flat)
return pred_labels_flat
[docs]
class Hierarchical_Metric():
[docs]
def __init__(self, true_labels, predicted_labels, hierarchy_structure, root_node='Blood'):
'''hierarchy_structure: NetworkX graph of hierarchical classifier'''
self.true_labels = np.array(true_labels)
self.predicted_labels = np.array(predicted_labels)
self.hierarchy_structure = hierarchy_structure
self.root_node = root_node
self.augmented_lookups = {}
self.intersect_lookups = {}
[docs]
def augmented_set_of_node_n(self, node):
'''Assuming a tree hierarchy structure, ancestors of node n, including node, excluding root'''
if node not in self.hierarchy_structure.nodes:
node = self.root_node
if node not in self.augmented_lookups.keys(): # avoid having to call nx ancestors for every single true and predicted label
ancestors = nx.shortest_path(self.hierarchy_structure, self.root_node, node)
self.augmented_lookups[node] = np.array(ancestors)
return self.augmented_lookups[node]
[docs]
def calculate_intersects(self, t_label, p_label, t_label_augmented, p_label_augmented):
cardinality_intersect_t_p = len(np.intersect1d(t_label_augmented, p_label_augmented))
cardinality_p_label_augmented = len(p_label_augmented)
#test for over specialization and in case cut augmented p to len of augmented true
if cardinality_intersect_t_p == len(t_label_augmented):
cardinality_p_label_augmented = cardinality_intersect_t_p
if t_label not in self.intersect_lookups.keys():
self.intersect_lookups[t_label] = {}
self.intersect_lookups[t_label][p_label] = (cardinality_intersect_t_p, cardinality_p_label_augmented)
[docs]
def hP(self):
numerator = []
denominator = []
for t_label, p_label in zip(self.true_labels, self.predicted_labels):
if not (t_label in self.intersect_lookups.keys() and p_label in self.intersect_lookups[t_label].keys()):
t_label_augmented = self.augmented_set_of_node_n(t_label)
p_label_augmented = self.augmented_set_of_node_n(p_label)
self.calculate_intersects(t_label, p_label, t_label_augmented, p_label_augmented)
cardinality_intersect_t_p, cardinality_p_label_augmented = self.intersect_lookups[t_label][p_label]
numerator.append(cardinality_intersect_t_p)
denominator.append(cardinality_p_label_augmented)
return np.sum(np.array(numerator)) / np.sum(np.array(denominator))
[docs]
def hR(self):
numerator = []
denominator = []
for t_label, p_label in zip(self.true_labels, self.predicted_labels):
t_label_augmented = self.augmented_set_of_node_n(t_label)
p_label_augmented = self.augmented_set_of_node_n(p_label)
if not (t_label in self.intersect_lookups.keys() and p_label in self.intersect_lookups[t_label].keys()):
self.calculate_intersects(t_label, p_label, t_label_augmented, p_label_augmented)
cardinality_intersect_t_p, _ = self.intersect_lookups[t_label][p_label]
cardinality_t_label_augmented = len(t_label_augmented)
numerator.append(cardinality_intersect_t_p)
denominator.append(cardinality_t_label_augmented)
return np.sum(np.array(numerator)) / np.sum(np.array(denominator))
[docs]
def hF(self, beta):
hP = self.hP()
hR = self.hR()
hF = (beta**2 + 1) * hP * hR / (beta**2 * hP + hR)
return hF
[docs]
def macro_hF(self, beta):
'''Macro averaged hF-Score (average of micro hF1's for each label)'''
labels = pd.Series(self.true_labels).value_counts().keys()
label_Fb = []
for label in labels:
true_label_idcs = np.where(self.true_labels == label)[0]
hP = self.hP()
hR = self.hR()
Fb = (beta**2 + 1) * hP * hR / (beta**2 * hP + hR)
label_Fb.append(Fb)
return np.sum(np.array(label_Fb))/len(labels)
[docs]
def list_micro_metrics(self, beta):
label_metrics = pd.DataFrame(columns=[f'hF{beta}', 'hR', 'hP'])
for label in np.unique(self.true_labels):
true_label_idcs = np.where(self.true_labels == label)[0]
hP = self.hP()
hR = self.hR()
Fb = (beta**2 + 1) * hP * hR / (beta**2 * hP + hR)
label_metrics.loc[label] = [np.round(Fb, 2), np.round(hR, 2), np.round(hP, 2)]
return label_metrics
[docs]
def infer_dict(graph, parent=None):
dict_of_cell_relations = {}
if parent is None:
for node in graph.nodes:
if len(list(graph.predecessors(node))) == 0:
parent = node
dict_of_cell_relations[parent] = infer_dict(
graph,
parent=parent
)
else:
for parent_node, child_node in graph.edges:
if parent_node == parent:
dict_of_cell_relations[child_node] = infer_dict(
graph,
parent=child_node
)
return dict_of_cell_relations