multiml.task.pytorch.pytorch_asngnas_task module

class multiml.task.pytorch.pytorch_asngnas_task.PytorchASNGNASTask(asng_args, **kwargs)

Bases: ModelConnectionTask, PytorchBaseTask

__init__(asng_args, **kwargs)
Parameters:
  • subtasks (list) – list of task instances.

  • **kwargs – Arbitrary keyword arguments.

build_model()

Build model.

set_most_likely()
best_model()
get_most_likely()
get_thetas()
fit(train_data=None, valid_data=None, dataloaders=None, valid_step=1, sampler=None, rank=None, **kwargs)

Train model over epoch.

This methods train and valid model over epochs by calling train_model() method. train and valid need to be provided by train_data and valid_data options, or dataloaders option.

Parameters:
  • train_data (ndarray) – If train_data is given, data are converted to TendorDataset and set to dataloaders['train'].

  • valid_data (ndarray) – If valid_data is given, data are converted to TendorDataset and set to dataloaders['valid'].

  • dataloaders (dict) – dict of dataloaders, dict(train=xxx, valid=yyy).

  • valid_step (int) – step to process validation.

  • sampler (obf) – sampler to execute set_epoch().

  • kwargs (dict) – arbitrary args passed to train_model().

Returns:

history data of train and valid.

Return type:

list

step_epoch(epoch, phase, dataloader, label)

Process model for given epoch and phase.

ml.model, ml.optimizer and ml.loss need to be set before calling this method, please see compile() method.

Parameters:
  • epoch (int) – epoch numer.

  • phase (str) – train mode or valid mode.

  • dataloader (obj) – dataloader instance.

  • label (bool) – If True, returns metric results based on labels.

Returns:

dict of result.

Return type:

dict

finalize()

Finalize base task.

Users implement their algorithms.

get_submodel_names()
get_submodel(i_models)
asng()
class multiml.task.pytorch.pytorch_asngnas_task.training_results(_metrics, is_multi_loss, len_true_var_names)

Bases: object

__init__(_metrics, is_multi_loss, len_true_var_names)
get_results()
get_running_loss()
get_subloss()
update_results(batch_result)