Compocyte.core.hierarchical_classifier¶
Classes
|
- class Compocyte.core.hierarchical_classifier.HierarchicalClassifier(save_path, adata=None, root_node=None, dict_of_cell_relations=None, obs_names=None, default_input_data='normlog', num_threads=1, ignore_counts=False, temp_paths=None, graph=None, intermittent_saving=False)[source]¶
Bases:
DataBase,HierarchyBase,ExportImportBase- __init__(save_path, adata=None, root_node=None, dict_of_cell_relations=None, obs_names=None, default_input_data='normlog', num_threads=1, ignore_counts=False, temp_paths=None, graph=None, intermittent_saving=False)[source]¶
Initialize a HierarchicalClassifier.
- Parameters:
save_path (str) – Directory where classifier state is saved and loaded from.
adata (sc.AnnData, optional) – AnnData object containing single-cell data.
root_node (str, optional) – Name of the root node in the cell-type hierarchy.
dict_of_cell_relations (dict, optional) – Nested dictionary defining the cell-type hierarchy, e.g.
{'root': {'T cell': {'CD4+': {}, 'CD8+': {}}}}.obs_names (list of str, optional) – List of
adata.obscolumn names indexed by hierarchy depth (index 0 = root level, index 1 = first child, …).default_input_data (str) – Input data representation to use. One of
'normlog'(log-normalized counts inadata.X) or'counts'(raw counts). Defaults to'normlog'.num_threads (int) – Number of CPU threads to use per training process. Defaults to
1.ignore_counts (bool) – If
True, the presence of raw count data is not confirmed. Must be used, for example, if .X contains dimensionality-reduced or manually normalized data. Defaults toFalse.temp_paths (str or list of str, optional) – Temporary directory paths used when importing or exporting classifier model files.
graph (networkx.DiGraph, optional) – Pre-built hierarchy graph. When provided,
dict_of_cell_relationsis inferred from it viainfer_dict.intermittent_saving (bool) – If
True, the classifier state is saved to disk after each node is trained. Defaults toFalse.
Example
>>> hc = HierarchicalClassifier( ... save_path='/data/classifier', ... adata=adata, ... root_node='root', ... dict_of_cell_relations={'root': {'T cell': {}, 'B cell': {}}}, ... obs_names=['cell_type_l1', 'cell_type_l2'], ... )
- save(save_adata=False)[source]¶
Save the classifier state to disk.
Serializes instance attributes, per-node graph content, and local classifiers to timestamped subdirectories under
self.save_path. Each call creates a new timestamped snapshot; loading always restores the most recent snapshot.- Parameters:
save_adata (bool) – If
True, also writesself.adatato an.h5adfile under<save_path>/data/. Defaults toFalse.
Example
>>> hc.save()
- load(load_path=None, load_adata=False)[source]¶
Load the most recent classifier snapshot from disk.
Restores instance attributes from the latest timestamped pickle under
<load_path>/hierarchical_classifiers/, optionally reloadsadata, and deserializes per-node graph content and local classifiers. Reconstructsself.graphif it does not already exist.- Parameters:
Example
>>> hc = HierarchicalClassifier(save_path='/data/classifier') >>> hc.load() >>> hc.load(load_path='/backup/classifier', load_adata=True)
- limit_cells(subset, max_cells, stratify_by)[source]¶
Randomly subsample cells from a subset, stratified by a given obs column if a maximum cell count (max_cells) is exceeded.
Cells are sampled equally from each stratum defined by
stratify_by. If the total after stratified sampling falls belowmax_cells, additional cells are drawn at random from the remaining pool to reach the cap.- Parameters:
- Returns:
- Subsampled AnnData with at most
max_cellsobservations. Returns the original
subsetunchanged iflen(subset) <= max_cells.
- Subsampled AnnData with at most
- Return type:
sc.AnnData
Example
>>> limited = hc.limit_cells( ... subset, max_cells=10_000, stratify_by='dataset' ... )
- introduce_limit(max_cells, stratify_by)[source]¶
Set a per-training-run cell cap and stratification column.
The limit is applied during
select_subset()whenevermax_cellsis passed to that call. Cells are subsampled vialimit_cells()using the specifiedstratify_bycolumn.- Parameters:
Example
>>> hc.introduce_limit(max_cells=50_000, stratify_by='dataset')
- select_subset(node, features=None, max_cells=None)[source]¶
Select training cells for a given hierarchy node.
Returns cells labeled as
nodeat the node’s depth level that also carry a non-empty child-level label (i.e. cells with a known subtype). Optionally restricts to a specified feature set and applies a stratified cell count cap.- Parameters:
node (str) – Name of the hierarchy node to select cells for.
features (list of str, optional) – Gene/feature names to restrict the returned subset to. If
None, all features are included.max_cells (int, optional) – Maximum number of cells to return. Only applied when
introduce_limit()has been called first (to setself.stratify_by); otherwise ignored.
- Returns:
- Subset of
self.adatacontaining only cells labeled as nodewith a non-empty child-level annotation.
- Subset of
- Return type:
sc.AnnData
Example
>>> subset = hc.select_subset('T cell') >>> subset = hc.select_subset( ... 'T cell', features=['CD3D', 'CD4'], max_cells=10_000 ... )
- select_subset_prediction(node, features=None, for_trial=False)[source]¶
Select cells assigned to a given node for inference.
In normal prediction mode cells are selected by their previously predicted label at the node’s depth level. When the predicted obs column does not yet exist, all cells are returned (root-level fallback). With
for_trial=Trueground-truth labels are used instead of predictions.- Parameters:
node (str) – Name of the hierarchy node.
features (list of str, optional) – Gene/feature names to restrict the returned subset to. If
None, all features are included.for_trial (bool) – If
True, uses ground-truth obs labels instead of predicted labels to select cells, to train with all ground-truth relevant cells. Defaults toFalse.
- Returns:
- Subset of
self.adatacontaining cells associated with nodefor prediction purposes.
- Subset of
- Return type:
sc.AnnData
Example
>>> subset = hc.select_subset_prediction('T cell') >>> subset = hc.select_subset_prediction('T cell', for_trial=True)
- run_feature_selection(node, overwrite=False, n_features=-1, max_features=None, min_features=30, test_factor=1.0, max_cells=100000)[source]¶
Select the most informative features for classifying children of a node.
Uses ANOVA F-statistics (
SelectKBestwithf_classif) on robustly-scaled expression values to rank and select features. The target feature count is either specified directly vian_featuresor inferred from sample count using the heuristic of 1 feature per 100 cells, then clamped to[min_features, max_features].- Parameters:
node (str) – Name of the hierarchy node whose children are the prediction targets.
overwrite (bool) – If
True, overwrite any previously stored feature selection for this node. Defaults toFalse.n_features (int) – Exact number of features to select. Use
-1to infer the count from sample size. Defaults to-1.max_features (int, optional) – Upper bound on the number of features selected. Defaults to the total number of features in
self.adata.min_features (int) – Lower bound on the number of features selected. Defaults to
30.test_factor (float) – Multiplier applied to the sample-size-inferred feature count. Defaults to
1.0.max_cells (int) – Maximum number of cells used for feature selection. Defaults to
100_000.
- Returns:
Names of the selected features (genes).
- Return type:
- Raises:
Exception – If features have already been selected for
nodeandoverwriteisFalse.
Example
>>> features = hc.run_feature_selection('T cell', n_features=200) >>> features = hc.run_feature_selection( ... 'T cell', overwrite=True, min_features=50, max_features=500 ... )
- create_local_classifier(node, overwrite=False, classifier_type=<class 'Compocyte.core.models.dense_torch.DenseTorch'>, **classifier_kwargs)[source]¶
Instantiate and attach a local classifier to a hierarchy node.
Creates a classifier of the specified type using the node’s previously selected features and unique child labels. If the node has only one child label, a
DummyClassifieris created regardless ofclassifier_type.- Parameters:
node (str) – Name of the hierarchy node.
overwrite (bool) – If
True, replace any existing classifier at this node. Defaults toFalse.classifier_type (type or str) – Classifier class to instantiate, or one of the strings
'DenseTorch','LogisticRegression', or'BoostedTrees'. Defaults toDenseTorch.**classifier_kwargs – Additional keyword arguments forwarded to the classifier constructor (e.g.
hidden_layers,dropout).
- Raises:
Exception – If a classifier already exists at
nodeandoverwriteisFalse.Exception – If
run_feature_selection()has not been called fornodefirst.Exception – If
classifier_typeis an unrecognized string.
Example
>>> hc.create_local_classifier( ... 'T cell', classifier_type='DenseTorch', ... hidden_layers=[64, 32], dropout=0.2 ... )
- train_single_node(node, standardize_separately=None, **fit_kwargs)[source]¶
Train the local classifier at a single hierarchy node.
Runs feature selection and classifier creation if they have not been performed yet, then fits the model on cells labeled as
node. Whenself.tuned_kwargscontains an entry fornode, those hyperparameters override any providedfit_kwargs. ReturnsNonewhen the node has fewer than 5 labeled cells.The method returns a dict of updated node parameters rather than modifying
self.graphin-place, which is required for safe use withmultiprocessing.Pool.- Parameters:
node (str) – Name of the hierarchy node to train.
standardize_separately (str, optional) –
adata.obscolumn name (e.g.'dataset'). When provided, cells are grouped by unique values of this column and each group is robustly scaled independently before training.**fit_kwargs – Additional keyword arguments forwarded to
fit()(e.g.epochs,batch_size,starting_lr).
- Returns:
- Dictionary of updated node parameters including
'learning_curve', orNoneif the node has too few cells.
- Return type:
dict or None
- Raises:
Exception – If
num_threadsis not set on the instance and not provided infit_kwargs.
Example
>>> params = hc.train_single_node('T cell', epochs=50, batch_size=256)
- set_classifier_type(node, classifier_type)[source]¶
Override the default classifier type for one or more hierarchy nodes.
Stores the mapping in
self.specified_classifier_types, which is consulted bytrain_single_node()when creating local classifiers. Can be called with a single node name or a list of node names.- Parameters:
Example
>>> hc.set_classifier_type('T cell', LogisticRegression) >>> hc.set_classifier_type(['B cell', 'NK cell'], BoostedTrees)
- train_all_child_nodes(parallelize=False, processes=None)[source]¶
Train local classifiers for all non-leaf nodes in the hierarchy.
Iterates over every node that has at least one child and calls
train_single_node(). In sequential mode, saves the classifier state after each node whenself.intermittent_savingisTrue. In parallel mode, node parameters are collected after all workers finish and written back to the graph.- Parameters:
- Returns:
None
- Raises:
Exception – If
parallelize=Trueandprocessesis not specified.- Return type:
None
Example
>>> hc.train_all_child_nodes() >>> hc.train_all_child_nodes(parallelize=True, processes=8)
- predict_single_node(node, threshold=-1, monte_carlo=None)[source]¶
Run inference at a single hierarchy node and write predictions to adata.obs.
Selects cells currently routed to
node, runs the node’s local classifier, and writes predicted child-level labels to<child_obs>_predinself.adata.obs. If an'overclustering'column is present, predictions are harmonized per cluster by majority vote.When
monte_carlois set, the method also computes per-sample mean and standard deviation of the winning label’s activation across MC iterations and stores them inself.adata.obs['monte_carlo_mean']andself.adata.obs['monte_carlo_std'].- Parameters:
node (str) – Name of the hierarchy node at which to run prediction.
threshold (float) – Minimum confidence required to assign a label. Use
-1to disable thresholding and always assign the top-scoring label. Defaults to-1.monte_carlo (int, optional) – Number of Monte Carlo dropout forward passes for uncertainty estimation. If
None, standard deterministic inference is used.
- Returns:
- Array of predicted child-level cell-type labels for cells
routed to
node.
- Return type:
Example
>>> pred = hc.predict_single_node('T cell') >>> pred = hc.predict_single_node('T cell', threshold=0.9, monte_carlo=100)
- predict_all_child_nodes(node, threshold=-1, mlnp=False, monte_carlo=None)[source]¶
Recursively predict cell types from a starting node down the hierarchy.
Calls
predict_single_node()atnode, then recurses into each child that itself has children (i.e. non-leaf children). Per-node thresholds stored inself.graph.nodes[node]['threshold']take precedence over thethresholdargument unlessmlnp=True.- Parameters:
node (str) – Root of the subtree to predict. Typically
self.root_node.threshold (float) – Default confidence threshold applied at each node. Use
-1to always assign a label. Defaults to-1.mlnp (bool) – Mandatory leaf-node prediction. If
True, per-node thresholds are ignored and prediction is forced to leaf nodes. Defaults toFalse.monte_carlo (int, optional) – Number of Monte Carlo dropout iterations, forwarded to
predict_single_node().
Example
>>> hc.predict_all_child_nodes(hc.root_node) >>> hc.predict_all_child_nodes(hc.root_node, threshold=0.9) >>> hc.predict_all_child_nodes(hc.root_node, monte_carlo=50)