multiml.task.pytorch.pytorch_ddp module

PytorchDDPTask module.

class multiml.task.pytorch.pytorch_ddp.PytorchDDPTask(ddp=True, addr='localhost', port='12355', backend='nccl', find_unused_parameters=False, **kwargs)

Bases: PytorchBaseTask

Distributed data parallel (DDP) task for PyTorch model.

__init__(ddp=True, addr='localhost', port='12355', backend='nccl', find_unused_parameters=False, **kwargs)

Initialize the pytorch DDP task.

compile_model(rank=None, world_size=None)

Build model.

compile_device()

Compile device.

dump_model(extra_args=None)

Dump current pytorch model.

prepare_dataloader(rank, world_size, data=None, phase=None, dataset_args=None, dataloader_args=None)

Prepare dataloader.

get_distributed_sampler(phase, dataset, rank, world_size, batch=False)

Get batch sampler.

fix_submodule(target)

Fix given parameters of model.

execute()

Execute the pytorch DDP task.

Multi processes are launched

abstract execute_mp(rank=None, world_size=None)

User defined algorithms.

Examples

>>> setup(rank, world_size)
>>> # your algorithms
>>> # ...
>>> cleanup()
setup(rank, world_size)

Setup multi processing.

cleanup()

Cleanup multi processing.