multiml.task.pytorch.pytorch_classification module
PytorchClassificationTask module.
- class multiml.task.pytorch.pytorch_classification.PytorchClassificationTask(device='cpu', gpu_ids=None, torchinfo=False, amp=False, torch_compile=False, dataset_args=None, dataloader_args=None, batch_sampler=False, metric_sample=1, **kwargs)
Bases:
PytorchBaseTask
Pytorch task for classification.
- predict(**kwargs)
Predict model.
This method predicts and returns results. Data need to be provided by
data
option, or setting property ofdataloaders
directory.- Parameters:
data (ndarray) – If
data
is given, data are converted toTendorDataset
and set todataloaders['test']
.dataloader (obj) – dataloader instance.
phase (str) – ‘all’ or ‘train’ or ‘valid’ or ‘test’ to specify dataloaders.
label (bool) – If True, returns metric results based on labels.
- Returns:
results of prediction.
- Return type:
ndarray or list