multiml.task.pytorch.pytorch_ddp module
PytorchDDPTask module.
- class multiml.task.pytorch.pytorch_ddp.PytorchDDPTask(ddp=True, addr='localhost', port='12355', backend='nccl', find_unused_parameters=False, **kwargs)
Bases:
PytorchBaseTask
Distributed data parallel (DDP) task for PyTorch model.
- __init__(ddp=True, addr='localhost', port='12355', backend='nccl', find_unused_parameters=False, **kwargs)
Initialize the pytorch DDP task.
- compile_model(rank=None, world_size=None)
Build model.
- compile_device()
Compile device.
- dump_model(extra_args=None)
Dump current pytorch model.
- prepare_dataloader(rank, world_size, data=None, phase=None, dataset_args=None, dataloader_args=None)
Prepare dataloader.
- get_distributed_sampler(phase, dataset, rank, world_size, batch=False)
Get batch sampler.
- fix_submodule(target)
Fix given parameters of model.
- execute()
Execute the pytorch DDP task.
Multi processes are launched
- abstract execute_mp(rank=None, world_size=None)
User defined algorithms.
Examples
>>> setup(rank, world_size) >>> # your algorithms >>> # ... >>> cleanup()
- setup(rank, world_size)
Setup multi processing.
- cleanup()
Cleanup multi processing.