kinoml.ml.torch_loops
¶
WIP
Module Contents¶
- kinoml.ml.torch_loops.multi_measurement_training_loop(dataloaders, observation_models, model, optimizer, loss_function, epochs=100)¶
Standard training loop with multiple dataloaders and observation models.
- Parameters
dataloaders (dict of str -> torch.utils.data.DataLoader) – key must refer to the measurement type present in the dataloader
observation_models (dict of str -> callable) – keys must be the same as in dataloaders, and point to pytorch-compatible callables that convert delta_over_kt to the corresponding measurement_type
model (torch.nn.Model) – instance of the model to train
optimizer (torch.optim.Optimizer) – instance of the optimization engine
loss_function (torch.nn.modules._Loss) – instance of the loss function to apply (e.g. MSELoss())
epochs (int) – number of iterations the loop will run
- Returns
model (torch.nn.Model) – The trained model (same instance as provided in parameters)
loss_timeseries (list of float) – Cumulative loss per epoch
- class kinoml.ml.torch_loops.EarlyStopping(patience=5, min_delta=0)¶
Early stopping to stop the training when the loss does not improve after certain epochs.
Taken from https://debuggercafe.com/using-learning-rate-scheduler-and-early-stopping-with-pytorch/
- __call__(val_loss)¶
- class kinoml.ml.torch_loops.LRScheduler(optimizer, patience=5, min_lr=1e-06, factor=0.5)¶
Learning rate scheduler. If the validation loss does not decrease for the given number of patience epochs, then the learning rate will decrease by by given factor.
Taken from https://debuggercafe.com/using-learning-rate-scheduler-and-early-stopping-with-pytorch/
- __call__(val_loss)¶
- kinoml.ml.torch_loops._old_training_loop()¶
Deprecated