github.com/kubeflow/training-operator@v1.7.0/examples/xgboost/xgboost-dist/tracker.py (about)

     1  """
     2  Tracker script for DMLC
     3  Implements the tracker control protocol
     4   - start dmlc jobs
     5   - start ps scheduler and rabit tracker
     6   - help nodes to establish links with each other
     7  Tianqi Chen
     8  --------------------------
     9  This was taken from
    10  https://github.com/dmlc/dmlc-core/blob/master/tracker/dmlc_tracker/tracker.py
    11  See LICENSE here
    12  https://github.com/dmlc/dmlc-core/blob/master/LICENSE
    13  No code modified or added except for this explanatory comment.
    14  """
    15  # pylint: disable=invalid-name, missing-docstring, too-many-arguments
    16  # pylint: disable=too-many-locals
    17  # pylint: disable=too-many-branches, too-many-statements
    18  from __future__ import absolute_import
    19  
    20  import os
    21  import sys
    22  import socket
    23  import struct
    24  import subprocess
    25  import argparse
    26  import time
    27  import logging
    28  from threading import Thread
    29  
    30  
    31  class ExSocket(object):
    32      """
    33      Extension of socket to handle recv and send of special data
    34      """
    35      def __init__(self, sock):
    36          self.sock = sock
    37  
    38      def recvall(self, nbytes):
    39          res = []
    40          nread = 0
    41          while nread < nbytes:
    42              chunk = self.sock.recv(min(nbytes - nread, 1024))
    43              nread += len(chunk)
    44              res.append(chunk)
    45          return b''.join(res)
    46  
    47      def recvint(self):
    48          return struct.unpack('@i', self.recvall(4))[0]
    49  
    50      def sendint(self, n):
    51          self.sock.sendall(struct.pack('@i', n))
    52  
    53      def sendstr(self, s):
    54          self.sendint(len(s))
    55          self.sock.sendall(s.encode())
    56  
    57      def recvstr(self):
    58          slen = self.recvint()
    59          return self.recvall(slen).decode()
    60  
    61  
    62  # magic number used to verify existence of data
    63  kMagic = 0xff99
    64  
    65  
    66  def get_some_ip(host):
    67      return socket.getaddrinfo(host, None)[0][4][0]
    68  
    69  
    70  def get_family(addr):
    71      return socket.getaddrinfo(addr, None)[0][0]
    72  
    73  
    74  class SlaveEntry(object):
    75      def __init__(self, sock, s_addr):
    76          slave = ExSocket(sock)
    77          self.sock = slave
    78          self.host = get_some_ip(s_addr[0])
    79          magic = slave.recvint()
    80          assert magic == kMagic, 'invalid magic number=%d from %s' % (
    81              magic, self.host)
    82          slave.sendint(kMagic)
    83          self.rank = slave.recvint()
    84          self.world_size = slave.recvint()
    85          self.jobid = slave.recvstr()
    86          self.cmd = slave.recvstr()
    87          self.wait_accept = 0
    88          self.port = None
    89  
    90      def decide_rank(self, job_map):
    91          if self.rank >= 0:
    92              return self.rank
    93          if self.jobid != 'NULL' and self.jobid in job_map:
    94              return job_map[self.jobid]
    95          return -1
    96  
    97      def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map):
    98          self.rank = rank
    99          nnset = set(tree_map[rank])
   100          rprev, rnext = ring_map[rank]
   101          self.sock.sendint(rank)
   102          # send parent rank
   103          self.sock.sendint(parent_map[rank])
   104          # send world size
   105          self.sock.sendint(len(tree_map))
   106          self.sock.sendint(len(nnset))
   107          # send the rprev and next link
   108          for r in nnset:
   109              self.sock.sendint(r)
   110          # send prev link
   111          if rprev != -1 and rprev != rank:
   112              nnset.add(rprev)
   113              self.sock.sendint(rprev)
   114          else:
   115              self.sock.sendint(-1)
   116          # send next link
   117          if rnext != -1 and rnext != rank:
   118              nnset.add(rnext)
   119              self.sock.sendint(rnext)
   120          else:
   121              self.sock.sendint(-1)
   122          while True:
   123              ngood = self.sock.recvint()
   124              goodset = set([])
   125              for _ in range(ngood):
   126                  goodset.add(self.sock.recvint())
   127              assert goodset.issubset(nnset)
   128              badset = nnset - goodset
   129              conset = []
   130              for r in badset:
   131                  if r in wait_conn:
   132                      conset.append(r)
   133              self.sock.sendint(len(conset))
   134              self.sock.sendint(len(badset) - len(conset))
   135              for r in conset:
   136                  self.sock.sendstr(wait_conn[r].host)
   137                  self.sock.sendint(wait_conn[r].port)
   138                  self.sock.sendint(r)
   139              nerr = self.sock.recvint()
   140              if nerr != 0:
   141                  continue
   142              self.port = self.sock.recvint()
   143              rmset = []
   144              # all connection was successuly setup
   145              for r in conset:
   146                  wait_conn[r].wait_accept -= 1
   147                  if wait_conn[r].wait_accept == 0:
   148                      rmset.append(r)
   149              for r in rmset:
   150                  wait_conn.pop(r, None)
   151              self.wait_accept = len(badset) - len(conset)
   152              return rmset
   153  
   154  
   155  class RabitTracker(object):
   156      """
   157      tracker for rabit
   158      """
   159      def __init__(self, hostIP, nslave, port=9091, port_end=9999):
   160          sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
   161          for port in range(port, port_end):
   162              try:
   163                  sock.bind((hostIP, port))
   164                  self.port = port
   165                  break
   166              except socket.error as e:
   167                  if e.errno in [98, 48]:
   168                      continue
   169                  else:
   170                      raise
   171          sock.listen(256)
   172          self.sock = sock
   173          self.hostIP = hostIP
   174          self.thread = None
   175          self.start_time = None
   176          self.end_time = None
   177          self.nslave = nslave
   178          logging.info('start listen on %s:%d', hostIP, self.port)
   179  
   180      def __del__(self):
   181          self.sock.close()
   182  
   183      @staticmethod
   184      def get_neighbor(rank, nslave):
   185          rank = rank + 1
   186          ret = []
   187          if rank > 1:
   188              ret.append(rank // 2 - 1)
   189          if rank * 2 - 1 < nslave:
   190              ret.append(rank * 2 - 1)
   191          if rank * 2 < nslave:
   192              ret.append(rank * 2)
   193          return ret
   194  
   195      def slave_envs(self):
   196          """
   197          get enviroment variables for slaves
   198          can be passed in as args or envs
   199          """
   200          return {'DMLC_TRACKER_URI': self.hostIP,
   201                  'DMLC_TRACKER_PORT': self.port}
   202  
   203      def get_tree(self, nslave):
   204          tree_map = {}
   205          parent_map = {}
   206          for r in range(nslave):
   207              tree_map[r] = self.get_neighbor(r, nslave)
   208              parent_map[r] = (r + 1) // 2 - 1
   209          return tree_map, parent_map
   210  
   211      def find_share_ring(self, tree_map, parent_map, r):
   212          """
   213          get a ring structure that tends to share nodes with the tree
   214          return a list starting from r
   215          """
   216          nset = set(tree_map[r])
   217          cset = nset - set([parent_map[r]])
   218          if len(cset) == 0:
   219              return [r]
   220          rlst = [r]
   221          cnt = 0
   222          for v in cset:
   223              vlst = self.find_share_ring(tree_map, parent_map, v)
   224              cnt += 1
   225              if cnt == len(cset):
   226                  vlst.reverse()
   227              rlst += vlst
   228          return rlst
   229  
   230      def get_ring(self, tree_map, parent_map):
   231          """
   232          get a ring connection used to recover local data
   233          """
   234          assert parent_map[0] == -1
   235          rlst = self.find_share_ring(tree_map, parent_map, 0)
   236          assert len(rlst) == len(tree_map)
   237          ring_map = {}
   238          nslave = len(tree_map)
   239          for r in range(nslave):
   240              rprev = (r + nslave - 1) % nslave
   241              rnext = (r + 1) % nslave
   242              ring_map[rlst[r]] = (rlst[rprev], rlst[rnext])
   243          return ring_map
   244  
   245      def get_link_map(self, nslave):
   246          """
   247          get the link map, this is a bit hacky, call for better algorithm
   248          to place similar nodes together
   249          """
   250          tree_map, parent_map = self.get_tree(nslave)
   251          ring_map = self.get_ring(tree_map, parent_map)
   252          rmap = {0: 0}
   253          k = 0
   254          for i in range(nslave - 1):
   255              k = ring_map[k][1]
   256              rmap[k] = i + 1
   257  
   258          ring_map_ = {}
   259          tree_map_ = {}
   260          parent_map_ = {}
   261          for k, v in ring_map.items():
   262              ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]])
   263          for k, v in tree_map.items():
   264              tree_map_[rmap[k]] = [rmap[x] for x in v]
   265          for k, v in parent_map.items():
   266              if k != 0:
   267                  parent_map_[rmap[k]] = rmap[v]
   268              else:
   269                  parent_map_[rmap[k]] = -1
   270          return tree_map_, parent_map_, ring_map_
   271  
   272      def accept_slaves(self, nslave):
   273          # set of nodes that finishs the job
   274          shutdown = {}
   275          # set of nodes that is waiting for connections
   276          wait_conn = {}
   277          # maps job id to rank
   278          job_map = {}
   279          # list of workers that is pending to be assigned rank
   280          pending = []
   281          # lazy initialize tree_map
   282          tree_map = None
   283  
   284          while len(shutdown) != nslave:
   285              fd, s_addr = self.sock.accept()
   286              s = SlaveEntry(fd, s_addr)
   287              if s.cmd == 'print':
   288                  msg = s.sock.recvstr()
   289                  logging.info(msg.strip())
   290                  continue
   291              if s.cmd == 'shutdown':
   292                  assert s.rank >= 0 and s.rank not in shutdown
   293                  assert s.rank not in wait_conn
   294                  shutdown[s.rank] = s
   295                  logging.debug('Recieve %s signal from %d', s.cmd, s.rank)
   296                  continue
   297              assert s.cmd == 'start' or s.cmd == 'recover'
   298              # lazily initialize the slaves
   299              if tree_map is None:
   300                  assert s.cmd == 'start'
   301                  if s.world_size > 0:
   302                      nslave = s.world_size
   303                  tree_map, parent_map, ring_map = self.get_link_map(nslave)
   304                  # set of nodes that is pending for getting up
   305                  todo_nodes = list(range(nslave))
   306              else:
   307                  assert s.world_size == -1 or s.world_size == nslave
   308              if s.cmd == 'recover':
   309                  assert s.rank >= 0
   310  
   311              rank = s.decide_rank(job_map)
   312              # batch assignment of ranks
   313              if rank == -1:
   314                  assert len(todo_nodes) != 0
   315                  pending.append(s)
   316                  if len(pending) == len(todo_nodes):
   317                      pending.sort(key=lambda x: x.host)
   318                      for s in pending:
   319                          rank = todo_nodes.pop(0)
   320                          if s.jobid != 'NULL':
   321                              job_map[s.jobid] = rank
   322                          s.assign_rank(rank, wait_conn, tree_map, parent_map,
   323                                        ring_map)
   324                          if s.wait_accept > 0:
   325                              wait_conn[rank] = s
   326                          logging.debug('Recieve %s signal from %s; '
   327                                        'assign rank %d', s.cmd, s.host, s.rank)
   328                  if len(todo_nodes) == 0:
   329                      logging.info('@tracker All of %d nodes getting started',
   330                                   nslave)
   331                      self.start_time = time.time()
   332              else:
   333                  s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map)
   334                  logging.debug('Recieve %s signal from %d', s.cmd, s.rank)
   335                  if s.wait_accept > 0:
   336                      wait_conn[rank] = s
   337  
   338              logging.info("worker(ip_address=%s) connected!" % get_some_ip(s_addr[0]))
   339  
   340          logging.info('@tracker All nodes finishes job')
   341          self.end_time = time.time()
   342          logging.info('@tracker %s secs between node start and job finish',
   343                       str(self.end_time - self.start_time))
   344  
   345      def start(self, nslave):
   346          def run():
   347              self.accept_slaves(nslave)
   348          self.thread = Thread(target=run, args=())
   349          self.thread.setDaemon(True)
   350          self.thread.start()
   351  
   352      def join(self):
   353          while self.thread.isAlive():
   354              self.thread.join(100)
   355  
   356  
   357  class PSTracker(object):
   358      """
   359      Tracker module for PS
   360      """
   361      def __init__(self, hostIP, cmd, port=9091, port_end=9999, envs=None):
   362          """
   363          Starts the PS scheduler
   364          """
   365          self.cmd = cmd
   366          if cmd is None:
   367              return
   368          envs = {} if envs is None else envs
   369          self.hostIP = hostIP
   370          sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM)
   371          for port in range(port, port_end):
   372              try:
   373                  sock.bind(('', port))
   374                  self.port = port
   375                  sock.close()
   376                  break
   377              except socket.error:
   378                  continue
   379          env = os.environ.copy()
   380  
   381          env['DMLC_ROLE'] = 'scheduler'
   382          env['DMLC_PS_ROOT_URI'] = str(self.hostIP)
   383          env['DMLC_PS_ROOT_PORT'] = str(self.port)
   384          for k, v in envs.items():
   385              env[k] = str(v)
   386          self.thread = Thread(
   387              target=(lambda: subprocess.check_call(self.cmd, env=env,
   388                                                    shell=True)), args=())
   389          self.thread.setDaemon(True)
   390          self.thread.start()
   391  
   392      def join(self):
   393          if self.cmd is not None:
   394              while self.thread.isAlive():
   395                  self.thread.join(100)
   396  
   397      def slave_envs(self):
   398          if self.cmd is None:
   399              return {}
   400          else:
   401              return {'DMLC_PS_ROOT_URI': self.hostIP,
   402                      'DMLC_PS_ROOT_PORT': self.port}
   403  
   404  
   405  def get_host_ip(hostIP=None):
   406      if hostIP is None or hostIP == 'auto':
   407          hostIP = 'ip'
   408  
   409      if hostIP == 'dns':
   410          hostIP = socket.getfqdn()
   411      elif hostIP == 'ip':
   412          from socket import gaierror
   413          try:
   414              hostIP = socket.gethostbyname(socket.getfqdn())
   415          except gaierror:
   416              logging.warn('gethostbyname(socket.getfqdn()) failed... trying on '
   417                           'hostname()')
   418              hostIP = socket.gethostbyname(socket.gethostname())
   419          if hostIP.startswith("127."):
   420              s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
   421              # doesn't have to be reachable
   422              s.connect(('10.255.255.255', 1))
   423              hostIP = s.getsockname()[0]
   424      return hostIP
   425  
   426  
   427  def submit(nworker, nserver, fun_submit, hostIP='auto', pscmd=None):
   428      if nserver == 0:
   429          pscmd = None
   430  
   431      envs = {'DMLC_NUM_WORKER': nworker,
   432              'DMLC_NUM_SERVER': nserver}
   433      hostIP = get_host_ip(hostIP)
   434  
   435      if nserver == 0:
   436          rabit = RabitTracker(hostIP=hostIP, nslave=nworker)
   437          envs.update(rabit.slave_envs())
   438          rabit.start(nworker)
   439      else:
   440          pserver = PSTracker(hostIP=hostIP, cmd=pscmd, envs=envs)
   441          envs.update(pserver.slave_envs())
   442      fun_submit(nworker, nserver, envs)
   443  
   444      if nserver == 0:
   445          rabit.join()
   446      else:
   447          pserver.join()
   448  
   449  
   450  def start_rabit_tracker(args):
   451      """Standalone function to start rabit tracker.
   452      Parameters
   453      ----------
   454      args: arguments to start the rabit tracker.
   455      """
   456      envs = {'DMLC_NUM_WORKER': args.num_workers,
   457              'DMLC_NUM_SERVER': args.num_servers}
   458      rabit = RabitTracker(hostIP=get_host_ip(args.host_ip),
   459                           nslave=args.num_workers)
   460      envs.update(rabit.slave_envs())
   461      rabit.start(args.num_workers)
   462      sys.stdout.write('DMLC_TRACKER_ENV_START\n')
   463      # simply write configuration to stdout
   464      for k, v in envs.items():
   465          sys.stdout.write('%s=%s\n' % (k, str(v)))
   466      sys.stdout.write('DMLC_TRACKER_ENV_END\n')
   467      sys.stdout.flush()
   468      rabit.join()
   469  
   470  
   471  def main():
   472      """Main function if tracker is executed in standalone mode."""
   473      parser = argparse.ArgumentParser(description='Rabit Tracker start.')
   474      parser.add_argument('--num-workers', required=True, type=int,
   475                          help='Number of worker proccess to be launched.')
   476      parser.add_argument('--num-servers', default=0, type=int,
   477                          help='Number of server process to be launched. Only '
   478                               'used in PS jobs.')
   479      parser.add_argument('--host-ip', default=None, type=str,
   480                          help=('Host IP addressed, this is only needed ' +
   481                                'if the host IP cannot be automatically guessed.'
   482                                ))
   483      parser.add_argument('--log-level', default='INFO', type=str,
   484                          choices=['INFO', 'DEBUG'],
   485                          help='Logging level of the logger.')
   486      args = parser.parse_args()
   487  
   488      fmt = '%(asctime)s %(levelname)s %(message)s'
   489      if args.log_level == 'INFO':
   490          level = logging.INFO
   491      elif args.log_level == 'DEBUG':
   492          level = logging.DEBUG
   493      else:
   494          raise RuntimeError("Unknown logging level %s" % args.log_level)
   495  
   496      logging.basicConfig(format=fmt, level=level)
   497  
   498      if args.num_servers == 0:
   499          start_rabit_tracker(args)
   500      else:
   501          raise RuntimeError("Do not yet support start ps tracker in standalone "
   502                             "mode.")
   503  
   504  
   505  if __name__ == "__main__":
   506      main()