multiml.task.keras.keras_darts module

class multiml.task.keras.keras_darts.DARTSTask(optimizer_alpha, optimizer_weight, learning_rate_alpha, learning_rate_weight, zeta=0.01, **kwargs)

Bases: ModelConnectionTask

__init__(optimizer_alpha, optimizer_weight, learning_rate_alpha, learning_rate_weight, zeta=0.01, **kwargs)
Parameters:
  • optimizer_darts_alpha (str) – optimizer for alpha in DARTS optimization

  • optimizer_darts_weight (str) – optimizer for weight in DARTS optimization

  • learning_rate_darts_alpha (float) – learning rate (epsilon) for alpha in DARTS optimization

  • learning_rate_darts_weight (float) – learning rate (epsilon) for weight in DARTS optimization

  • zeta (float) – zeta parameter in DARTS optimization

  • **kwargs – Arbitrary keyword arguments

fit(train_data=None, valid_data=None)

Training model.

Returns:

training results.

Return type:

dict

load_metadata()

Load metadata.

get_best_submodels()

Returns indices of the best submodels determined by the alpha.

Returns:

list of index of the selected submodels

Return type:

list (int)

build_model()

Build model.

dump_model(extra_args=None)

Dump current DARTS model.