Compocyte.core.models.fit_methodsΒΆ
Shared fit and predict helper methods for different model wrappers.
This module centralises common training and prediction routines that are used by the different model wrappers (PyTorch, logistic regression, CatBoost etc.).
- class Compocyte.core.models.fit_methods.DaskBatchDataset(*args: Any, **kwargs: Any)[source]ΒΆ
Bases:
IterableDataset
- Compocyte.core.models.fit_methods.fit(model: DenseTorch | LogisticRegression | DummyClassifier, x: numpy.array, y: numpy.array, standardize_idx: list = None, **fit_kwargs)[source]ΒΆ
- Parameters:
model (Union[DenseTorch, LogisticRegression, DummyClassifier]) β Model to be fitted.
x (np.array) β Input data.
y (np.array) β Target data in the shape of a 1-dimensional array of label strings.
- Returns:
_description_
- Return type:
_type_
- Compocyte.core.models.fit_methods.fit_logreg(model: LogisticRegression, x, y, **fit_kwargs)[source]ΒΆ
- Compocyte.core.models.fit_methods.fit_torch(model: DenseTorch, x: numpy.array, y: numpy.array, epochs: int = 40, batch_size: int = 64, starting_lr: float = 0.01, max_lr: float = 0.1, momentum: float = 0.5, parallelize: bool = True, num_threads: int = 1, beta: float = 0.8, gamma: float = 2.0, class_balance: bool = True, max_cells: int = 1000000)[source]ΒΆ
- Compocyte.core.models.fit_methods.fit_trees(model: BoostedTrees, x, y, **fit_kwargs)[source]ΒΆ