github.com/kubeflow/training-operator@v1.7.0/examples/pytorch/smoke-dist/dist_sendrecv.py (about)

     1  import logging
     2  import os
     3  import json
     4  import torch
     5  import torch.distributed as dist
     6  import torch.nn as nn
     7  import torch.nn.functional as F
     8  import torch.optim as optim
     9  
    10  from math import ceil
    11  from random import Random
    12  from torch.autograd import Variable
    13  from torchvision import datasets, transforms
    14  
    15  def run():
    16      """ Simple Send/Recv for testing Master <--> Workers communication """
    17      rank = dist.get_rank()
    18      size = dist.get_world_size()
    19      inp = torch.randn(2,2)
    20      result = torch.zeros(2,2)
    21      if rank == 0:
    22          # Send the input tensor to all workers
    23          for i in range(1, size):
    24              dist.send(tensor=inp, dst=i)
    25              # Receive the result tensor from all workers
    26              dist.recv(tensor=result, src=i)
    27              logging.info("Result from worker %d : %s", i, result)
    28      else:
    29          # Receive input tensor from master
    30          dist.recv(tensor=inp, src=0)
    31          # Elementwise tensor multiplication
    32          result = torch.mul(inp,inp)
    33          # Send the result tensor back to master
    34          dist.send(tensor=result, dst=0)
    35  
    36  def init_processes(fn, backend='gloo'):
    37      """ Initialize the distributed environment. """
    38      dist.init_process_group(backend)
    39      fn()
    40  
    41  def main():
    42      logging.info("Torch version: %s", torch.__version__)
    43      
    44      port = os.environ.get("MASTER_PORT", "{}")
    45      logging.info("MASTER_PORT: %s", port)
    46      
    47      addr = os.environ.get("MASTER_ADDR", "{}")
    48      logging.info("MASTER_ADDR: %s", addr)
    49  
    50      world_size = os.environ.get("WORLD_SIZE", "{}")
    51      logging.info("WORLD_SIZE: %s", world_size)
    52      
    53      rank = os.environ.get("RANK", "{}")
    54      logging.info("RANK: %s", rank)
    55      
    56      init_processes(run)
    57  
    58  
    59  if __name__ == "__main__":
    60      logging.getLogger().setLevel(logging.INFO)
    61      main()