multiml.task.keras.keras_base module

KerasBaseTask module.

class multiml.task.keras.keras_base.KerasBaseTask(run_eagerly=None, callbacks=['EarlyStopping', 'ModelCheckpoint'], save_tensorboard=False, **kwargs)

Bases: MLBaseTask

Base task for Keras model.

Examples

>>> # your keras model
>>> class MyKerasModel(Model):
>>>     def __init__(self, units=1):
>>>         super(MyKerasModel, self).__init__()
>>>
>>>         self.dense = Dense(units)
>>>         self.relu = ReLU()
>>>
>>>     def call(self, x):
>>>         return self.relu(self.dense(x))
>>>
>>> # create task instance
>>> task = KerasBaseTask(storegate=storegate,
>>>                      model=MyKerasModel,
>>>                      input_var_names=('x0', 'x1'),
>>>                      output_var_names='outputs-keras',
>>>                      true_var_names='labels',
>>>                      optimizer='adam',
>>>                      optimizer_args=dict(lr=0.1),
>>>                      loss='binary_crossentropy')
>>> task.set_hps({'num_epochs': 5})
>>> task.execute()
>>> task.finalize()
__init__(run_eagerly=None, callbacks=['EarlyStopping', 'ModelCheckpoint'], save_tensorboard=False, **kwargs)
Parameters:
  • run_eagerly (bool) – Run on eager execution mode (not graph mode).

  • callbacks (list(str or keras.Callback)) – callback for keras model training. Predefined callbacks (EarlyStopping, ModelCheckpoint, and TensorBoard) can be selected by str. Other user-defined callbacks should be given as keras.Callback object.

  • save_tensorboard (bool) – use tensorboard callback in training.

  • **kwargs – Arbitrary keyword arguments.

compile_model()

Compile keras model.

compile_loss()

Compile keras model.

load_model()

Load pre-trained keras model weights.

dump_model(extra_args=None)

Dump current keras model.

fit(train_data=None, valid_data=None)

Training model.

Returns:

training results.

Return type:

dict

predict(data=None, phase=None)

Evaluate model prediction.

Parameters:

phase (str) – data type (train, valid, test or None)

Returns:

prediction by the model ndarray: target

Return type:

ndarray

get_inputs()

Returns keras Input from input_var_names.