lab.nexedi.com/kirr/go123@v0.0.0-20240207185015-8299741fa871/xnet/lonet/__init__.py (about) 1 # -*- coding: utf-8 -*- 2 # Copyright (C) 2018-2020 Nexedi SA and Contributors. 3 # Kirill Smelkov <kirr@nexedi.com> 4 # 5 # This program is free software: you can Use, Study, Modify and Redistribute 6 # it under the terms of the GNU General Public License version 3, or (at your 7 # option) any later version, as published by the Free Software Foundation. 8 # 9 # You can also Link and Combine this program with other software covered by 10 # the terms of any of the Free Software licenses or any of the Open Source 11 # Initiative approved licenses and Convey the resulting work. Corresponding 12 # source of such a combination shall include the source code for all other 13 # software used. 14 # 15 # This program is distributed WITHOUT ANY WARRANTY; without even the implied 16 # warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. 17 # 18 # See COPYING file for full licensing terms. 19 # See https://www.nexedi.com/licensing for rationale and options. 20 """Package lonet provides TCP network simulated on top of localhost TCP loopback. 21 22 See lonet.go for lonet description, organization and protocol. 23 """ 24 25 # NOTE this package is deliberately concise and follows lonet.go structure, 26 # which is more well documented. 27 28 import sys, os, stat, errno, tempfile, re 29 import socket as net 30 31 import sqlite3 32 import functools 33 import threading 34 import logging as log 35 36 from golang import func, defer, go, chan, select, default, panic, gimport 37 from golang import sync 38 from golang.gcompat import qq 39 40 xerr = gimport('lab.nexedi.com/kirr/go123/xerr') 41 Error = xerr.Error 42 errctx = xerr.context 43 errcause= xerr.cause 44 45 46 # set_once sets threading.Event, but only once. 47 # 48 # it returns whether event was set. 49 # 50 # if set_once(down_once): 51 # ... 52 # 53 # is analog of 54 # 55 # downOnce.Do(...) 56 # 57 # in Go. 58 # 59 # TODO just use sync.Once from pygolang. 60 _oncemu = threading.Lock() 61 def set_once(event): 62 with _oncemu: 63 if event.is_set(): 64 return False 65 event.set() 66 return True 67 68 69 70 # -------- virtnet -------- 71 # 72 # See ../virtnet/virtnet.go for details. 73 74 # neterror creates net.error and registers it as WDE to xerr. 75 def neterror(*argv): 76 err = net.error(*argv) 77 xerr.register_wde_object(err) 78 return err 79 80 ErrNetDown = neterror(errno.EBADFD, "network is down") 81 ErrHostDown = neterror(errno.EBADFD, "host is down") 82 ErrSockDown = neterror(errno.EBADFD, "socket is down") 83 ErrAddrAlreadyUsed = neterror(errno.EADDRINUSE, "address already in use") 84 ErrAddrNoListen = neterror(errno.EADDRNOTAVAIL, "cannot listen on requested address") 85 ErrConnRefused = neterror(errno.ECONNREFUSED, "connection refused") 86 87 ErrNoHost = neterror("no such host") 88 ErrHostDup = neterror("host already registered") 89 90 91 # addrstr4 formats host:port as if for TCP4 network. 92 def addrstr4(host, port): 93 return "%s:%d" % (host, port) 94 95 # Addr represent address of a virtnet endpoint. 96 class Addr(object): 97 # .net str 98 # .host str 99 # .port int 100 101 def __init__(self, net, host, port): 102 self.net, self.host, self.port = net, host, port 103 104 # netaddr returns address as net.AF_INET (host, port) pair. 105 def netaddr(self): 106 return (self.host, self.port) 107 108 def __str__(self): 109 return addrstr4(*self.netaddr()) 110 111 def __eq__(a, b): 112 return isinstance(b, Addr) and a.net == b.net and a.host == b.host and a.port == b.port 113 114 115 # VirtSubNetwork represents one subnetwork of a virtnet network. 116 class VirtSubNetwork(object): 117 # ._network str 118 # ._registry Registry 119 # ._hostmu μ 120 # ._hostmap {} hostname -> Host 121 # ._nopenhosts int 122 # ._autoclose bool 123 # ._down chan ø 124 # ._down_once threading.Event 125 126 def __init__(self, network, registry): 127 self._network = network 128 self._registry = registry 129 self._hostmu = threading.Lock() 130 self._hostmap = {} 131 self._nopenhosts = 0 132 self._autoclose = False 133 self._down = chan() 134 self._down_once = threading.Event() 135 136 # must be implemented in particular virtnet implementation 137 def _vnet_newhost(self, hostname, registry): raise NotImplementedError() 138 def _vnet_dial(self, src, dst, dstosladdr): raise NotImplementedError() 139 def _vnet_close(self): raise NotImplementedError() 140 141 142 # Host represents named access point on a virtnet network. 143 class Host(object): 144 # ._subnet VirtSubNetwork 145 # ._name str 146 # ._sockmu μ 147 # ._socketv []socket ; port -> listener | conn ; [0] is always None 148 # ._down chan ø 149 # ._down_once threading.Event 150 # ._close_once sync.Once 151 152 def __init__(self, subnet, name): 153 self._subnet = subnet 154 self._name = name 155 self._sockmu = threading.Lock() 156 self._socketv = [] 157 self._down = chan() 158 self._down_once = threading.Event() 159 self._close_once = sync.Once() 160 161 162 # socket represents one endpoint entry on Host. 163 class socket(object): 164 # ._host Host 165 # ._port int 166 167 # ._conn conn | None 168 # ._listener listener | None 169 170 def __init__(self, host, port): 171 self._host, self._port = host, port 172 self._conn = self._listener = None 173 174 175 # conn represents one endpoint of a virtnet connection. 176 class conn(object): 177 # ._socket socket 178 # ._peerAddr Addr 179 # ._netsk net.socket (embedded) 180 # ._down chan() 181 # ._down_once threading.Event 182 # ._close_once threading.Event 183 184 def __init__(self, sk, peerAddr, netsk): 185 self._socket, self._peerAddr, self._netsk = sk, peerAddr, netsk 186 self._down = chan() 187 self._down_once = threading.Event() 188 self._close_once = threading.Event() 189 190 # ._netsk embedded: 191 def __getattr__(self, name): 192 return getattr(self._netsk, name) 193 194 195 # listener implements net.Listener for Host. 196 class listener(object): 197 # ._socket socket 198 # ._dialq chan dialReq 199 # ._down chan ø 200 # ._down_once threading.Event 201 # ._close_once threading.Event 202 203 def __init__(self, sk): 204 self._socket = sk 205 self._dialq = chan() 206 self._down = chan() 207 self._down_once = threading.Event() 208 self._close_once = threading.Event() 209 210 211 # dialReq represents one dial request to listener from acceptor. 212 class dialReq(object): 213 # ._from Addr 214 # ._netsk net.socket 215 # ._resp chan Accept 216 217 def __init__(self, from_, netsk, resp): 218 self._from, self._netsk, self._resp = from_, netsk, resp 219 220 221 # Accept represents successful acceptance decision from VirtSubNetwork._vnet_accept . 222 class Accept(object): 223 # .addr Addr 224 # .ack chan error 225 def __init__(self, addr, ack): 226 self.addr, self.ack = addr, ack 227 228 229 # ---------------------------------------- 230 231 # _shutdown is worker for close and _vnet_down. 232 @func(VirtSubNetwork) 233 def _shutdown(n, exc): 234 n.__shutdown(exc, True) 235 @func(VirtSubNetwork) 236 def __shutdown(n, exc, withHosts): 237 if not set_once(n._down_once): 238 return 239 240 n._down.close() 241 242 if withHosts: 243 with n._hostmu: 244 for host in n._hostmap.values(): 245 host._shutdown() 246 247 # XXX py: we don't collect / remember .downErr 248 if exc is not None: 249 log.error(exc) 250 n._vnet_close() 251 n._registry.close() 252 253 254 # close shutdowns subnetwork. 255 @func(VirtSubNetwork) 256 def close(n): 257 n.__close(True) 258 @func(VirtSubNetwork) 259 def _closeWithoutHosts(n): 260 n.__close(False) 261 @func(VirtSubNetwork) 262 def __close(n, withHosts): 263 with errctx("virtnet %s: close" % qq(n._network)): 264 n.__shutdown(None, withHosts) 265 266 # _vnet_down shutdowns subnetwork upon engine error. 267 @func(VirtSubNetwork) 268 def _vnet_down(n, exc): 269 # XXX py: errctx here (go does not have) because we do not reraise .downErr in close 270 with errctx("virtnet %s: shutdown" % qq(n._network)): 271 n._shutdown(exc) 272 273 274 # new_host creates new Host with given name. 275 @func(VirtSubNetwork) 276 def new_host(n, name): 277 with errctx("virtnet %s: new host %s" % (qq(n._network), qq(name))): 278 n._vnet_newhost(name, n._registry) 279 # XXX check err due to subnet down 280 281 with n._hostmu: 282 if name in n._hostmap: 283 panic("announced ok but .hostMap already !empty" % (qq(n._network), qq(name))) 284 285 host = Host(n, name) 286 n._hostmap[name] = host 287 n._nopenhosts += 1 288 return host 289 290 291 # host returns host on the subnetwork by name. 292 @func(VirtSubNetwork) 293 def host(n, name): 294 with n._hostmu: 295 return n._hostmap.get(name) 296 297 298 # _shutdown is underlying worker for close. 299 @func(Host) 300 def _shutdown(h): 301 if not set_once(h._down_once): 302 return 303 304 h._down.close() 305 306 with h._sockmu: 307 for sk in h._socketv: 308 if sk is None: 309 continue 310 if sk._conn is not None: 311 sk._conn._shutdown() 312 if sk._listener is not None: 313 sk._listener._shutdown() 314 315 # close shutdowns host. 316 @func(Host) 317 def close(h): 318 def autoclose(): 319 def _(): 320 n = h._subnet 321 with n._hostmu: 322 n._nopenHosts -= 1 323 if n._nopenHosts < 0: 324 panic("SubNetwork._nopenHosts < 0") 325 if n._autoclose and n._nopenHosts == 0: 326 n._closeWithoutHosts() 327 h._close_once.do(_) 328 defer(autoclose) 329 330 with errctx("virtnet %s: host %s: close" % (qq(h._subnet._network), qq(h._name))): 331 h._shutdown() 332 333 # autoclose schedules close to be called after last host on this subnetwork is closed. 334 @func(VirtSubNetwork) 335 def autoclose(n): 336 with n._hostmu: 337 if n._nopenHosts == 0: 338 panic("BUG: no opened hosts") 339 n._autoclose = True 340 341 342 # listen starts new listener on the host. 343 @func(Host) 344 def listen(h, laddr): 345 if laddr == "": 346 laddr = ":0" 347 348 with errctx("listen %s %s" % (h.network(), laddr)): 349 a = h._parseAddr(laddr) 350 351 if a.host != h._name: 352 raise ErrAddrNoListen 353 354 if ready(h._down): 355 h._excDown() 356 357 with h._sockmu: 358 if a.port == 0: 359 sk = h._allocFreeSocket() 360 361 else: 362 while a.port >= len(h._socketv): 363 h._socketv.append(None) 364 365 if h._socketv[a.port] is not None: 366 raise ErrAddrAlreadyUsed 367 368 sk = socket(h, a.port) 369 h._socketv[a.port] = sk 370 371 l = listener(sk) 372 sk._listener = l 373 374 return l 375 376 377 # _shutdown shutdowns the listener. 378 @func(listener) 379 def _shutdown(l): 380 if set_once(l._down_once): 381 l._down.close() 382 383 # close closes the listener. 384 @func(listener) 385 def close(l): 386 l._shutdown() 387 if not set_once(l._close_once): 388 return 389 390 sk = l._socket 391 h = sk._host 392 393 with h._sockmu: 394 sk._listener = None 395 if sk._empty(): 396 h._socketv[sk.port] = None 397 398 399 # accept tries to connect to dial called with addr corresponding to our listener. 400 @func(listener) 401 def accept(l): 402 h = l._socket._host 403 404 with errctx("accept %s %s" % (h.network(), l.addr())): 405 while 1: 406 _, _rx = select( 407 l._down.recv, # 0 408 l._dialq.recv, # 1 409 ) 410 if _ == 0: 411 l._excDown() 412 if _ == 1: 413 req = _rx 414 415 with h._sockmu: 416 sk = h._allocFreeSocket() 417 418 ack = chan() 419 req._resp.send(Accept(sk.addr(), ack)) 420 421 _, _rx = select( 422 l._down.recv, # 0 423 ack.recv, # 1 424 ) 425 if _ == 0: 426 def purgesk(): 427 err = ack.recv() 428 if err is None: 429 try: 430 req._netsk.close() 431 except: 432 pass 433 with h._sockmu: 434 h._socketv[sk._port] = None 435 436 go(purgesk) 437 l._excDown() 438 439 if _ == 1: 440 err = _rx 441 442 if err is not None: 443 with h._sockmu: 444 h._socketv[sk._port] = None 445 continue 446 447 c = conn(sk, req._from, req._netsk) 448 with h._sockmu: 449 sk.conn = c 450 451 return c 452 453 454 # _vnet_accept accepts or rejects incoming connection. 455 @func(VirtSubNetwork) 456 def _vnet_accept(n, src, dst, netconn): 457 with n._hostmu: 458 host = n._hostmap.get(dst.host) 459 if host is None: 460 raise net.gaierror('%s: no such host' % dst.host) 461 462 host._sockmu.acquire() 463 464 if dst.port >= len(host._socketv): 465 host._sockmu.release() 466 raise ErrConnRefused 467 468 sk = host._socketv[dst.port] 469 if sk is None or sk._listener is None: 470 host._sockmu.release() 471 raise ErrConnRefused 472 473 l = sk._listener 474 host._sockmu.release() 475 476 resp = chan() 477 req = dialReq(src, netconn, resp) 478 479 _, _rx = select( 480 l._down.recv, # 0 481 (l._dialq.send, req), # 1 482 ) 483 if _ == 0: 484 raise ErrConnRefused 485 if _ == 1: 486 return resp.recv() 487 488 489 # dial dials address on the network. 490 @func(Host) 491 def dial(h, addr): 492 with h._sockmu: 493 sk = h._allocFreeSocket() 494 495 # XXX py: default dst to addr to be able to render it in error if it happens before parse 496 dst = addr 497 498 try: 499 dst = h._parseAddr(addr) 500 n = h._subnet 501 502 # XXX cancel on host shutdown 503 504 dstdata = n._registry.query(dst.host) 505 if dstdata is None: 506 raise ErrNoHost 507 508 netsk, acceptAddr = n._vnet_dial(sk.addr(), dst, dstdata) 509 510 c = conn(sk, acceptAddr, netsk) 511 with h._sockmu: 512 sk._conn = c 513 return c 514 515 except Exception as err: 516 with h._sockmu: 517 h._socketv[sk._port] = None 518 519 _, _, tb = sys.exc_info() 520 raise Error("dial %s %s->%s" % (h.network(), sk.addr(), dst), err, tb) 521 522 523 # ---- conn ---- 524 525 # _shutdown closes underlying network connection. 526 @func(conn) 527 def _shutdown(c): 528 if not set_once(c._down_once): 529 return 530 531 c._down.close() 532 # XXX py: we don't remember .errClose 533 c._netsk.close() 534 535 536 # close closes network endpoint and unregisters conn from Host. 537 @func(conn) 538 def close(c): 539 c._shutdown() 540 if set_once(c._close_once): 541 sk = c._socket 542 h = sk._host 543 544 with h._sockmu: 545 sk._conn = None 546 if sk._empty(): 547 h._socketv[sk._port] = None 548 549 # XXX py: we don't reraise c.errClose 550 551 # XXX py: don't bother to override recv (Read) 552 # XXX py: don't bother to override send (Write) 553 554 # local_addr returns virtnet address of local end of connection. 555 @func(conn) 556 def local_addr(c): 557 return c._socket.addr() 558 559 # getsockname returns virtnet address of local end of connection as net.AF_INET (host, port) pair. 560 @func(conn) 561 def getsockname(c): 562 return c.local_addr().netaddr() 563 564 # remote_addr returns virtnet address of remote end of connection. 565 @func(conn) 566 def remote_addr(c): 567 return c._peerAddr 568 569 # getpeername returns virtnet address of remote end of connection as net.AF_INET (host, port) pair. 570 @func(conn) 571 def getpeername(c): 572 return c.remote_addr().netaddr() 573 574 # ---------------------------------------- 575 576 # _allocFreeSocket finds first free port and allocates socket entry for it. 577 @func(Host) 578 def _allocFreeSocket(h): 579 port = 1 580 while port < len(h._socketv): 581 if h._socketv[port] is None: 582 break 583 port += 1 584 585 while port >= len(h._socketv): 586 h._socketv.append(None) 587 588 sk = socket(h, port) 589 h._socketv[port] = sk 590 return sk 591 592 593 # empty checks whether socket's both conn and listener are all nil. 594 @func(socket) 595 def _empty(sk): 596 return (sk._conn is None and sk._listener is None) 597 598 # addr returns address corresponding to socket. 599 @func(socket) 600 def addr(sk): 601 h = sk._host 602 return Addr(h.network(), h.name(), sk._port) 603 604 # Addr.parse parses addr into virtnet address for named network. 605 @func(Addr) 606 @staticmethod 607 def parse(net, addr): 608 try: 609 addrv = addr.split(':') 610 if len(addrv) != 2: 611 raise ValueError() 612 return Addr(net, addrv[0], int(addrv[1])) 613 except: 614 raise ValueError('%s is not valid virtnet address' % addr) 615 616 # _parseAddr parses addr into virtnet address from host point of view. 617 @func(Host) 618 def _parseAddr(h, addr): 619 a = Addr.parse(h.network(), addr) 620 if a.host == "": 621 a.host = h._name 622 return a 623 624 # addr returns address where listener is accepting incoming connections. 625 @func(listener) 626 def addr(l): 627 return l._socket.addr() 628 629 630 # network returns full network name this subnetwork is part of. 631 @func(VirtSubNetwork) 632 def network(n): 633 return n._network 634 635 # network returns full network name of underlying network. 636 @func(Host) 637 def network(h): 638 return h._subnet.network() 639 640 # name returns host name. 641 @func(Host) 642 def name(h): 643 return h._name 644 645 # ---------------------------------------- 646 647 # _excDown raises appropriate exception cause when h.down is found ready. 648 @func(Host) 649 def _excDown(h): 650 if ready(h._subnet._down): 651 raise ErrNetDown 652 else: 653 raise ErrHostDown 654 655 # _excDown raises appropriate exception cause when l.down is found ready. 656 @func(listener) 657 def _excDown(l): 658 h = l._socket._host 659 n = h._subnet 660 661 if ready(n._down): 662 raise ErrNetDown 663 elif ready(h._down): 664 raise ErrHostDown 665 else: 666 raise ErrSockDown 667 668 # XXX py: conn.errOrDown is not implemented because conn.{Read,Write} are not wrapped. 669 670 # ready returns whether channel ch is ready. 671 def ready(ch): 672 _, _rx = select( 673 ch.recv, # 0 674 default, # 1 675 ) 676 if _ == 0: 677 return True 678 if _ == 1: 679 return False 680 681 682 # -------- lonet networking -------- 683 # 684 # See lonet.go for details. 685 686 # protocolError represents logical error in lonet handshake exchange. 687 class protocolError(Exception): 688 pass 689 690 xerr.register_wde_class(protocolError) 691 692 693 # `mkdir -p`; https://stackoverflow.com/a/273227 694 def _mkdir_p(path, mode): 695 try: 696 os.makedirs(path, mode) 697 except OSError as e: 698 if e.errno != errno.EEXIST: 699 raise 700 701 # join joins or creates new lonet network with given name. 702 def join(network): 703 with errctx("lonet: join %s" % qq(network)): 704 lonet = tempfile.gettempdir() + "/lonet" 705 _mkdir_p(lonet, 0777 | stat.S_ISVTX) 706 707 if network != "": 708 netdir = lonet + "/" + network 709 _mkdir_p(netdir, 0700) 710 else: 711 netdir = tempfile.mkdtemp(dir=lonet) 712 network = os.path.basename(netdir) 713 714 registry = SQLiteRegistry(netdir + "/registry.db", network) 715 return _SubNetwork("lonet" + network, registry) 716 717 718 # lonet handshake: 719 # scanf("> lonet %q dial %q %q\n", network, src, dst) 720 # scanf("< lonet %q %s %q\n", network, reply, arg) 721 _lodial_re = re.compile(r'> lonet "(?P<network>.*?[^\\])" dial "(?P<src>.*?[^\\])" "(?P<dst>.*?[^\\])"\n') 722 _loreply_re = re.compile(r'< lonet "(?P<network>.*?[^\\])" (?P<reply>[^\s]+) "(?P<arg>.*?[^\\])"\n') 723 724 # _SubNetwork represents one subnetwork of a lonet network. 725 class _SubNetwork(VirtSubNetwork): 726 # ._oslistener net.socket 727 # ._tserve Thread(._serve) 728 729 def __init__(n, network, registry): 730 super(_SubNetwork, n).__init__(network, registry) 731 732 try: 733 # start OS listener 734 oslistener = net.socket(net.AF_INET, net.SOCK_STREAM) 735 oslistener.bind(("127.0.0.1", 0)) 736 oslistener.listen(1024) 737 738 except: 739 registry.close() 740 raise 741 742 n._oslistener = oslistener 743 744 # XXX -> go(n._serve, serveCtx) + cancel serveCtx in close 745 n._tserve = threading.Thread(target=n._serve, name="%s/serve" % n._network) 746 n._tserve.start() 747 748 749 def _vnet_close(n): 750 # XXX py: no errctx here - it is in _vnet_down 751 # XXX cancel + join tloaccept* 752 n._oslistener.close() 753 n._tserve.join() 754 755 756 # _serve serves incoming OS-level connections to this subnetwork. 757 def _serve(n): 758 # XXX net.socket.close does not interrupt sk.accept 759 # XXX we workaround it with accept timeout and polling for ._down 760 n._oslistener.settimeout(1E-3) # 1ms 761 while 1: 762 if ready(n._down): 763 break 764 765 try: 766 osconn, _ = n._oslistener.accept() 767 except net.timeout: 768 continue 769 770 except Exception as e: 771 n._vnet_down(e) 772 return 773 774 # XXX wg.Add(1) 775 def _(osconn): 776 # XXX defer wg.Done() 777 778 myaddr = addrstr4(*n._oslistener.getsockname()) 779 peeraddr = addrstr4(*osconn.getpeername()) 780 781 try: 782 n._loaccept(osconn) 783 except Exception as e: 784 if errcause(e) is not ErrConnRefused: 785 log.error("lonet %s: serve %s <- %s : %s" % (qq(n._network), myaddr, peeraddr, e)) 786 787 go(_, osconn) 788 789 790 # --- acceptor vs dialer --- 791 792 # _loaccept handles incoming OS-level connection. 793 def _loaccept(n, osconn): 794 # XXX does not support interruption 795 with errctx("loaccept"): 796 try: 797 n.__loaccept(osconn) 798 except Exception: 799 # close osconn on error 800 osconn.close() 801 raise 802 803 def __loaccept(n, osconn): 804 line = skreadline(osconn, 1024) 805 806 def reply(reply): 807 line = "< lonet %s %s\n" % (qq(n._network), reply) 808 osconn.sendall(line) 809 810 def ereply(err, tb): 811 e = err 812 if err is ErrConnRefused: 813 e = "connection refused" # str(ErrConnRefused) is "[Errno 111] connection refused" 814 reply("E %s" % qq(e)) 815 if not xerr.well_defined(err): 816 err = Error("BUG", err, cause_tb=tb) 817 raise err 818 819 def eproto(ereason, detail): 820 reply("E %s" % qq(protocolError(ereason))) 821 raise protocolError(ereason + ": " + detail) 822 823 824 m = _lodial_re.match(line) 825 if m is None: 826 eproto("invalid dial request", "%s" % qq(line)) 827 828 network = m.group('network').decode('string_escape') 829 src = m.group('src').decode('string_escape') 830 dst = m.group('dst').decode('string_escape') 831 832 if network != n._network: 833 eproto("network mismatch", "%s" % qq(network)) 834 835 try: 836 asrc = Addr.parse(network, src) 837 except ValueError: 838 eproto("src address invalid", "%s" % qq(src)) 839 840 try: 841 adst = Addr.parse(network, dst) 842 except ValueError: 843 eproto("dst address invalid", "%s" % qq(dst)) 844 845 with errctx("%s <- %s" % (dst, src)): 846 try: 847 accept = n._vnet_accept(asrc, adst, osconn) 848 except Exception as e: 849 _, _, tb = sys.exc_info() 850 ereply(e, tb) 851 852 try: 853 reply('connected %s' % qq(accept.addr)) 854 except Exception as e: 855 accept.ack.send(e) 856 raise 857 else: 858 accept.ack.send(None) 859 860 861 # _loconnect tries to establish lonet connection on top of OS-level connection. 862 def _loconnect(n, osconn, src, dst): 863 # XXX does not support interruption 864 try: 865 return n.__loconnect(osconn, src, dst) 866 except Exception as err: 867 peeraddr = addrstr4(*osconn.getpeername()) 868 869 # close osconn on error 870 osconn.close() 871 872 _, _, tb = sys.exc_info() 873 if err is not ErrConnRefused: 874 err = Error("loconnect %s" % peeraddr, err, tb) 875 raise err 876 877 878 def __loconnect(n, osconn, src, dst): 879 osconn.sendall("> lonet %s dial %s %s\n" % (qq(n._network), qq(src), qq(dst))) 880 line = skreadline(osconn, 1024) 881 m = _loreply_re.match(line) 882 if m is None: 883 raise protocolError("invalid dial reply: %s" % qq(line)) 884 885 network = m.group('network').decode('string_escape') 886 reply = m.group('reply') # no unescape 887 arg = m.group('arg').decode('string_escape') 888 889 if reply == "E": 890 if arg == "connection refused": 891 raise ErrConnRefused 892 else: 893 raise Error(arg) 894 895 if reply == "connected": 896 pass # ok 897 else: 898 raise protocolError("invalid reply verb: %s" % qq(reply)) 899 900 if network != n._network: 901 raise protocolError("connected, but network mismatch: %s" % qq(network)) 902 903 try: 904 acceptAddr = Addr.parse(network, arg) 905 except ValueError: 906 raise protocolError("connected, but accept address invalid: %s" % qq(acceptAddr)) 907 908 if acceptAddr.host != dst.host: 909 raise protocolError("connected, but accept address is for different host: %s" % qq(acceptAddr.host)) 910 911 # everything is ok 912 return acceptAddr 913 914 915 def _vnet_dial(n, src, dst, dstosladdr): 916 try: 917 # XXX abusing Addr.parse to parse TCP address 918 a = Addr.parse("", dstosladdr) 919 except ValueError: 920 raise ValueError('%s is not valid TCP address' % dstosladdr) 921 922 osconn = net.socket(net.AF_INET, net.SOCK_STREAM) 923 osconn.connect((a.host, a.port)) 924 addrAccept = n._loconnect(osconn, src, dst) 925 return osconn, addrAccept 926 927 def _vnet_newhost(n, hostname, registry): 928 registry.announce(hostname, '%s:%d' % n._oslistener.getsockname()) 929 930 931 @func(protocolError) 932 def __str__(e): 933 return "protocol error: %s" % e.args 934 935 936 # skreadline reads 1 line from sk up to maxlen bytes. 937 def skreadline(sk, maxlen): 938 line = "" 939 while len(line) < maxlen: 940 b = sk.recv(1) 941 if len(b) == 0: # EOF 942 raise Error('unexpected EOF') 943 assert len(b) == 1 944 line += b 945 if b == "\n": 946 break 947 948 return line 949 950 951 952 # -------- registry -------- 953 # 954 # See registry_sqlite.go for details. 955 956 957 # RegistryError is the error raised by registry operations. 958 class RegistryError(Exception): 959 def __init__(self, err, registry, op, *argv): 960 self.err, self.registry, self.op, self.argv = err, registry, op, argv 961 962 def __str__(self): 963 return "%s: %s %s: %s" % (self.registry, self.op, self.argv, self.err) 964 965 # _regerr wraps f to raise RegistryError exception. 966 def _regerr(f): 967 @functools.wraps(f) 968 def f_regerr(self, *argv): 969 try: 970 return f(self, *argv) 971 except Exception as err: 972 if not xerr.well_defined(err): 973 _, _, tb = sys.exc_info() 974 err = Error("BUG", err, tb) 975 raise RegistryError(err, self.uri, f.__name__, *argv) 976 977 return f_regerr 978 979 980 # DBPool provides pool of SQLite connections. 981 class DBPool(object): 982 983 def __init__(self, dburi): 984 # factory to create new connection. 985 # 986 # ( !check_same_thread because it is safe from long ago to pass SQLite 987 # connections in between threads, and with using pool it can happen. ) 988 def factory(): 989 conn = sqlite3.connect(dburi, check_same_thread=False) 990 conn.text_factory = str # always return bytestrings - we keep text in UTF-8 991 conn.isolation_level = None # autocommit 992 return conn 993 994 self._factory = factory # None when pool closed 995 self._lock = threading.Lock() 996 self._connv = [] # of sqlite3.connection 997 998 # get gets connection from the pool. 999 # 1000 # once user is done with it, it has to put the connection back via put. 1001 def get(self): 1002 # try getting already available connection 1003 with self._lock: 1004 factory = self._factory 1005 if factory is None: 1006 raise RuntimeError("sqlite: pool: get on closed pool") 1007 if len(self._connv) > 0: 1008 conn = self._connv.pop() 1009 return conn 1010 1011 # no connection available - open new one 1012 return factory() 1013 1014 1015 # put puts connection back into the pool. 1016 def put(self, conn): 1017 with self._lock: 1018 if self._factory is not None: 1019 self._connv.append(conn) 1020 return 1021 1022 # conn is put back after pool was closed - close conn. 1023 conn.close() 1024 1025 # close closes the pool. 1026 def close(self): 1027 with self._lock: 1028 self._factory = None 1029 connv = self._connv 1030 self._connv = [] 1031 1032 for conn in connv: 1033 conn.close() 1034 1035 1036 # with xget one can use DBPool as context manager to automatically get / put a connection. 1037 def xget(self): 1038 return _DBPoolContext(self) 1039 1040 class _DBPoolContext(object): 1041 1042 def __init__(self, pool): 1043 self.pool = pool 1044 self.conn = pool.get() 1045 1046 def __enter__(self): 1047 return self.conn 1048 1049 def __exit__(self, exc_type, exc_val, exc_tb): 1050 self.pool.put(self.conn) 1051 1052 1053 # SQLiteRegistry implements network registry as shared SQLite file. 1054 class SQLiteRegistry(object): 1055 1056 schema_ver = "lonet.1" 1057 1058 @_regerr 1059 def __init__(r, dburi, network): 1060 r.uri = dburi 1061 r._dbpool = DBPool(dburi) 1062 r._setup(network) 1063 1064 def close(r): 1065 r._dbpool.close() 1066 1067 def _setup(r, network): 1068 with errctx('setup %s' % qq(network)): 1069 with r._dbpool.xget() as conn: 1070 with conn: 1071 conn.execute(""" 1072 CREATE TABLE IF NOT EXISTS hosts ( 1073 hostname TEXT NON NULL PRIMARY KEY, 1074 osladdr TEXT NON NULL 1075 ) 1076 """) 1077 1078 conn.execute(""" 1079 CREATE TABLE IF NOT EXISTS meta ( 1080 name TEXT NON NULL PRIMARY KEY, 1081 value TEXT NON NULL 1082 ) 1083 """) 1084 1085 ver = r._config(conn, "schemaver") 1086 if ver == "": 1087 ver = r.schema_ver 1088 r._set_config(conn, "schemaver", ver) 1089 if ver != r.schema_ver: 1090 raise Error('schema version mismatch: want %s; have %s' % (qq(r._schema_ver), qq(ver))) 1091 1092 dbnetwork = r._config(conn, "network") 1093 if dbnetwork == "": 1094 dbnetwork = network 1095 r._set_config(conn, "network", dbnetwork) 1096 if dbnetwork != network: 1097 raise Error('network name mismatch: want %s; have %s' % (qq(network), qq(dbnetwork))) 1098 1099 1100 def _config(r, conn, name): 1101 with errctx('config: get %s' % qq(name)): 1102 rowv = query(conn, "SELECT value FROM meta WHERE name = ?", name) 1103 if len(rowv) == 0: 1104 return "" 1105 if len(rowv) > 1: 1106 raise Error("registry broken: duplicate config entries") 1107 return rowv[0][0] 1108 1109 1110 def _set_config(r, conn, name, value): 1111 with errctx('config: set %s = %s' % (qq(name), qq(value))): 1112 conn.execute( 1113 "INSERT OR REPLACE INTO meta (name, value) VALUES (?, ?)", 1114 (name, value)) 1115 1116 1117 @_regerr 1118 def announce(r, hostname, osladdr): 1119 with r._dbpool.xget() as conn: 1120 try: 1121 conn.execute( 1122 "INSERT INTO hosts (hostname, osladdr) VALUES (?, ?)", 1123 (hostname, osladdr)) 1124 except sqlite3.IntegrityError as e: 1125 if e.message.startswith('UNIQUE constraint failed'): 1126 raise ErrHostDup 1127 raise 1128 1129 @_regerr 1130 def query(r, hostname): 1131 with r._dbpool.xget() as conn: 1132 rowv = query(conn, "SELECT osladdr FROM hosts WHERE hostname = ?", hostname) 1133 if len(rowv) == 0: 1134 return None 1135 if len(rowv) > 1: 1136 raise Error("registry broken: duplicate host entries") 1137 return rowv[0][0] 1138 1139 1140 # query executes query on connection, fetches and returns all rows as []. 1141 def query(conn, sql, *argv): 1142 rowi = conn.execute(sql, argv) 1143 return list(rowi)