multiml.task.keras.keras_base module
KerasBaseTask module.
- class multiml.task.keras.keras_base.KerasBaseTask(run_eagerly=None, callbacks=['EarlyStopping', 'ModelCheckpoint'], save_tensorboard=False, **kwargs)
Bases:
MLBaseTask
Base task for Keras model.
Examples
>>> # your keras model >>> class MyKerasModel(Model): >>> def __init__(self, units=1): >>> super(MyKerasModel, self).__init__() >>> >>> self.dense = Dense(units) >>> self.relu = ReLU() >>> >>> def call(self, x): >>> return self.relu(self.dense(x)) >>> >>> # create task instance >>> task = KerasBaseTask(storegate=storegate, >>> model=MyKerasModel, >>> input_var_names=('x0', 'x1'), >>> output_var_names='outputs-keras', >>> true_var_names='labels', >>> optimizer='adam', >>> optimizer_args=dict(lr=0.1), >>> loss='binary_crossentropy') >>> task.set_hps({'num_epochs': 5}) >>> task.execute() >>> task.finalize()
- __init__(run_eagerly=None, callbacks=['EarlyStopping', 'ModelCheckpoint'], save_tensorboard=False, **kwargs)
- Parameters:
run_eagerly (bool) – Run on eager execution mode (not graph mode).
callbacks (list(str or keras.Callback)) – callback for keras model training. Predefined callbacks (EarlyStopping, ModelCheckpoint, and TensorBoard) can be selected by str. Other user-defined callbacks should be given as keras.Callback object.
save_tensorboard (bool) – use tensorboard callback in training.
**kwargs – Arbitrary keyword arguments.
- compile_model()
Compile keras model.
- compile_loss()
Compile keras model.
- load_model()
Load pre-trained keras model weights.
- dump_model(extra_args=None)
Dump current keras model.
- fit(train_data=None, valid_data=None)
Training model.
- Returns:
training results.
- Return type:
dict
- predict(data=None, phase=None)
Evaluate model prediction.
- Parameters:
phase (str) – data type (train, valid, test or None)
- Returns:
prediction by the model ndarray: target
- Return type:
ndarray
- get_inputs()
Returns keras Input from input_var_names.