multiml.task.basic.ml_model_connection module

ModelConnectionTask module.

class multiml.task.basic.ml_model_connection.ModelConnectionTask(subtasks, loss_weights=None, variable_mapping=None, **kwargs)

Bases: MLBaseTask

Build a single task connecting with multiple tasks.

ModelConnectionTask connects multiple ML tasks considering the input/output variables and dependencies of the tasks, then builds a single task. ML model of component tasks are trained diferentially, thus each ML model must be implemented by the same deep learning library, i.e. Keras or Pytorch. Each subtask must contain

  • input_var_names, output_var_names` and true_var_names,

  • loss function,

to compile subtask dependencies and data I/O formats. The following examples shows a workflow and its attributes, which are automatically compiled:

Examples

>>> '''
>>> (input0, input1, input2)
>>>      |   |        |
>>>   [subtask0]      |
>>>       |           |
>>>   (output0)       |
>>>       |           |
>>>   [subtask1]------+
>>>       |
>>>   (output1)
>>> '''
>>>
>>> input_var_names = ['input0', 'input1', 'input2']
>>> output_var_names = ['output0', 'output1']
>>> input_var_index = [[0, 1], [-1, 2]]
>>> output_var_index = [[0], [1]]

Examples

>>> task = ModelConnectionTask(subtasks=[your_subtask0, your_subtask2],
>>>                            optimizer='SGD')
>>> task.execute()
__init__(subtasks, loss_weights=None, variable_mapping=None, **kwargs)

Constructor of ModelConnectionTask.

Parameters:
  • subtasks (list) – list must contains ordered instance objects inherited from MLBaseTask.

  • loss_weights (list or dict or str) – list of loss weights for each task. last_loss and flat_loss are also allowed.

  • variable_mapping (list(str, str)) – Input variables are replaced following this list. Used for the case that the input variables change from pre-training to main-training (with model connecting).

  • **kwargs – Arbitrary keyword arguments passed to MLBaseTask.

compile()

Compile subtasks and this task.

compile_loss()

Compile loss and loss_weights.

Loss functions are retrieved from subtasks, thus each subtask must contain loss.

compile_var_names()

Compile subtask dependencies and I/O variables.

set_output_var_index()

Set output_var_names and output_var_index.

set_input_var_index()

Set input_var_names and input_var_index.