multiml.task.keras.modules.functional_model module

class multiml.task.keras.modules.functional_model.FunctionalModel(*args, **kwargs)

Bases: Functional

__init__(*args, **kwargs)

Base model to overwrite train_step().

TODO: this class is to avoid mix of functional API and subclass.

set_pred_index(pred_index)
train_step(data)

The logic for one training step.

This method can be overridden to support custom training logic. This method is called by Model.make_train_function.

This method should contain the mathematical logic for one step of training. This typically includes the forward pass, loss calculation, backpropagation, and metric updates.

Configuration details for how this logic is run (e.g. tf.function and tf.distribute.Strategy settings), should be left to Model.make_train_function, which can also be overridden.

Parameters:

data – A nested structure of `Tensor`s.

Returns:

A dict containing values that will be passed to tf.keras.callbacks.CallbackList.on_train_batch_end. Typically, the values of the Model’s metrics are returned. Example: {‘loss’: 0.2, ‘accuracy’: 0.7}.

test_step(data)

The logic for one evaluation step.

This method can be overridden to support custom evaluation logic. This method is called by Model.make_test_function.

This function should contain the mathematical logic for one step of evaluation. This typically includes the forward pass, loss calculation, and metrics updates.

Configuration details for how this logic is run (e.g. tf.function and tf.distribute.Strategy settings), should be left to Model.make_test_function, which can also be overridden.

Parameters:

data – A nested structure of `Tensor`s.

Returns:

A dict containing values that will be passed to tf.keras.callbacks.CallbackList.on_train_batch_end. Typically, the values of the Model’s metrics are returned.

select_pred_data(y_pred)