multiml.task.pytorch.pytorch_classification module

PytorchClassificationTask module.

class multiml.task.pytorch.pytorch_classification.PytorchClassificationTask(device='cpu', gpu_ids=None, torchinfo=False, amp=False, torch_compile=False, dataset_args=None, dataloader_args=None, batch_sampler=False, metric_sample=1, **kwargs)

Bases: PytorchBaseTask

Pytorch task for classification.

predict(**kwargs)

Predict model.

This method predicts and returns results. Data need to be provided by data option, or setting property of dataloaders directory.

Parameters:
  • data (ndarray) – If data is given, data are converted to TendorDataset and set to dataloaders['test'].

  • dataloader (obj) – dataloader instance.

  • phase (str) – ‘all’ or ‘train’ or ‘valid’ or ‘test’ to specify dataloaders.

  • label (bool) – If True, returns metric results based on labels.

Returns:

results of prediction.

Return type:

ndarray or list