import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from Compocyte.core.tools import flatten_dict, dict_depth, hierarchy_names_unique, \
make_graph_from_edges, set_node_to_depth, delete_dict_entries
from Compocyte.core.models.log_reg import LogisticRegression
from copy import deepcopy
[docs]
class HierarchyBase():
"""Add explanation
"""
[docs]
def set_cell_relations(self, root_node, dict_of_cell_relations, obs_names, temp_paths):
"""Once set, cell relations can only be changed one node at a time, using supplied methods,
not by simply calling defining new cell relations
"""
if self.root_node is not None and self.dict_of_cell_relations is not None and self.obs_names is not None:
raise Exception('To redefine cell relations after initialization, call update_hierarchy.')
dict_of_cell_relations_with_classifiers = deepcopy(dict_of_cell_relations)
dict_of_cell_relations, contains_classifier = delete_dict_entries(dict_of_cell_relations, 'classifier')
self.root_node = root_node
self.ensure_depth_match(dict_of_cell_relations, obs_names)
self.ensure_unique_nodes(dict_of_cell_relations)
self.dict_of_cell_relations = dict_of_cell_relations
self.obs_names = obs_names
self.all_nodes = flatten_dict(self.dict_of_cell_relations)
self.node_to_depth = set_node_to_depth(self.dict_of_cell_relations)
self.make_classifier_graph()
if contains_classifier:
self.import_classifiers(dict_of_cell_relations_with_classifiers, temp_paths=temp_paths, parent_key=root_node)
self.ensure_normlog()
[docs]
def ensure_depth_match(self, dict_of_cell_relations, obs_names):
"""Check if the annotations supplied in .obs under obs_names are sufficiently deep to work
with the hierarchy provided.
"""
if not dict_depth(dict_of_cell_relations) == len(obs_names):
raise Exception('obs_names must contain an annotation key for every level of the '\
'hierarchy supplied in dict_of_cell_relations.')
[docs]
def ensure_unique_nodes(self, dict_of_cell_relations):
"""Check if keys within the hierarchy are unique across all levels as that is a requirement
for uniquely identifying graph nodes with networkx.
"""
if not hierarchy_names_unique(dict_of_cell_relations):
raise Exception('Names given in the hierarchy must be unique.')
[docs]
def make_classifier_graph(self):
"""Compute directed graph from a given dictionary of cell relationships."""
self.graph = nx.DiGraph()
make_graph_from_edges(self.dict_of_cell_relations, self.graph)
[docs]
def plot_hierarchy(self):
"""Plot hierarchical cell labels.
"""
fig, ax = plt.subplots(figsize=(10, 10))
pos = nx.drawing.nx_agraph.graphviz_layout(self.graph, prog='twopi')
nx.draw(self.graph, pos, with_labels=True, arrows=True, ax=ax)
[docs]
def get_children_obs_key(self, parent_node):
"""Get the obs key under which labels for the following level in the hierarchy are saved.
E. g. if you get_children_obs_key for T cells, it will return the obs key for alpha beta
T cell labels and so on.
"""
depth_parent = self.node_to_depth[parent_node]
children_obs_key = self.obs_names[depth_parent + 1]
return children_obs_key
[docs]
def get_parent_obs_key(self, parent_node):
"""Get the obs key under which labels for the current level in the hierarchy are saved.
E. g. if you get_parent_obs_key for T cells, it will return the obs key in which true
T cells are labelled as such.
"""
depth_parent = self.node_to_depth[parent_node]
return self.obs_names[depth_parent]
[docs]
def get_child_nodes(self, node):
return self.graph.adj[node].keys()
[docs]
def get_leaf_nodes(self):
return [
x for x in self.graph.nodes() \
if self.graph.out_degree(x) == 0 \
and self.graph.in_degree(x) == 1
]
[docs]
def get_parent_node(self, node, graph=None):
if graph is None:
graph = self.graph
edges = np.array(graph.edges)
# In a directed graph there should only be one edge leading TO any given node
idx_child_node_edges = np.where(edges[:, 1] == node)
parent_node = edges[idx_child_node_edges][0, 0]
return parent_node
[docs]
def update_hierarchy(self, dict_of_cell_relations, temp_path=None, root_node=None, overwrite=False):
dict_of_cell_relations_with_classifiers = deepcopy(dict_of_cell_relations)
dict_of_cell_relations, contains_classifier = delete_dict_entries(dict_of_cell_relations, 'classifier')
if root_node is not None:
self.root_node = root_node
if dict_of_cell_relations == self.dict_of_cell_relations:
return
self.ensure_depth_match(dict_of_cell_relations, self.obs_names)
self.ensure_unique_nodes(dict_of_cell_relations)
all_nodes_pre = flatten_dict(self.dict_of_cell_relations)
self.dict_of_cell_relations = dict_of_cell_relations
all_nodes_post = flatten_dict(self.dict_of_cell_relations)
self.all_nodes = all_nodes_post
self.node_to_depth = set_node_to_depth(self.dict_of_cell_relations)
new_graph = nx.DiGraph()
make_graph_from_edges(self.dict_of_cell_relations, new_graph)
new_nodes = [n for n in all_nodes_post if n not in all_nodes_pre]
[n for n in all_nodes_pre if n not in all_nodes_post]
moved_nodes = []
classifier_nodes = []
for node in all_nodes_post:
if node in new_nodes:
continue
# Check if node was moved within the hierarchy, i. e. assigned
# to a different parent node
# Does not change the strategy of assigning the previous node attributes
# but may end up a fact of interest
if not node == self.root_node:
parent_post = self.get_parent_node(node, graph=new_graph)
parent_pre = self.get_parent_node(node)
if parent_pre != parent_post:
moved_nodes.append(node)
# Transfer properties, such as local classifier, from old graph
# to new graph
for item in self.graph.nodes[node]:
new_graph.nodes[node][item] = deepcopy(self.graph.nodes[node][item])
if "local_classifier" in self.graph.nodes[node]:
classifier_nodes.append(node) # Define nodes that contain a classifier
print(f'Transfered to {node}, local classifier {"transferred" if "local_classifier" in self.graph.nodes[node] else "not transferred"}')
self.graph = new_graph
for node in classifier_nodes:
print(f'Ensuring correct output architecture for {node}.')
child_nodes = self.get_child_nodes(node)
# Previously reset all classifier nodes
# Bad idea because you want to conserve as much of the training progress as possible,
# resetting as little as possible, as much as necessary
if True in [n in new_nodes or n in moved_nodes for n in [node] + list(child_nodes)]:
if type(self.graph.nodes[node]['local_classifier']) is LogisticRegression:
print('Cannot adjust output structure of logistic regression classifier.')
continue
# reset label encoding, unproblematic because the final layer is reinitilaized anyway
self.graph.nodes[node]['label_encoding'] = {} # TODO: leads to problems?
self.graph.nodes[node]['local_classifier'].reset_output(len(child_nodes))
if contains_classifier:
self.import_classifiers(dict_of_cell_relations_with_classifiers, temp_path=temp_path, overwrite=overwrite)
self.ensure_normlog()