Source code for Compocyte.core.base.hierarchy_base

"""Hierarchy graph management utilities.

This module provides helpers and a base class for constructing and managing
the hierarchical graph of cell types used by the hierarchical classifier.
"""

import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
from Compocyte.core.tools import flatten_dict, dict_depth, hierarchy_names_unique, \
    make_graph_from_edges, set_node_to_depth, delete_dict_entries
from Compocyte.core.models.log_reg import LogisticRegression
from copy import deepcopy

[docs] class HierarchyBase(): """Add explanation """
[docs] def set_cell_relations(self, root_node, dict_of_cell_relations, obs_names, temp_path): """Once set, cell relations can only be changed one node at a time, using supplied methods, not by simply calling defining new cell relations """ if self.root_node is not None and self.dict_of_cell_relations is not None and self.obs_names is not None: raise Exception('To redefine cell relations after initialization, call update_hierarchy.') dict_of_cell_relations_with_classifiers = deepcopy(dict_of_cell_relations) dict_of_cell_relations, contains_classifier = delete_dict_entries(dict_of_cell_relations, 'classifier') self.root_node = root_node self.ensure_depth_match(dict_of_cell_relations, obs_names) self.ensure_unique_nodes(dict_of_cell_relations) self.dict_of_cell_relations = dict_of_cell_relations self.obs_names = obs_names self.all_nodes = flatten_dict(self.dict_of_cell_relations) self.node_to_depth = set_node_to_depth(self.dict_of_cell_relations) self.make_classifier_graph() if contains_classifier: self.import_classifiers(dict_of_cell_relations_with_classifiers, temp_path=temp_path, parent_key=root_node) self.ensure_normlog()
[docs] def ensure_depth_match(self, dict_of_cell_relations, obs_names): """Check if the annotations supplied in .obs under obs_names are sufficiently deep to work with the hierarchy provided. """ if not dict_depth(dict_of_cell_relations) == len(obs_names): raise Exception('obs_names must contain an annotation key for every level of the '\ 'hierarchy supplied in dict_of_cell_relations.')
[docs] def ensure_unique_nodes(self, dict_of_cell_relations): """Check if keys within the hierarchy are unique across all levels as that is a requirement for uniquely identifying graph nodes with networkx. """ if not hierarchy_names_unique(dict_of_cell_relations): raise Exception('Names given in the hierarchy must be unique.')
[docs] def make_classifier_graph(self): """Compute directed graph from a given dictionary of cell relationships.""" self.graph = nx.DiGraph() make_graph_from_edges(self.dict_of_cell_relations, self.graph)
[docs] def plot_hierarchy(self): """Plot hierarchical cell labels. """ fig, ax = plt.subplots(figsize=(10, 10)) pos = nx.drawing.nx_agraph.graphviz_layout(self.graph, prog='twopi') nx.draw(self.graph, pos, with_labels=True, arrows=True, ax=ax)
[docs] def get_children_obs_key(self, parent_node): """Get the obs key under which labels for the following level in the hierarchy are saved. E. g. if you get_children_obs_key for T cells, it will return the obs key for alpha beta T cell labels and so on. """ depth_parent = self.node_to_depth[parent_node] children_obs_key = self.obs_names[depth_parent + 1] return children_obs_key
[docs] def get_parent_obs_key(self, parent_node): """Get the obs key under which labels for the current level in the hierarchy are saved. E. g. if you get_parent_obs_key for T cells, it will return the obs key in which true T cells are labelled as such. """ depth_parent = self.node_to_depth[parent_node] return self.obs_names[depth_parent]
[docs] def get_child_nodes(self, node): return self.graph.adj[node].keys()
[docs] def get_leaf_nodes(self): return [ x for x in self.graph.nodes() \ if self.graph.out_degree(x) == 0 \ and self.graph.in_degree(x) == 1 ]
[docs] def get_parent_node(self, node, graph=None): if graph is None: graph = self.graph edges = np.array(graph.edges) # In a directed graph there should only be one edge leading TO any given node idx_child_node_edges = np.where(edges[:, 1] == node) parent_node = edges[idx_child_node_edges][0, 0] return parent_node
[docs] def update_hierarchy(self, dict_of_cell_relations, temp_path=None, root_node=None, overwrite=False): dict_of_cell_relations_with_classifiers = deepcopy(dict_of_cell_relations) dict_of_cell_relations, contains_classifier = delete_dict_entries(dict_of_cell_relations, 'classifier') if root_node is not None: self.root_node = root_node if dict_of_cell_relations == self.dict_of_cell_relations: return self.ensure_depth_match(dict_of_cell_relations, self.obs_names) self.ensure_unique_nodes(dict_of_cell_relations) all_nodes_pre = flatten_dict(self.dict_of_cell_relations) self.dict_of_cell_relations = dict_of_cell_relations all_nodes_post = flatten_dict(self.dict_of_cell_relations) self.all_nodes = all_nodes_post self.node_to_depth = set_node_to_depth(self.dict_of_cell_relations) new_graph = nx.DiGraph() make_graph_from_edges(self.dict_of_cell_relations, new_graph) new_nodes = [n for n in all_nodes_post if n not in all_nodes_pre] [n for n in all_nodes_pre if n not in all_nodes_post] moved_nodes = [] classifier_nodes = [] for node in all_nodes_post: if node in new_nodes: continue # Check if node was moved within the hierarchy, i. e. assigned # to a different parent node # Does not change the strategy of assigning the previous node attributes # but may end up a fact of interest if not node == self.root_node: parent_post = self.get_parent_node(node, graph=new_graph) parent_pre = self.get_parent_node(node) if parent_pre != parent_post: moved_nodes.append(node) # Transfer properties, such as local classifier, from old graph # to new graph for item in self.graph.nodes[node]: new_graph.nodes[node][item] = deepcopy(self.graph.nodes[node][item]) if "local_classifier" in self.graph.nodes[node]: classifier_nodes.append(node) # Define nodes that contain a classifier print(f'Transfered to {node}, local classifier {"transferred" if "local_classifier" in self.graph.nodes[node] else "not transferred"}') self.graph = new_graph for node in classifier_nodes: print(f'Ensuring correct output architecture for {node}.') child_nodes = self.get_child_nodes(node) # Previously reset all classifier nodes # Bad idea because you want to conserve as much of the training progress as possible, # resetting as little as possible, as much as necessary if True in [n in new_nodes or n in moved_nodes for n in [node] + list(child_nodes)]: if type(self.graph.nodes[node]['local_classifier']) is LogisticRegression: print('Cannot adjust output structure of logistic regression classifier.') continue # reset label encoding, unproblematic because the final layer is reinitilaized anyway self.graph.nodes[node]['label_encoding'] = {} # TODO: leads to problems? self.graph.nodes[node]['local_classifier'].reset_output(len(child_nodes)) if contains_classifier: self.import_classifiers(dict_of_cell_relations_with_classifiers, temp_path=temp_path, overwrite=overwrite) self.ensure_normlog()