multiml.task.pytorch.pytorch_asngnas_task module
- class multiml.task.pytorch.pytorch_asngnas_task.PytorchASNGNASTask(asng_args, **kwargs)
Bases:
ModelConnectionTask
,PytorchBaseTask
- __init__(asng_args, **kwargs)
- Parameters:
subtasks (list) – list of task instances.
**kwargs – Arbitrary keyword arguments.
- build_model()
Build model.
- set_most_likely()
- best_model()
- get_most_likely()
- get_thetas()
- fit(train_data=None, valid_data=None, dataloaders=None, valid_step=1, sampler=None, rank=None, **kwargs)
Train model over epoch.
This methods train and valid model over epochs by calling
train_model()
method. train and valid need to be provided bytrain_data
andvalid_data
options, ordataloaders
option.- Parameters:
train_data (ndarray) – If
train_data
is given, data are converted toTendorDataset
and set todataloaders['train']
.valid_data (ndarray) – If
valid_data
is given, data are converted toTendorDataset
and set todataloaders['valid']
.dataloaders (dict) – dict of dataloaders, dict(train=xxx, valid=yyy).
valid_step (int) – step to process validation.
sampler (obf) – sampler to execute
set_epoch()
.kwargs (dict) – arbitrary args passed to
train_model()
.
- Returns:
history data of train and valid.
- Return type:
list
- step_epoch(epoch, phase, dataloader, label)
Process model for given epoch and phase.
ml.model
,ml.optimizer
andml.loss
need to be set before calling this method, please seecompile()
method.- Parameters:
epoch (int) – epoch numer.
phase (str) – train mode or valid mode.
dataloader (obj) – dataloader instance.
label (bool) – If True, returns metric results based on labels.
- Returns:
dict of result.
- Return type:
dict
- finalize()
Finalize base task.
Users implement their algorithms.
- get_submodel_names()
- get_submodel(i_models)
- asng()