multiml.task.pytorch.PytorchBaseTask
- class multiml.task.pytorch.PytorchBaseTask(device='cpu', gpu_ids=None, torchinfo=False, amp=False, torch_compile=False, dataset_args=None, dataloader_args=None, batch_sampler=False, metric_sample=1, **kwargs)
Base task for PyTorch model.
Examples
>>> # your pytorch model >>> class MyPytorchModel(nn.Module): >>> def __init__(self, inputs=2, outputs=2): >>> super(MyPytorchModel, self).__init__() >>> >>> self.fc1 = nn.Linear(inputs, outputs) >>> self.relu = nn.ReLU() >>> >>> def forward(self, x): >>> return self.relu(self.fc1(x)) >>> >>> # create task instance >>> task = PytorchBaseTask(storegate=storegate, >>> model=MyPytorchModel, >>> input_var_names=('x0', 'x1'), >>> output_var_names='outputs-pytorch', >>> true_var_names='labels', >>> optimizer='SGD', >>> optimizer_args=dict(lr=0.1), >>> loss='CrossEntropyLoss') >>> task.set_hps({'num_epochs': 5}) >>> task.execute() >>> task.finalize()
- __init__(device='cpu', gpu_ids=None, torchinfo=False, amp=False, torch_compile=False, dataset_args=None, dataloader_args=None, batch_sampler=False, metric_sample=1, **kwargs)
Initialize the pytorch base task.
- Parameters:
device (str or obj) – pytorch device, e.g. ‘cpu’, ‘cuda’.
gpu_ids (list) – GPU identifiers, e.g. [0, 1, 2].
data_parallel
mode is enabled ifgpu_ids
is given.torchinfo (bool) – show torchinfo summary after model compile.
amp (bool) – (expert option) enable amp mode.
torch_compile (bool) – (expert option) enable torch.compile.
dataset_args (dict) – args passed to default DataSet creation.
dataloader_args (dict) – args passed to default DataLoader creation.
batch_sampler (bool) – user batch_sampler or not.
metric_sample (float or int) – sampling ratio for running metrics.
Methods
__init__
([device, gpu_ids, torchinfo, amp, ...])Initialize the pytorch base task.
add_device
(data, device)Add data to device.
build_model
()Build model.
compile
()Compile pytorch ml objects.
Compile device.
Compile pytorch loss.
Compile pytorch model.
Compile pytorch optimizer and scheduler.
compile_var_names
()Compile var_names.
do_test
()Perform test phase or not.
do_train
()Perform train phase or not.
do_valid
()Perform valid phase or not.
dump_model
([extra_args])Dump current pytorch model.
execute
()Execute a task.
finalize
()Finalize base task.
fit
([train_data, valid_data, dataloaders, ...])Train model over epoch.
fit_predict
([fit_args, predict_args])Fit and predict model.
fix_submodule
(target)Fix given parameters of model.
get_batch_sampler
(phase, dataset)Returns batch sampler.
get_dataset
([data, phase, preload, callbacks])Returns dataset from given ndarray data.
get_input_true_data
(phase)Get input and true data.
get_input_var_shapes
([phase])Get shape of input_var_names.
get_metadata
(metadata_key)Returns metadata.
get_pred_index
()Returns prediction index passed to loss calculation.
get_storegate_dataset
(phase[, preload, ...])Returns storegate dataset.
get_tensor_dataset
(data[, callbacks])Returns tensor dataset from given ndarray data.
get_unique_id
()Returns unique identifier of task.
load_metadata
()Load metadata.
Load pre-trained pytorch model weights.
predict
([data, dataloader, phase, label])Predict model.
predict_update
([data, phase])Predict and update data in StoreGate.
prepare_dataloader
([data, phase, ...])Prepare dataloader.
prepare_dataloaders
([phases, dataset_args, ...])Prepare dataloaders for all phases.
set_hps
(params)Set hyperparameters to this task.
show_info
()Print information.
step_batch
(data, phase[, label])Process batch data and update weights.
step_epoch
(epoch, phase, dataloader[, label])Process model for given epoch and phase.
step_loss
(outputs, labels)Process loss function.
step_model
(inputs)Process model.
step_optimizer
(loss)Process optimizer.
update
(data[, phase])Update data in storegate.
Attributes
input_saver_key
Return input_saver_key.
input_var_names
Returns input_var_names.
job_id
Return job_id of task.
ml
Returns ML data class.
name
Return name of task.
output_saver_key
Return output_saver_key.
output_var_names
Returns output_var_names.
phases
Returns ML phases.
pool_id
Return pool_id of task.
pred_var_names
Returns pred_var_names.
save_var_names
Returns save_var_names.
saver
Return saver of task.
storegate
Return storegate of task.
subtask_id
Return subtask_id of task.
task_id
Return task_id of task.
trial_id
Return trial_id of task.
true_var_names
Returns true_var_names.
- __init__(device='cpu', gpu_ids=None, torchinfo=False, amp=False, torch_compile=False, dataset_args=None, dataloader_args=None, batch_sampler=False, metric_sample=1, **kwargs)
Initialize the pytorch base task.
- Parameters:
device (str or obj) – pytorch device, e.g. ‘cpu’, ‘cuda’.
gpu_ids (list) – GPU identifiers, e.g. [0, 1, 2].
data_parallel
mode is enabled ifgpu_ids
is given.torchinfo (bool) – show torchinfo summary after model compile.
amp (bool) – (expert option) enable amp mode.
torch_compile (bool) – (expert option) enable torch.compile.
dataset_args (dict) – args passed to default DataSet creation.
dataloader_args (dict) – args passed to default DataLoader creation.
batch_sampler (bool) – user batch_sampler or not.
metric_sample (float or int) – sampling ratio for running metrics.