multiml.task.keras.keras_darts module
- class multiml.task.keras.keras_darts.DARTSTask(optimizer_alpha, optimizer_weight, learning_rate_alpha, learning_rate_weight, zeta=0.01, **kwargs)
Bases:
ModelConnectionTask
- __init__(optimizer_alpha, optimizer_weight, learning_rate_alpha, learning_rate_weight, zeta=0.01, **kwargs)
- Parameters:
optimizer_darts_alpha (str) – optimizer for alpha in DARTS optimization
optimizer_darts_weight (str) – optimizer for weight in DARTS optimization
learning_rate_darts_alpha (float) – learning rate (epsilon) for alpha in DARTS optimization
learning_rate_darts_weight (float) – learning rate (epsilon) for weight in DARTS optimization
zeta (float) – zeta parameter in DARTS optimization
**kwargs – Arbitrary keyword arguments
- fit(train_data=None, valid_data=None)
Training model.
- Returns:
training results.
- Return type:
dict
- load_metadata()
Load metadata.
- get_best_submodels()
Returns indices of the best submodels determined by the alpha.
- Returns:
list of index of the selected submodels
- Return type:
list (int)
- build_model()
Build model.
- dump_model(extra_args=None)
Dump current DARTS model.