multiml.task.pytorch.PytorchBaseTask
- class multiml.task.pytorch.PytorchBaseTask(device='cpu', gpu_ids=None, torchinfo=False, amp=False, torch_compile=False, dataset_args=None, dataloader_args=None, batch_sampler=False, metric_sample=1, **kwargs)
Base task for PyTorch model.
Examples
>>> # your pytorch model >>> class MyPytorchModel(nn.Module): >>> def __init__(self, inputs=2, outputs=2): >>> super(MyPytorchModel, self).__init__() >>> >>> self.fc1 = nn.Linear(inputs, outputs) >>> self.relu = nn.ReLU() >>> >>> def forward(self, x): >>> return self.relu(self.fc1(x)) >>> >>> # create task instance >>> task = PytorchBaseTask(storegate=storegate, >>> model=MyPytorchModel, >>> input_var_names=('x0', 'x1'), >>> output_var_names='outputs-pytorch', >>> true_var_names='labels', >>> optimizer='SGD', >>> optimizer_args=dict(lr=0.1), >>> loss='CrossEntropyLoss') >>> task.set_hps({'num_epochs': 5}) >>> task.execute() >>> task.finalize()
- __init__(device='cpu', gpu_ids=None, torchinfo=False, amp=False, torch_compile=False, dataset_args=None, dataloader_args=None, batch_sampler=False, metric_sample=1, **kwargs)
Initialize the pytorch base task.
- Parameters:
device (str or obj) – pytorch device, e.g. ‘cpu’, ‘cuda’.
gpu_ids (list) – GPU identifiers, e.g. [0, 1, 2].
data_parallelmode is enabled ifgpu_idsis given.torchinfo (bool) – show torchinfo summary after model compile.
amp (bool) – (expert option) enable amp mode.
torch_compile (bool) – (expert option) enable torch.compile.
dataset_args (dict) – args passed to default DataSet creation.
dataloader_args (dict) – args passed to default DataLoader creation.
batch_sampler (bool) – user batch_sampler or not.
metric_sample (float or int) – sampling ratio for running metrics.
Methods
__init__([device, gpu_ids, torchinfo, amp, ...])Initialize the pytorch base task.
add_device(data, device)Add data to device.
build_model()Build model.
compile()Compile pytorch ml objects.
Compile device.
Compile pytorch loss.
Compile pytorch model.
Compile pytorch optimizer and scheduler.
compile_var_names()Compile var_names.
do_test()Perform test phase or not.
do_train()Perform train phase or not.
do_valid()Perform valid phase or not.
dump_model([extra_args])Dump current pytorch model.
execute()Execute a task.
finalize()Finalize base task.
fit([train_data, valid_data, dataloaders, ...])Train model over epoch.
fit_predict([fit_args, predict_args])Fit and predict model.
fix_submodule(target)Fix given parameters of model.
get_batch_sampler(phase, dataset)Returns batch sampler.
get_dataset([data, phase, preload, callbacks])Returns dataset from given ndarray data.
get_input_true_data(phase)Get input and true data.
get_input_var_shapes([phase])Get shape of input_var_names.
get_metadata(metadata_key)Returns metadata.
get_pred_index()Returns prediction index passed to loss calculation.
get_storegate_dataset(phase[, preload, ...])Returns storegate dataset.
get_tensor_dataset(data[, callbacks])Returns tensor dataset from given ndarray data.
get_unique_id()Returns unique identifier of task.
load_metadata()Load metadata.
Load pre-trained pytorch model weights.
predict([data, dataloader, phase, label])Predict model.
predict_update([data, phase])Predict and update data in StoreGate.
prepare_dataloader([data, phase, ...])Prepare dataloader.
prepare_dataloaders([phases, dataset_args, ...])Prepare dataloaders for all phases.
set_hps(params)Set hyperparameters to this task.
show_info()Print information.
step_batch(data, phase[, label])Process batch data and update weights.
step_epoch(epoch, phase, dataloader[, label])Process model for given epoch and phase.
step_loss(outputs, labels)Process loss function.
step_model(inputs)Process model.
step_optimizer(loss)Process optimizer.
update(data[, phase])Update data in storegate.
Attributes
input_saver_keyReturn input_saver_key.
input_var_namesReturns input_var_names.
job_idReturn job_id of task.
mlReturns ML data class.
nameReturn name of task.
output_saver_keyReturn output_saver_key.
output_var_namesReturns output_var_names.
phasesReturns ML phases.
pool_idReturn pool_id of task.
pred_var_namesReturns pred_var_names.
save_var_namesReturns save_var_names.
saverReturn saver of task.
storegateReturn storegate of task.
subtask_idReturn subtask_id of task.
task_idReturn task_id of task.
trial_idReturn trial_id of task.
true_var_namesReturns true_var_names.
- __init__(device='cpu', gpu_ids=None, torchinfo=False, amp=False, torch_compile=False, dataset_args=None, dataloader_args=None, batch_sampler=False, metric_sample=1, **kwargs)
Initialize the pytorch base task.
- Parameters:
device (str or obj) – pytorch device, e.g. ‘cpu’, ‘cuda’.
gpu_ids (list) – GPU identifiers, e.g. [0, 1, 2].
data_parallelmode is enabled ifgpu_idsis given.torchinfo (bool) – show torchinfo summary after model compile.
amp (bool) – (expert option) enable amp mode.
torch_compile (bool) – (expert option) enable torch.compile.
dataset_args (dict) – args passed to default DataSet creation.
dataloader_args (dict) – args passed to default DataLoader creation.
batch_sampler (bool) – user batch_sampler or not.
metric_sample (float or int) – sampling ratio for running metrics.