github.com/kubeflow/training-operator@v1.7.0/examples/xgboost/smoke-dist/tracker.py (about) 1 """ 2 Tracker script for DMLC 3 Implements the tracker control protocol 4 - start dmlc jobs 5 - start ps scheduler and rabit tracker 6 - help nodes to establish links with each other 7 Tianqi Chen 8 -------------------------- 9 This was taken from 10 https://github.com/dmlc/dmlc-core/blob/master/tracker/dmlc_tracker/tracker.py 11 See LICENSE here 12 https://github.com/dmlc/dmlc-core/blob/master/LICENSE 13 No code modified or added except for this explanatory comment. 14 """ 15 # pylint: disable=invalid-name, missing-docstring, too-many-arguments 16 # pylint: disable=too-many-locals 17 # pylint: disable=too-many-branches, too-many-statements 18 from __future__ import absolute_import 19 20 import os 21 import sys 22 import socket 23 import struct 24 import subprocess 25 import argparse 26 import time 27 import logging 28 from threading import Thread 29 30 31 class ExSocket(object): 32 """ 33 Extension of socket to handle recv and send of special data 34 """ 35 def __init__(self, sock): 36 self.sock = sock 37 38 def recvall(self, nbytes): 39 res = [] 40 nread = 0 41 while nread < nbytes: 42 chunk = self.sock.recv(min(nbytes - nread, 1024)) 43 nread += len(chunk) 44 res.append(chunk) 45 return b''.join(res) 46 47 def recvint(self): 48 return struct.unpack('@i', self.recvall(4))[0] 49 50 def sendint(self, n): 51 self.sock.sendall(struct.pack('@i', n)) 52 53 def sendstr(self, s): 54 self.sendint(len(s)) 55 self.sock.sendall(s.encode()) 56 57 def recvstr(self): 58 slen = self.recvint() 59 return self.recvall(slen).decode() 60 61 62 # magic number used to verify existence of data 63 kMagic = 0xff99 64 65 66 def get_some_ip(host): 67 return socket.getaddrinfo(host, None)[0][4][0] 68 69 70 def get_family(addr): 71 return socket.getaddrinfo(addr, None)[0][0] 72 73 74 class SlaveEntry(object): 75 def __init__(self, sock, s_addr): 76 slave = ExSocket(sock) 77 self.sock = slave 78 self.host = get_some_ip(s_addr[0]) 79 magic = slave.recvint() 80 assert magic == kMagic, 'invalid magic number=%d from %s' % ( 81 magic, self.host) 82 slave.sendint(kMagic) 83 self.rank = slave.recvint() 84 self.world_size = slave.recvint() 85 self.jobid = slave.recvstr() 86 self.cmd = slave.recvstr() 87 self.wait_accept = 0 88 self.port = None 89 90 def decide_rank(self, job_map): 91 if self.rank >= 0: 92 return self.rank 93 if self.jobid != 'NULL' and self.jobid in job_map: 94 return job_map[self.jobid] 95 return -1 96 97 def assign_rank(self, rank, wait_conn, tree_map, parent_map, ring_map): 98 self.rank = rank 99 nnset = set(tree_map[rank]) 100 rprev, rnext = ring_map[rank] 101 self.sock.sendint(rank) 102 # send parent rank 103 self.sock.sendint(parent_map[rank]) 104 # send world size 105 self.sock.sendint(len(tree_map)) 106 self.sock.sendint(len(nnset)) 107 # send the rprev and next link 108 for r in nnset: 109 self.sock.sendint(r) 110 # send prev link 111 if rprev != -1 and rprev != rank: 112 nnset.add(rprev) 113 self.sock.sendint(rprev) 114 else: 115 self.sock.sendint(-1) 116 # send next link 117 if rnext != -1 and rnext != rank: 118 nnset.add(rnext) 119 self.sock.sendint(rnext) 120 else: 121 self.sock.sendint(-1) 122 while True: 123 ngood = self.sock.recvint() 124 goodset = set([]) 125 for _ in range(ngood): 126 goodset.add(self.sock.recvint()) 127 assert goodset.issubset(nnset) 128 badset = nnset - goodset 129 conset = [] 130 for r in badset: 131 if r in wait_conn: 132 conset.append(r) 133 self.sock.sendint(len(conset)) 134 self.sock.sendint(len(badset) - len(conset)) 135 for r in conset: 136 self.sock.sendstr(wait_conn[r].host) 137 self.sock.sendint(wait_conn[r].port) 138 self.sock.sendint(r) 139 nerr = self.sock.recvint() 140 if nerr != 0: 141 continue 142 self.port = self.sock.recvint() 143 rmset = [] 144 # all connection was successuly setup 145 for r in conset: 146 wait_conn[r].wait_accept -= 1 147 if wait_conn[r].wait_accept == 0: 148 rmset.append(r) 149 for r in rmset: 150 wait_conn.pop(r, None) 151 self.wait_accept = len(badset) - len(conset) 152 return rmset 153 154 155 class RabitTracker(object): 156 """ 157 tracker for rabit 158 """ 159 def __init__(self, hostIP, nslave, port=9091, port_end=9999): 160 sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 161 for port in range(port, port_end): 162 try: 163 sock.bind((hostIP, port)) 164 self.port = port 165 break 166 except socket.error as e: 167 if e.errno in [98, 48]: 168 continue 169 else: 170 raise 171 sock.listen(256) 172 self.sock = sock 173 self.hostIP = hostIP 174 self.thread = None 175 self.start_time = None 176 self.end_time = None 177 self.nslave = nslave 178 logging.info('start listen on %s:%d', hostIP, self.port) 179 180 def __del__(self): 181 self.sock.close() 182 183 @staticmethod 184 def get_neighbor(rank, nslave): 185 rank = rank + 1 186 ret = [] 187 if rank > 1: 188 ret.append(rank // 2 - 1) 189 if rank * 2 - 1 < nslave: 190 ret.append(rank * 2 - 1) 191 if rank * 2 < nslave: 192 ret.append(rank * 2) 193 return ret 194 195 def slave_envs(self): 196 """ 197 get enviroment variables for slaves 198 can be passed in as args or envs 199 """ 200 return {'DMLC_TRACKER_URI': self.hostIP, 201 'DMLC_TRACKER_PORT': self.port} 202 203 def get_tree(self, nslave): 204 tree_map = {} 205 parent_map = {} 206 for r in range(nslave): 207 tree_map[r] = self.get_neighbor(r, nslave) 208 parent_map[r] = (r + 1) // 2 - 1 209 return tree_map, parent_map 210 211 def find_share_ring(self, tree_map, parent_map, r): 212 """ 213 get a ring structure that tends to share nodes with the tree 214 return a list starting from r 215 """ 216 nset = set(tree_map[r]) 217 cset = nset - set([parent_map[r]]) 218 if len(cset) == 0: 219 return [r] 220 rlst = [r] 221 cnt = 0 222 for v in cset: 223 vlst = self.find_share_ring(tree_map, parent_map, v) 224 cnt += 1 225 if cnt == len(cset): 226 vlst.reverse() 227 rlst += vlst 228 return rlst 229 230 def get_ring(self, tree_map, parent_map): 231 """ 232 get a ring connection used to recover local data 233 """ 234 assert parent_map[0] == -1 235 rlst = self.find_share_ring(tree_map, parent_map, 0) 236 assert len(rlst) == len(tree_map) 237 ring_map = {} 238 nslave = len(tree_map) 239 for r in range(nslave): 240 rprev = (r + nslave - 1) % nslave 241 rnext = (r + 1) % nslave 242 ring_map[rlst[r]] = (rlst[rprev], rlst[rnext]) 243 return ring_map 244 245 def get_link_map(self, nslave): 246 """ 247 get the link map, this is a bit hacky, call for better algorithm 248 to place similar nodes together 249 """ 250 tree_map, parent_map = self.get_tree(nslave) 251 ring_map = self.get_ring(tree_map, parent_map) 252 rmap = {0: 0} 253 k = 0 254 for i in range(nslave - 1): 255 k = ring_map[k][1] 256 rmap[k] = i + 1 257 258 ring_map_ = {} 259 tree_map_ = {} 260 parent_map_ = {} 261 for k, v in ring_map.items(): 262 ring_map_[rmap[k]] = (rmap[v[0]], rmap[v[1]]) 263 for k, v in tree_map.items(): 264 tree_map_[rmap[k]] = [rmap[x] for x in v] 265 for k, v in parent_map.items(): 266 if k != 0: 267 parent_map_[rmap[k]] = rmap[v] 268 else: 269 parent_map_[rmap[k]] = -1 270 return tree_map_, parent_map_, ring_map_ 271 272 def accept_slaves(self, nslave): 273 # set of nodes that finishs the job 274 shutdown = {} 275 # set of nodes that is waiting for connections 276 wait_conn = {} 277 # maps job id to rank 278 job_map = {} 279 # list of workers that is pending to be assigned rank 280 pending = [] 281 # lazy initialize tree_map 282 tree_map = None 283 284 while len(shutdown) != nslave: 285 fd, s_addr = self.sock.accept() 286 s = SlaveEntry(fd, s_addr) 287 if s.cmd == 'print': 288 msg = s.sock.recvstr() 289 logging.info(msg.strip()) 290 continue 291 if s.cmd == 'shutdown': 292 assert s.rank >= 0 and s.rank not in shutdown 293 assert s.rank not in wait_conn 294 shutdown[s.rank] = s 295 logging.debug('Recieve %s signal from %d', s.cmd, s.rank) 296 continue 297 assert s.cmd == 'start' or s.cmd == 'recover' 298 # lazily initialize the slaves 299 if tree_map is None: 300 assert s.cmd == 'start' 301 if s.world_size > 0: 302 nslave = s.world_size 303 tree_map, parent_map, ring_map = self.get_link_map(nslave) 304 # set of nodes that is pending for getting up 305 todo_nodes = list(range(nslave)) 306 else: 307 assert s.world_size == -1 or s.world_size == nslave 308 if s.cmd == 'recover': 309 assert s.rank >= 0 310 311 rank = s.decide_rank(job_map) 312 # batch assignment of ranks 313 if rank == -1: 314 assert len(todo_nodes) != 0 315 pending.append(s) 316 if len(pending) == len(todo_nodes): 317 pending.sort(key=lambda x: x.host) 318 for s in pending: 319 rank = todo_nodes.pop(0) 320 if s.jobid != 'NULL': 321 job_map[s.jobid] = rank 322 s.assign_rank(rank, wait_conn, tree_map, parent_map, 323 ring_map) 324 if s.wait_accept > 0: 325 wait_conn[rank] = s 326 logging.debug('Recieve %s signal from %s; ' 327 'assign rank %d', s.cmd, s.host, s.rank) 328 if len(todo_nodes) == 0: 329 logging.info('@tracker All of %d nodes getting started', 330 nslave) 331 self.start_time = time.time() 332 else: 333 s.assign_rank(rank, wait_conn, tree_map, parent_map, ring_map) 334 logging.debug('Recieve %s signal from %d', s.cmd, s.rank) 335 if s.wait_accept > 0: 336 wait_conn[rank] = s 337 338 logging.info("worker(ip_address=%s) connected!" % get_some_ip(s_addr[0])) 339 340 logging.info('@tracker All nodes finishes job') 341 self.end_time = time.time() 342 logging.info('@tracker %s secs between node start and job finish', 343 str(self.end_time - self.start_time)) 344 345 def start(self, nslave): 346 def run(): 347 self.accept_slaves(nslave) 348 self.thread = Thread(target=run, args=()) 349 self.thread.setDaemon(True) 350 self.thread.start() 351 352 def join(self): 353 while self.thread.isAlive(): 354 self.thread.join(100) 355 356 357 class PSTracker(object): 358 """ 359 Tracker module for PS 360 """ 361 def __init__(self, hostIP, cmd, port=9091, port_end=9999, envs=None): 362 """ 363 Starts the PS scheduler 364 """ 365 self.cmd = cmd 366 if cmd is None: 367 return 368 envs = {} if envs is None else envs 369 self.hostIP = hostIP 370 sock = socket.socket(get_family(hostIP), socket.SOCK_STREAM) 371 for port in range(port, port_end): 372 try: 373 sock.bind(('', port)) 374 self.port = port 375 sock.close() 376 break 377 except socket.error: 378 continue 379 env = os.environ.copy() 380 381 env['DMLC_ROLE'] = 'scheduler' 382 env['DMLC_PS_ROOT_URI'] = str(self.hostIP) 383 env['DMLC_PS_ROOT_PORT'] = str(self.port) 384 for k, v in envs.items(): 385 env[k] = str(v) 386 self.thread = Thread( 387 target=(lambda: subprocess.check_call(self.cmd, env=env, 388 shell=True)), args=()) 389 self.thread.setDaemon(True) 390 self.thread.start() 391 392 def join(self): 393 if self.cmd is not None: 394 while self.thread.isAlive(): 395 self.thread.join(100) 396 397 def slave_envs(self): 398 if self.cmd is None: 399 return {} 400 else: 401 return {'DMLC_PS_ROOT_URI': self.hostIP, 402 'DMLC_PS_ROOT_PORT': self.port} 403 404 405 def get_host_ip(hostIP=None): 406 if hostIP is None or hostIP == 'auto': 407 hostIP = 'ip' 408 409 if hostIP == 'dns': 410 hostIP = socket.getfqdn() 411 elif hostIP == 'ip': 412 from socket import gaierror 413 try: 414 hostIP = socket.gethostbyname(socket.getfqdn()) 415 except gaierror: 416 logging.warn('gethostbyname(socket.getfqdn()) failed... trying on ' 417 'hostname()') 418 hostIP = socket.gethostbyname(socket.gethostname()) 419 if hostIP.startswith("127."): 420 s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) 421 # doesn't have to be reachable 422 s.connect(('10.255.255.255', 1)) 423 hostIP = s.getsockname()[0] 424 return hostIP 425 426 427 def submit(nworker, nserver, fun_submit, hostIP='auto', pscmd=None): 428 if nserver == 0: 429 pscmd = None 430 431 envs = {'DMLC_NUM_WORKER': nworker, 432 'DMLC_NUM_SERVER': nserver} 433 hostIP = get_host_ip(hostIP) 434 435 if nserver == 0: 436 rabit = RabitTracker(hostIP=hostIP, nslave=nworker) 437 envs.update(rabit.slave_envs()) 438 rabit.start(nworker) 439 else: 440 pserver = PSTracker(hostIP=hostIP, cmd=pscmd, envs=envs) 441 envs.update(pserver.slave_envs()) 442 fun_submit(nworker, nserver, envs) 443 444 if nserver == 0: 445 rabit.join() 446 else: 447 pserver.join() 448 449 450 def start_rabit_tracker(args): 451 """Standalone function to start rabit tracker. 452 Parameters 453 ---------- 454 args: arguments to start the rabit tracker. 455 """ 456 envs = {'DMLC_NUM_WORKER': args.num_workers, 457 'DMLC_NUM_SERVER': args.num_servers} 458 rabit = RabitTracker(hostIP=get_host_ip(args.host_ip), 459 nslave=args.num_workers) 460 envs.update(rabit.slave_envs()) 461 rabit.start(args.num_workers) 462 sys.stdout.write('DMLC_TRACKER_ENV_START\n') 463 # simply write configuration to stdout 464 for k, v in envs.items(): 465 sys.stdout.write('%s=%s\n' % (k, str(v))) 466 sys.stdout.write('DMLC_TRACKER_ENV_END\n') 467 sys.stdout.flush() 468 rabit.join() 469 470 471 def main(): 472 """Main function if tracker is executed in standalone mode.""" 473 parser = argparse.ArgumentParser(description='Rabit Tracker start.') 474 parser.add_argument('--num-workers', required=True, type=int, 475 help='Number of worker proccess to be launched.') 476 parser.add_argument('--num-servers', default=0, type=int, 477 help='Number of server process to be launched. Only ' 478 'used in PS jobs.') 479 parser.add_argument('--host-ip', default=None, type=str, 480 help=('Host IP addressed, this is only needed ' + 481 'if the host IP cannot be automatically guessed.' 482 )) 483 parser.add_argument('--log-level', default='INFO', type=str, 484 choices=['INFO', 'DEBUG'], 485 help='Logging level of the logger.') 486 args = parser.parse_args() 487 488 fmt = '%(asctime)s %(levelname)s %(message)s' 489 if args.log_level == 'INFO': 490 level = logging.INFO 491 elif args.log_level == 'DEBUG': 492 level = logging.DEBUG 493 else: 494 raise RuntimeError("Unknown logging level %s" % args.log_level) 495 496 logging.basicConfig(format=fmt, level=level) 497 498 if args.num_servers == 0: 499 start_rabit_tracker(args) 500 else: 501 raise RuntimeError("Do not yet support start ps tracker in standalone " 502 "mode.") 503 504 505 if __name__ == "__main__": 506 main()