multiml.agent.keras.callback module
- class multiml.agent.keras.callback.AlphaDumperCallback
Bases:
Callback
Dump alpha values in DARTS training.
- alpha_history
list of alpha values on each epoch
- model
instance of keras.models.Model. Reference of the model being trained. (member variable of Callback class)
- __init__()
- static formatting(var)
Format tensor for display.
- Parameters:
var (Tensor) – Tensor
- Returns:
formatted alpha values
- Return type:
str
- on_train_begin(logs=None)
Called at the beginning of training.
Subclasses should override for any actions to run.
- Parameters:
logs – Dict. Currently no data is passed to this argument for this method but that may change in the future.
- on_epoch_end(epoch, logs=None)
Called at the end of an epoch.
Subclasses should override for any actions to run. This function should only be called during TRAIN mode.
- Parameters:
epoch – Integer, index of epoch.
logs –
- Dict, metric results for this training epoch, and for the
validation epoch if validation is performed. Validation result keys are prefixed with val_. For training epoch, the values of the
Model’s metrics are returned. Example : {‘loss’: 0.2, ‘acc’: 0.7}.
- on_train_end(logs=None)
Called at the end of training.
Subclasses should override for any actions to run.
- Parameters:
logs – Dict. Currently the output of the last call to on_epoch_end() is passed to this argument for this method but that may change in the future.
- get_alpha_history()
- class multiml.agent.keras.callback.EachLossDumperCallback
Bases:
Callback
Dump each loss values in DARTS training.
- loss_history
list of loss values on each epoch
- model
instance of keras.models.Model. Reference of the model being trained. (member variable of Callback class)
- __init__()
- on_epoch_end(epoch, logs=None)
Called at the end of an epoch.
Subclasses should override for any actions to run. This function should only be called during TRAIN mode.
- Parameters:
epoch – Integer, index of epoch.
logs –
- Dict, metric results for this training epoch, and for the
validation epoch if validation is performed. Validation result keys are prefixed with val_. For training epoch, the values of the
Model’s metrics are returned. Example : {‘loss’: 0.2, ‘acc’: 0.7}.
- get_loss_history()
- class multiml.agent.keras.callback.NaNKillerCallback
Bases:
Callback
Stop training when nan is found in alphas.
- on_epoch_end(epoch, logs=None)
Called at the end of an epoch.
Subclasses should override for any actions to run. This function should only be called during TRAIN mode.
- Parameters:
epoch – Integer, index of epoch.
logs –
- Dict, metric results for this training epoch, and for the
validation epoch if validation is performed. Validation result keys are prefixed with val_. For training epoch, the values of the
Model’s metrics are returned. Example : {‘loss’: 0.2, ‘acc’: 0.7}.