multiml.task.pytorch.PytorchBaseTask

class multiml.task.pytorch.PytorchBaseTask(device='cpu', gpu_ids=None, torchinfo=False, amp=False, torch_compile=False, dataset_args=None, dataloader_args=None, batch_sampler=False, metric_sample=1, **kwargs)

Base task for PyTorch model.

Examples

>>> # your pytorch model
>>> class MyPytorchModel(nn.Module):
>>>     def __init__(self, inputs=2, outputs=2):
>>>         super(MyPytorchModel, self).__init__()
>>>
>>>         self.fc1 = nn.Linear(inputs, outputs)
>>>         self.relu = nn.ReLU()
>>>
>>>     def forward(self, x):
>>>         return self.relu(self.fc1(x))
>>>
>>> # create task instance
>>> task = PytorchBaseTask(storegate=storegate,
>>>                        model=MyPytorchModel,
>>>                        input_var_names=('x0', 'x1'),
>>>                        output_var_names='outputs-pytorch',
>>>                        true_var_names='labels',
>>>                        optimizer='SGD',
>>>                        optimizer_args=dict(lr=0.1),
>>>                        loss='CrossEntropyLoss')
>>> task.set_hps({'num_epochs': 5})
>>> task.execute()
>>> task.finalize()
__init__(device='cpu', gpu_ids=None, torchinfo=False, amp=False, torch_compile=False, dataset_args=None, dataloader_args=None, batch_sampler=False, metric_sample=1, **kwargs)

Initialize the pytorch base task.

Parameters:
  • device (str or obj) – pytorch device, e.g. ‘cpu’, ‘cuda’.

  • gpu_ids (list) – GPU identifiers, e.g. [0, 1, 2]. data_parallel mode is enabled if gpu_ids is given.

  • torchinfo (bool) – show torchinfo summary after model compile.

  • amp (bool) – (expert option) enable amp mode.

  • torch_compile (bool) – (expert option) enable torch.compile.

  • dataset_args (dict) – args passed to default DataSet creation.

  • dataloader_args (dict) – args passed to default DataLoader creation.

  • batch_sampler (bool) – user batch_sampler or not.

  • metric_sample (float or int) – sampling ratio for running metrics.

Methods

__init__([device, gpu_ids, torchinfo, amp, ...])

Initialize the pytorch base task.

add_device(data, device)

Add data to device.

build_model()

Build model.

compile()

Compile pytorch ml objects.

compile_device()

Compile device.

compile_loss()

Compile pytorch loss.

compile_model()

Compile pytorch model.

compile_optimizer()

Compile pytorch optimizer and scheduler.

compile_var_names()

Compile var_names.

do_test()

Perform test phase or not.

do_train()

Perform train phase or not.

do_valid()

Perform valid phase or not.

dump_model([extra_args])

Dump current pytorch model.

execute()

Execute a task.

finalize()

Finalize base task.

fit([train_data, valid_data, dataloaders, ...])

Train model over epoch.

fit_predict([fit_args, predict_args])

Fit and predict model.

fix_submodule(target)

Fix given parameters of model.

get_batch_sampler(phase, dataset)

Returns batch sampler.

get_dataset([data, phase, preload, callbacks])

Returns dataset from given ndarray data.

get_input_true_data(phase)

Get input and true data.

get_input_var_shapes([phase])

Get shape of input_var_names.

get_metadata(metadata_key)

Returns metadata.

get_pred_index()

Returns prediction index passed to loss calculation.

get_storegate_dataset(phase[, preload, ...])

Returns storegate dataset.

get_tensor_dataset(data[, callbacks])

Returns tensor dataset from given ndarray data.

get_unique_id()

Returns unique identifier of task.

load_metadata()

Load metadata.

load_model()

Load pre-trained pytorch model weights.

predict([data, dataloader, phase, label])

Predict model.

predict_update([data, phase])

Predict and update data in StoreGate.

prepare_dataloader([data, phase, ...])

Prepare dataloader.

prepare_dataloaders([phases, dataset_args, ...])

Prepare dataloaders for all phases.

set_hps(params)

Set hyperparameters to this task.

show_info()

Print information.

step_batch(data, phase[, label])

Process batch data and update weights.

step_epoch(epoch, phase, dataloader[, label])

Process model for given epoch and phase.

step_loss(outputs, labels)

Process loss function.

step_model(inputs)

Process model.

step_optimizer(loss)

Process optimizer.

update(data[, phase])

Update data in storegate.

Attributes

input_saver_key

Return input_saver_key.

input_var_names

Returns input_var_names.

job_id

Return job_id of task.

ml

Returns ML data class.

name

Return name of task.

output_saver_key

Return output_saver_key.

output_var_names

Returns output_var_names.

phases

Returns ML phases.

pool_id

Return pool_id of task.

pred_var_names

Returns pred_var_names.

save_var_names

Returns save_var_names.

saver

Return saver of task.

storegate

Return storegate of task.

subtask_id

Return subtask_id of task.

task_id

Return task_id of task.

trial_id

Return trial_id of task.

true_var_names

Returns true_var_names.

__init__(device='cpu', gpu_ids=None, torchinfo=False, amp=False, torch_compile=False, dataset_args=None, dataloader_args=None, batch_sampler=False, metric_sample=1, **kwargs)

Initialize the pytorch base task.

Parameters:
  • device (str or obj) – pytorch device, e.g. ‘cpu’, ‘cuda’.

  • gpu_ids (list) – GPU identifiers, e.g. [0, 1, 2]. data_parallel mode is enabled if gpu_ids is given.

  • torchinfo (bool) – show torchinfo summary after model compile.

  • amp (bool) – (expert option) enable amp mode.

  • torch_compile (bool) – (expert option) enable torch.compile.

  • dataset_args (dict) – args passed to default DataSet creation.

  • dataloader_args (dict) – args passed to default DataLoader creation.

  • batch_sampler (bool) – user batch_sampler or not.

  • metric_sample (float or int) – sampling ratio for running metrics.