Source code for Compocyte.core.tuner

"""Hyperparameter tuner utilities.

This module contains helper code for tuning hyperparameters of local
classifiers, storing results and orchestrating grid/random search runs.
"""

import sqlite3
import time
import numpy as np
import pandas as pd
import scanpy as sc
from Compocyte.core.hierarchical_classifier import HierarchicalClassifier
from Compocyte.core.models.dense_torch import DenseTorch
from Compocyte.core.models.trees import BoostedTrees
from Compocyte.core.models.fit_methods import fit, predict_logits

[docs] class Tuner(): def __init__(self, database_path: str, adata_path: str, hierarchy: dict, root_node: str, obs_names: list): self.con = sqlite3.connect(database_path) self.cur = self.con.cursor() self.adata_path = adata_path self.hierarchy = hierarchy self.root_node = root_node self.obs_names = obs_names
[docs] def train_from_tuner(self, save_path: str, adata: sc.AnnData, parallelize=True, max_cells: int=None, stratify_by: str=None, processes: int=None) -> HierarchicalClassifier: classifier = HierarchicalClassifier( save_path, root_node=self.root_node, adata=adata, dict_of_cell_relations=self.hierarchy, obs_names=self.obs_names) for node in classifier.graph.nodes: n_children = len(list(classifier.graph.successors(node))) if n_children >= 1: subset = classifier.select_subset(node) if len(subset) < 5: continue tup = self.get_best_trial(node) # No hypopt results exist for this node # Use defaults if tup is None: continue else: kwargs = { 'n_features': tup[0], 'hidden_layers': eval(tup[1]), 'dropout': tup[2], 'epochs': tup[3], 'batch_size': tup[4], 'starting_lr': tup[5], 'max_lr': tup[6], 'momentum': tup[7], 'beta': tup[8], 'gamma': tup[9], 'threshold': tup[10], } if not hasattr(classifier, 'tuned_kwargs'): classifier.tuned_kwargs = {} classifier.tuned_kwargs[node] = kwargs if max_cells is not None and stratify_by is not None: classifier.introduce_limit(max_cells, stratify_by) if not parallelize: classifier.num_threads = processes else: classifier.num_threads = 1 classifier.train_all_child_nodes(parallelize=parallelize, processes=processes) return classifier
[docs] def trial_run( self, cv_key: str, n_features: int, hidden_layers: list, dropout: float, epochs: int, batch_size: int, starting_lr: float, max_lr: float, momentum: float, beta: float, gamma: float, test_factor: int, parallelize: bool=True, num_threads=None, standardize_separately: str=None) -> None: adata = sc.read_h5ad(self.adata_path) rng = np.random.default_rng(42) adata = adata[ rng.choice(adata.obs_names, int(len(adata) / test_factor), replace=False)] performance_per_cv = pd.DataFrame(columns=['node', 'threshold', 'max_correct', 'correct_total']) for dataset in adata.obs[cv_key].unique(): train_adata = adata[adata.obs[cv_key] != dataset] val_adata = adata[adata.obs[cv_key] == dataset] classifier = HierarchicalClassifier( 'testing', root_node=self.root_node, adata=train_adata, dict_of_cell_relations=self.hierarchy, obs_names=self.obs_names) classifier.num_threads = num_threads for node in classifier.graph.nodes: n_children = len(list(classifier.graph.successors(node))) if n_children >= 1: subset = classifier.select_subset(node) if len(subset) < 5: continue features = classifier.run_feature_selection( node=node, overwrite=False, n_features=n_features, max_features=None, min_features=30, test_factor=test_factor) classifier.graph.nodes[node]['selected_var_names'] = features classifier_type = DenseTorch hidden_layers = hidden_layers if isinstance(hidden_layers, list) else eval(hidden_layers) if -1 in hidden_layers: classifier_type = BoostedTrees classifier.create_local_classifier( node, hidden_layers=hidden_layers, dropout=dropout, batchnorm=True, classifier_type=classifier_type ) features = classifier.graph.nodes[node]['selected_var_names'] model = classifier.graph.nodes[node]['local_classifier'] subset = classifier.select_subset(node, features=features) x = subset.X child_obs = classifier.obs_names[classifier.node_to_depth[node] + 1] y = subset.obs[child_obs].values if standardize_separately is not None: idx = [] for dataset in subset.obs[standardize_separately].unique(): idx.append(np.where(subset.obs[standardize_separately] == dataset)) else: idx = None fit(model, x, y, standardize_idx=idx, epochs=epochs, batch_size=batch_size, starting_lr=starting_lr, max_lr=max_lr, momentum=momentum, beta=beta, gamma=gamma) classifier.load_adata(val_adata) for node in classifier.graph.nodes: if 'local_classifier' not in classifier.graph.nodes[node]: continue features = classifier.graph.nodes[node]['selected_var_names'] model = classifier.graph.nodes[node]['local_classifier'] subset = classifier.select_subset_prediction(node, features=features, for_trial=True) if len(subset) < 5: continue x = subset.X child_obs = self.obs_names[classifier.node_to_depth[node] + 1] y = subset.obs[child_obs].values label_enc = model.labels_enc y = np.array([label_enc[label] if label in label_enc.keys() else -1 for label in y]) logits = predict_logits(model, x) activations = np.max(logits, axis=1) matches = np.argmax(logits, axis=1) == y if hasattr(model, 'labels_dec'): child_obs = f'{child_obs}_pred' if child_obs not in classifier.adata.obs.columns: classifier.adata.obs[child_obs] = '' classifier.adata.obs[child_obs] = classifier.adata.obs[child_obs].astype(str) pred = np.argmax(logits, axis=1).astype(int) pred = np.array([model.labels_dec[p] for p in pred]) classifier.adata.obs.loc[ subset.obs_names, child_obs ] = pred for threshold in range(100): threshold /= 100 max_correct = len(matches) n_matches = np.sum(matches) correct_positive = matches & (activations > threshold) correct_negative = (~matches) & (activations <= threshold) correct_total = np.sum(correct_positive) + np.sum(correct_negative) performance_per_cv.loc[ len(performance_per_cv), ['node', 'threshold', 'n_matches', 'max_correct', 'correct_total'] ] = [node, threshold, n_matches, max_correct, correct_total] trials = len(adata.obs[cv_key].unique()) for node in performance_per_cv.node.unique(): node_performance = performance_per_cv[performance_per_cv.node == node] for threshold in node_performance.threshold.unique(): threshold_performance = node_performance[node_performance.threshold == threshold] n_matches = threshold_performance.n_matches.sum() correct_total = threshold_performance.correct_total.sum() max_total = threshold_performance.max_correct.sum() fraction_matches = n_matches / max_total fraction_correct = correct_total / max_total self.make_entry( node=node, trials=trials, fraction_correct=fraction_correct, fraction_matches=fraction_matches, n_features=n_features, hidden_layers=hidden_layers, dropout=dropout, epochs=epochs, batch_size=batch_size, starting_lr=starting_lr, max_lr=max_lr, momentum=momentum, beta=beta, gamma=gamma, threshold=threshold)
[docs] def make_db(self) -> None: self.cur.execute("""CREATE TABLE IF NOT EXISTS trials( node, trials, fraction_correct, fraction_matches, n_features, hidden_layers, dropout, epochs, batch_size, starting_lr, max_lr, momentum, beta, gamma, threshold, t TIMESTAMP)""") self.con.commit()
[docs] def make_entry( self, node: str, trials: int, fraction_correct: float, fraction_matches: int, n_features: int, hidden_layers: str, dropout: float, epochs: int, batch_size: int, starting_lr: float, max_lr: float, momentum: float, beta: float, gamma: float, threshold: float) -> None: for i in range(3): try: self.cur.execute(f""" INSERT INTO trials VALUES ('{node}', {trials}, {fraction_correct}, {fraction_matches}, {n_features}, '{hidden_layers}', {dropout}, {epochs}, {batch_size}, {starting_lr}, {max_lr}, {momentum}, {beta}, {gamma}, {threshold}, DATETIME('now')) """) self.con.commit() break except sqlite3.OperationalError: time.sleep(0.01)
[docs] def get_best_trial(self, node) -> dict: res = None for i in range(10): try: res = self.cur.execute( f"""SELECT n_features, hidden_layers, dropout, epochs, batch_size, starting_lr, max_lr, momentum, beta, gamma, threshold FROM trials WHERE node == '{node}' ORDER BY fraction_matches DESC, fraction_correct DESC""" ) break except sqlite3.OperationalError: time.sleep(0.01) if res is None: tup = None else: tup = res.fetchone() return tup
def __del__(self): self.con.close()