multiml.agent.SequentialAgent
- class multiml.agent.SequentialAgent(differentiable=None, diff_pretrain=False, diff_task_args=None, num_trials=None, **kwargs)
Agent execute sequential tasks.
Examples
>>> task0 = your_task0 >>> task1 = your_task1 >>> task2 = your_task2 >>> >>> agent = SequentialAgent(storegate=storegate, >>> task_scheduler=[task0, task1, task2], >>> metric=your_metric) >>> agent.execute() >>> agent.finalize()
- __init__(differentiable=None, diff_pretrain=False, diff_task_args=None, num_trials=None, **kwargs)
Initialize sequential agent.
- Parameters:
differentiable (str) –
keras
orpytorch
. If differentiable is given,ConnectionTask()
is created based on sequential tasks. If differentiable is None (default), sequential tasks are executed step by step.diff_pretrain (bool) – If True, each subtask is trained before creating ConnectionTask()`.
diff_task_args (dict) – arbitrary args passed to
ConnectionTask()
.num_trials (ine) – number of trials. Average value of trials is used as final metric.
Methods
__init__
([differentiable, diff_pretrain, ...])Initialize sequential agent.
execute
()Execute sequential agent.
execute_differentiable
(subtasktuples, counter)Execute connection model.
execute_finalize
()Execute and finalize base agent.
execute_pipeline
(subtasktuples, counter[, trial])Execute pipeline.
execute_subtasktuples
(subtasktuples, counter)Execute given subtasktuples.
finalize
()Finalize sequential agent.
Attributes
metric
Return metric of base agent.
Return result of execution.
saver
Return saver of base agent.
storegate
Return storegate of base agent.
task_scheduler
Return task_scheduler of base agent.
- __init__(differentiable=None, diff_pretrain=False, diff_task_args=None, num_trials=None, **kwargs)
Initialize sequential agent.
- Parameters:
differentiable (str) –
keras
orpytorch
. If differentiable is given,ConnectionTask()
is created based on sequential tasks. If differentiable is None (default), sequential tasks are executed step by step.diff_pretrain (bool) – If True, each subtask is trained before creating ConnectionTask()`.
diff_task_args (dict) – arbitrary args passed to
ConnectionTask()
.num_trials (ine) – number of trials. Average value of trials is used as final metric.