github.com/kubeflow/training-operator@v1.7.0/examples/pytorch/smoke-dist/dist_sendrecv.py (about) 1 import logging 2 import os 3 import json 4 import torch 5 import torch.distributed as dist 6 import torch.nn as nn 7 import torch.nn.functional as F 8 import torch.optim as optim 9 10 from math import ceil 11 from random import Random 12 from torch.autograd import Variable 13 from torchvision import datasets, transforms 14 15 def run(): 16 """ Simple Send/Recv for testing Master <--> Workers communication """ 17 rank = dist.get_rank() 18 size = dist.get_world_size() 19 inp = torch.randn(2,2) 20 result = torch.zeros(2,2) 21 if rank == 0: 22 # Send the input tensor to all workers 23 for i in range(1, size): 24 dist.send(tensor=inp, dst=i) 25 # Receive the result tensor from all workers 26 dist.recv(tensor=result, src=i) 27 logging.info("Result from worker %d : %s", i, result) 28 else: 29 # Receive input tensor from master 30 dist.recv(tensor=inp, src=0) 31 # Elementwise tensor multiplication 32 result = torch.mul(inp,inp) 33 # Send the result tensor back to master 34 dist.send(tensor=result, dst=0) 35 36 def init_processes(fn, backend='gloo'): 37 """ Initialize the distributed environment. """ 38 dist.init_process_group(backend) 39 fn() 40 41 def main(): 42 logging.info("Torch version: %s", torch.__version__) 43 44 port = os.environ.get("MASTER_PORT", "{}") 45 logging.info("MASTER_PORT: %s", port) 46 47 addr = os.environ.get("MASTER_ADDR", "{}") 48 logging.info("MASTER_ADDR: %s", addr) 49 50 world_size = os.environ.get("WORLD_SIZE", "{}") 51 logging.info("WORLD_SIZE: %s", world_size) 52 53 rank = os.environ.get("RANK", "{}") 54 logging.info("RANK: %s", rank) 55 56 init_processes(run) 57 58 59 if __name__ == "__main__": 60 logging.getLogger().setLevel(logging.INFO) 61 main()