Source code for Compocyte.core.models.dense_torch
import numpy as np
import torch
import os
import pickle
import logging
logger = logging.getLogger(__name__)
[docs]
class DenseTorch(torch.nn.Module):
[docs]
def __init__(
self,
labels: list,
n_input: int,
n_output: int,
hidden_layers: list=[64, 64],
dropout: float=0.4,
batchnorm: bool=True):
super().__init__()
self.labels_enc = {label: i for i, label in enumerate(labels)}
self.labels_dec = {self.labels_enc[label]: label for label in self.labels_enc.keys()}
self.layers = torch.nn.ModuleList()
layers = [n_input] + hidden_layers + [n_output]
for i in range(len(layers) - 1):
n_in = layers[i]
n_out = layers[i + 1]
new_linear = torch.nn.Linear(n_in, n_out)
new_activation = torch.nn.LeakyReLU(0.1)
new_batchnorm = torch.nn.BatchNorm1d(n_out)
new_dropout = torch.nn.Dropout(dropout)
torch.nn.init.xavier_uniform_(
new_linear.weight,
gain=torch.nn.init.calculate_gain('leaky_relu', 0.1))
torch.nn.init.zeros_(new_linear.bias)
self.layers.append(new_linear)
if i < (len(layers) - 2):
self.layers.append(new_activation)
if batchnorm:
self.layers.append(new_batchnorm)
self.layers.append(new_dropout)
else:
self.layers.append(
torch.nn.Softmax(dim=1)
)
[docs]
def forward(self, x):
torch.autograd.set_detect_anomaly(True)
for layer in self.layers:
x = layer(x)
return x
[docs]
def predict_logits(self, x) -> np.array:
self.eval()
x = torch.from_numpy(x).to(torch.float32)
return self(x).detach().numpy()
[docs]
def predict(self, x) -> np.array:
logits = self.predict_logits(x)
pred = np.argmax(logits, axis=1)
pred = np.array(
[self.labels_dec[p] for p in pred]
)
return pred
[docs]
def reset_output(self, n_output):
in_features = self.layers[-2].in_features
del self.layers[-2] # last dense
del self.layers[-1] # softmax
self.layers.append(
torch.nn.Linear(in_features, n_output)
)
torch.nn.init.xavier_uniform_(self.layers[-1].weight, gain=torch.nn.init.calculate_gain('leaky_relu', 0.1))
torch.nn.init.zeros_(self.layers[-1].bias)
self.layers.append(
torch.nn.Softmax(dim=1)
)
def _save(self, path):
non_param_attr = ['histories', 'labels_enc', 'labels_dec']
non_param_dict = {}
for item in self.__dict__.keys():
if item in non_param_attr:
non_param_dict[item] = self.__dict__[item]
torch.save(self, os.path.join(path, 'model'))
with open(os.path.join(path, 'non_param_dict.pickle'), 'wb') as f:
pickle.dump(non_param_dict, f)
@classmethod
def _load(cls, path):
model = torch.load(os.path.join(path, 'model'), weights_only=False)
with open(os.path.join(path, 'non_param_dict.pickle'), 'rb') as f:
non_param_dict = pickle.load(f)
for item in non_param_dict.keys():
model.__dict__[item] = non_param_dict[item]
return model
[docs]
@classmethod
def import_external(
cls,
model,
labels,):
if not issubclass(type(model), torch.nn.Module):
raise TypeError('To import an external model as DenseTorch, it must be a subclass of torch.nn.Module.')
denseTorch = cls(labels, 2, 2)
denseTorch.layers = torch.nn.ModuleList([model])
return denseTorch