multiml.task.keras.KerasBaseTask

class multiml.task.keras.KerasBaseTask(run_eagerly=None, callbacks=['EarlyStopping', 'ModelCheckpoint'], save_tensorboard=False, **kwargs)

Base task for Keras model.

Examples

>>> # your keras model
>>> class MyKerasModel(Model):
>>>     def __init__(self, units=1):
>>>         super(MyKerasModel, self).__init__()
>>>
>>>         self.dense = Dense(units)
>>>         self.relu = ReLU()
>>>
>>>     def call(self, x):
>>>         return self.relu(self.dense(x))
>>>
>>> # create task instance
>>> task = KerasBaseTask(storegate=storegate,
>>>                      model=MyKerasModel,
>>>                      input_var_names=('x0', 'x1'),
>>>                      output_var_names='outputs-keras',
>>>                      true_var_names='labels',
>>>                      optimizer='adam',
>>>                      optimizer_args=dict(lr=0.1),
>>>                      loss='binary_crossentropy')
>>> task.set_hps({'num_epochs': 5})
>>> task.execute()
>>> task.finalize()
__init__(run_eagerly=None, callbacks=['EarlyStopping', 'ModelCheckpoint'], save_tensorboard=False, **kwargs)
Parameters:
  • run_eagerly (bool) – Run on eager execution mode (not graph mode).

  • callbacks (list(str or keras.Callback)) – callback for keras model training. Predefined callbacks (EarlyStopping, ModelCheckpoint, and TensorBoard) can be selected by str. Other user-defined callbacks should be given as keras.Callback object.

  • save_tensorboard (bool) – use tensorboard callback in training.

  • **kwargs – Arbitrary keyword arguments.

Methods

__init__([run_eagerly, callbacks, ...])

param run_eagerly:

Run on eager execution mode (not graph mode).

build_model()

Build model.

compile()

Compile model, optimizer and loss.

compile_loss()

Compile keras model.

compile_model()

Compile keras model.

compile_optimizer()

Compile optimizer.

compile_var_names()

Compile var_names.

do_test()

Perform test phase or not.

do_train()

Perform train phase or not.

do_valid()

Perform valid phase or not.

dump_model([extra_args])

Dump current keras model.

execute()

Execute a task.

finalize()

Finalize base task.

fit([train_data, valid_data])

Training model.

fit_predict([fit_args, predict_args])

Fit and predict model.

get_input_true_data(phase)

Get input and true data.

get_input_var_shapes([phase])

Get shape of input_var_names.

get_inputs()

Returns keras Input from input_var_names.

get_metadata(metadata_key)

Returns metadata.

get_pred_index()

Returns prediction index passed to loss calculation.

get_unique_id()

Returns unique identifier of task.

load_metadata()

Load metadata.

load_model()

Load pre-trained keras model weights.

predict([data, phase])

Evaluate model prediction.

predict_update([data, phase])

Predict and update data in StoreGate.

set_hps(params)

Set hyperparameters to this task.

show_info()

Print information.

update(data[, phase])

Update data in storegate.

Attributes

input_saver_key

Return input_saver_key.

input_var_names

Returns input_var_names.

job_id

Return job_id of task.

ml

Returns ML data class.

name

Return name of task.

output_saver_key

Return output_saver_key.

output_var_names

Returns output_var_names.

phases

Returns ML phases.

pool_id

Return pool_id of task.

pred_var_names

Returns pred_var_names.

save_var_names

Returns save_var_names.

saver

Return saver of task.

storegate

Return storegate of task.

subtask_id

Return subtask_id of task.

task_id

Return task_id of task.

trial_id

Return trial_id of task.

true_var_names

Returns true_var_names.

__init__(run_eagerly=None, callbacks=['EarlyStopping', 'ModelCheckpoint'], save_tensorboard=False, **kwargs)
Parameters:
  • run_eagerly (bool) – Run on eager execution mode (not graph mode).

  • callbacks (list(str or keras.Callback)) – callback for keras model training. Predefined callbacks (EarlyStopping, ModelCheckpoint, and TensorBoard) can be selected by str. Other user-defined callbacks should be given as keras.Callback object.

  • save_tensorboard (bool) – use tensorboard callback in training.

  • **kwargs – Arbitrary keyword arguments.