lab.nexedi.com/kirr/go123@v0.0.0-20240207185015-8299741fa871/xnet/lonet/__init__.py (about)

     1  # -*- coding: utf-8 -*-
     2  # Copyright (C) 2018-2020  Nexedi SA and Contributors.
     3  #                          Kirill Smelkov <kirr@nexedi.com>
     4  #
     5  # This program is free software: you can Use, Study, Modify and Redistribute
     6  # it under the terms of the GNU General Public License version 3, or (at your
     7  # option) any later version, as published by the Free Software Foundation.
     8  #
     9  # You can also Link and Combine this program with other software covered by
    10  # the terms of any of the Free Software licenses or any of the Open Source
    11  # Initiative approved licenses and Convey the resulting work. Corresponding
    12  # source of such a combination shall include the source code for all other
    13  # software used.
    14  #
    15  # This program is distributed WITHOUT ANY WARRANTY; without even the implied
    16  # warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
    17  #
    18  # See COPYING file for full licensing terms.
    19  # See https://www.nexedi.com/licensing for rationale and options.
    20  """Package lonet provides TCP network simulated on top of localhost TCP loopback.
    21  
    22  See lonet.go for lonet description, organization and protocol.
    23  """
    24  
    25  # NOTE this package is deliberately concise and follows lonet.go structure,
    26  # which is more well documented.
    27  
    28  import sys, os, stat, errno, tempfile, re
    29  import socket as net
    30  
    31  import sqlite3
    32  import functools
    33  import threading
    34  import logging as log
    35  
    36  from golang import func, defer, go, chan, select, default, panic, gimport
    37  from golang import sync
    38  from golang.gcompat import qq
    39  
    40  xerr    = gimport('lab.nexedi.com/kirr/go123/xerr')
    41  Error   = xerr.Error
    42  errctx  = xerr.context
    43  errcause= xerr.cause
    44  
    45  
    46  # set_once sets threading.Event, but only once.
    47  #
    48  # it returns whether event was set.
    49  #
    50  #   if set_once(down_once):
    51  #     ...
    52  #
    53  # is analog of
    54  #
    55  #   downOnce.Do(...)
    56  #
    57  # in Go.
    58  #
    59  # TODO just use sync.Once from pygolang.
    60  _oncemu = threading.Lock()
    61  def set_once(event):
    62      with _oncemu:
    63          if event.is_set():
    64              return False
    65          event.set()
    66          return True
    67  
    68  
    69  
    70  # -------- virtnet --------
    71  #
    72  # See ../virtnet/virtnet.go for details.
    73  
    74  # neterror creates net.error and registers it as WDE to xerr.
    75  def neterror(*argv):
    76      err = net.error(*argv)
    77      xerr.register_wde_object(err)
    78      return err
    79  
    80  ErrNetDown          = neterror(errno.EBADFD,         "network is down")
    81  ErrHostDown         = neterror(errno.EBADFD,         "host is down")
    82  ErrSockDown         = neterror(errno.EBADFD,         "socket is down")
    83  ErrAddrAlreadyUsed  = neterror(errno.EADDRINUSE,     "address already in use")
    84  ErrAddrNoListen     = neterror(errno.EADDRNOTAVAIL,  "cannot listen on requested address")
    85  ErrConnRefused      = neterror(errno.ECONNREFUSED,   "connection refused")
    86  
    87  ErrNoHost           = neterror("no such host")
    88  ErrHostDup          = neterror("host already registered")
    89  
    90  
    91  # addrstr4 formats host:port as if for TCP4 network.
    92  def addrstr4(host, port):
    93      return "%s:%d" % (host, port)
    94  
    95  # Addr represent address of a virtnet endpoint.
    96  class Addr(object):
    97      # .net  str
    98      # .host str
    99      # .port int
   100  
   101      def __init__(self, net, host, port):
   102          self.net, self.host, self.port = net, host, port
   103  
   104      # netaddr returns address as net.AF_INET (host, port) pair.
   105      def netaddr(self):
   106          return (self.host, self.port)
   107  
   108      def __str__(self):
   109          return addrstr4(*self.netaddr())
   110  
   111      def __eq__(a, b):
   112          return isinstance(b, Addr) and a.net == b.net and a.host == b.host and a.port == b.port
   113  
   114  
   115  # VirtSubNetwork represents one subnetwork of a virtnet network.
   116  class VirtSubNetwork(object):
   117      # ._network     str
   118      # ._registry    Registry
   119      # ._hostmu      μ
   120      # ._hostmap     {} hostname -> Host
   121      # ._nopenhosts  int
   122      # ._autoclose   bool
   123      # ._down        chan ø
   124      # ._down_once   threading.Event
   125  
   126      def __init__(self, network, registry):
   127          self._network    = network
   128          self._registry   = registry
   129          self._hostmu     = threading.Lock()
   130          self._hostmap    = {}
   131          self._nopenhosts = 0
   132          self._autoclose  = False
   133          self._down       = chan()
   134          self._down_once  = threading.Event()
   135  
   136      # must be implemented in particular virtnet implementation
   137      def _vnet_newhost(self, hostname, registry):    raise NotImplementedError()
   138      def _vnet_dial(self, src, dst, dstosladdr):     raise NotImplementedError()
   139      def _vnet_close(self):                          raise NotImplementedError()
   140  
   141  
   142  # Host represents named access point on a virtnet network.
   143  class Host(object):
   144      # ._subnet      VirtSubNetwork
   145      # ._name        str
   146      # ._sockmu      μ
   147      # ._socketv     []socket ; port -> listener | conn ; [0] is always None
   148      # ._down        chan ø
   149      # ._down_once   threading.Event
   150      # ._close_once  sync.Once
   151  
   152      def __init__(self, subnet, name):
   153          self._subnet     = subnet
   154          self._name       = name
   155          self._sockmu     = threading.Lock()
   156          self._socketv    = []
   157          self._down       = chan()
   158          self._down_once  = threading.Event()
   159          self._close_once = sync.Once()
   160  
   161  
   162  # socket represents one endpoint entry on Host.
   163  class socket(object):
   164      # ._host    Host
   165      # ._port    int
   166  
   167      # ._conn      conn | None
   168      # ._listener  listener | None
   169  
   170      def __init__(self, host, port):
   171          self._host, self._port = host, port
   172          self._conn = self._listener = None
   173  
   174  
   175  # conn represents one endpoint of a virtnet connection.
   176  class conn(object):
   177      # ._socket      socket
   178      # ._peerAddr    Addr
   179      # ._netsk       net.socket (embedded)
   180      # ._down        chan()
   181      # ._down_once   threading.Event
   182      # ._close_once  threading.Event
   183  
   184      def __init__(self, sk, peerAddr, netsk):
   185          self._socket, self._peerAddr, self._netsk = sk, peerAddr, netsk
   186          self._down       = chan()
   187          self._down_once  = threading.Event()
   188          self._close_once = threading.Event()
   189  
   190      # ._netsk embedded:
   191      def __getattr__(self, name):
   192          return getattr(self._netsk, name)
   193  
   194  
   195  # listener implements net.Listener for Host.
   196  class listener(object):
   197      # ._socket      socket
   198      # ._dialq       chan dialReq
   199      # ._down        chan ø
   200      # ._down_once   threading.Event
   201      # ._close_once  threading.Event
   202  
   203      def __init__(self, sk):
   204          self._socket     = sk
   205          self._dialq      = chan()
   206          self._down       = chan()
   207          self._down_once  = threading.Event()
   208          self._close_once = threading.Event()
   209  
   210  
   211  # dialReq represents one dial request to listener from acceptor.
   212  class dialReq(object):
   213      # ._from    Addr
   214      # ._netsk   net.socket
   215      # ._resp    chan Accept
   216  
   217      def __init__(self, from_, netsk, resp):
   218          self._from, self._netsk, self._resp = from_, netsk, resp
   219  
   220  
   221  # Accept represents successful acceptance decision from VirtSubNetwork._vnet_accept .
   222  class Accept(object):
   223      # .addr     Addr
   224      # .ack      chan error
   225      def __init__(self, addr, ack):
   226          self.addr, self.ack = addr, ack
   227  
   228  
   229  # ----------------------------------------
   230  
   231  # _shutdown is worker for close and _vnet_down.
   232  @func(VirtSubNetwork)
   233  def _shutdown(n, exc):
   234      n.__shutdown(exc, True)
   235  @func(VirtSubNetwork)
   236  def __shutdown(n, exc, withHosts):
   237      if not set_once(n._down_once):
   238          return
   239  
   240      n._down.close()
   241  
   242      if withHosts:
   243          with n._hostmu:
   244              for host in n._hostmap.values():
   245                  host._shutdown()
   246  
   247      # XXX py: we don't collect / remember .downErr
   248      if exc is not None:
   249          log.error(exc)
   250      n._vnet_close()
   251      n._registry.close()
   252  
   253  
   254  # close shutdowns subnetwork.
   255  @func(VirtSubNetwork)
   256  def close(n):
   257      n.__close(True)
   258  @func(VirtSubNetwork)
   259  def _closeWithoutHosts(n):
   260      n.__close(False)
   261  @func(VirtSubNetwork)
   262  def __close(n, withHosts):
   263      with errctx("virtnet %s: close" % qq(n._network)):
   264          n.__shutdown(None, withHosts)
   265  
   266  # _vnet_down shutdowns subnetwork upon engine error.
   267  @func(VirtSubNetwork)
   268  def _vnet_down(n, exc):
   269      # XXX py: errctx here (go does not have) because we do not reraise .downErr in close
   270      with errctx("virtnet %s: shutdown" % qq(n._network)):
   271          n._shutdown(exc)
   272  
   273  
   274  # new_host creates new Host with given name.
   275  @func(VirtSubNetwork)
   276  def new_host(n, name):
   277      with errctx("virtnet %s: new host %s" % (qq(n._network), qq(name))):
   278          n._vnet_newhost(name, n._registry)
   279          # XXX check err due to subnet down
   280  
   281          with n._hostmu:
   282              if name in n._hostmap:
   283                  panic("announced ok but .hostMap already !empty" % (qq(n._network), qq(name)))
   284  
   285              host = Host(n, name)
   286              n._hostmap[name] = host
   287              n._nopenhosts += 1
   288              return host
   289  
   290  
   291  # host returns host on the subnetwork by name.
   292  @func(VirtSubNetwork)
   293  def host(n, name):
   294      with n._hostmu:
   295          return n._hostmap.get(name)
   296  
   297  
   298  # _shutdown is underlying worker for close.
   299  @func(Host)
   300  def _shutdown(h):
   301      if not set_once(h._down_once):
   302          return
   303  
   304      h._down.close()
   305  
   306      with h._sockmu:
   307          for sk in h._socketv:
   308              if sk is None:
   309                  continue
   310              if sk._conn is not None:
   311                  sk._conn._shutdown()
   312              if sk._listener is not None:
   313                  sk._listener._shutdown()
   314  
   315  # close shutdowns host.
   316  @func(Host)
   317  def close(h):
   318      def autoclose():
   319          def _():
   320              n = h._subnet
   321              with n._hostmu:
   322                  n._nopenHosts -= 1
   323                  if n._nopenHosts < 0:
   324                      panic("SubNetwork._nopenHosts < 0")
   325                  if n._autoclose and n._nopenHosts == 0:
   326                      n._closeWithoutHosts()
   327          h._close_once.do(_)
   328      defer(autoclose)
   329  
   330      with errctx("virtnet %s: host %s: close" % (qq(h._subnet._network), qq(h._name))):
   331          h._shutdown()
   332  
   333  # autoclose schedules close to be called after last host on this subnetwork is closed.
   334  @func(VirtSubNetwork)
   335  def autoclose(n):
   336      with n._hostmu:
   337          if n._nopenHosts == 0:
   338              panic("BUG: no opened hosts")
   339          n._autoclose = True
   340  
   341  
   342  # listen starts new listener on the host.
   343  @func(Host)
   344  def listen(h, laddr):
   345      if laddr == "":
   346          laddr = ":0"
   347  
   348      with errctx("listen %s %s" % (h.network(), laddr)):
   349          a = h._parseAddr(laddr)
   350  
   351          if a.host != h._name:
   352              raise ErrAddrNoListen
   353  
   354          if ready(h._down):
   355              h._excDown()
   356  
   357          with h._sockmu:
   358              if a.port == 0:
   359                  sk = h._allocFreeSocket()
   360  
   361              else:
   362                  while a.port >= len(h._socketv):
   363                      h._socketv.append(None)
   364  
   365                  if h._socketv[a.port] is not None:
   366                      raise ErrAddrAlreadyUsed
   367  
   368                  sk = socket(h, a.port)
   369                  h._socketv[a.port] = sk
   370  
   371              l = listener(sk)
   372              sk._listener = l
   373  
   374              return l
   375  
   376  
   377  # _shutdown shutdowns the listener.
   378  @func(listener)
   379  def _shutdown(l):
   380      if set_once(l._down_once):
   381          l._down.close()
   382  
   383  # close closes the listener.
   384  @func(listener)
   385  def close(l):
   386      l._shutdown()
   387      if not set_once(l._close_once):
   388          return
   389  
   390      sk = l._socket
   391      h  = sk._host
   392  
   393      with h._sockmu:
   394          sk._listener = None
   395          if sk._empty():
   396              h._socketv[sk.port] = None
   397  
   398  
   399  # accept tries to connect to dial called with addr corresponding to our listener.
   400  @func(listener)
   401  def accept(l):
   402      h = l._socket._host
   403  
   404      with errctx("accept %s %s" % (h.network(), l.addr())):
   405          while 1:
   406              _, _rx = select(
   407                  l._down.recv,   # 0
   408                  l._dialq.recv,  # 1
   409              )
   410              if _ == 0:
   411                  l._excDown()
   412              if _ == 1:
   413                  req = _rx
   414  
   415              with h._sockmu:
   416                  sk = h._allocFreeSocket()
   417  
   418              ack = chan()
   419              req._resp.send(Accept(sk.addr(), ack))
   420  
   421              _, _rx = select(
   422                  l._down.recv,   # 0
   423                  ack.recv,       # 1
   424              )
   425              if _ == 0:
   426                  def purgesk():
   427                      err = ack.recv()
   428                      if err is None:
   429                          try:
   430                              req._netsk.close()
   431                          except:
   432                              pass
   433                      with h._sockmu:
   434                          h._socketv[sk._port] = None
   435  
   436                  go(purgesk)
   437                  l._excDown()
   438  
   439              if _ == 1:
   440                  err = _rx
   441  
   442              if err is not None:
   443                  with h._sockmu:
   444                      h._socketv[sk._port] = None
   445                  continue
   446  
   447              c = conn(sk, req._from, req._netsk)
   448              with h._sockmu:
   449                  sk.conn = c
   450  
   451              return c
   452  
   453  
   454  # _vnet_accept accepts or rejects incoming connection.
   455  @func(VirtSubNetwork)
   456  def _vnet_accept(n, src, dst, netconn):
   457      with n._hostmu:
   458          host = n._hostmap.get(dst.host)
   459      if host is None:
   460          raise net.gaierror('%s: no such host' % dst.host)
   461  
   462      host._sockmu.acquire()
   463  
   464      if dst.port >= len(host._socketv):
   465          host._sockmu.release()
   466          raise ErrConnRefused
   467  
   468      sk = host._socketv[dst.port]
   469      if sk is None or sk._listener is None:
   470          host._sockmu.release()
   471          raise ErrConnRefused
   472  
   473      l = sk._listener
   474      host._sockmu.release()
   475  
   476      resp = chan()
   477      req  = dialReq(src, netconn, resp)
   478  
   479      _, _rx = select(
   480          l._down.recv,           # 0
   481          (l._dialq.send, req),   # 1
   482      )
   483      if _ == 0:
   484          raise ErrConnRefused
   485      if _ == 1:
   486          return resp.recv()
   487  
   488  
   489  # dial dials address on the network.
   490  @func(Host)
   491  def dial(h, addr):
   492      with h._sockmu:
   493          sk = h._allocFreeSocket()
   494  
   495      # XXX py: default dst to addr to be able to render it in error if it happens before parse
   496      dst = addr
   497  
   498      try:
   499          dst = h._parseAddr(addr)
   500          n   = h._subnet
   501  
   502          # XXX cancel on host shutdown
   503  
   504          dstdata = n._registry.query(dst.host)
   505          if dstdata is None:
   506              raise ErrNoHost
   507  
   508          netsk, acceptAddr = n._vnet_dial(sk.addr(), dst, dstdata)
   509  
   510          c = conn(sk, acceptAddr, netsk)
   511          with h._sockmu:
   512              sk._conn = c
   513          return c
   514  
   515      except Exception as err:
   516          with h._sockmu:
   517              h._socketv[sk._port] = None
   518  
   519          _, _, tb = sys.exc_info()
   520          raise Error("dial %s %s->%s" % (h.network(), sk.addr(), dst), err, tb)
   521  
   522  
   523  # ---- conn ----
   524  
   525  # _shutdown closes underlying network connection.
   526  @func(conn)
   527  def _shutdown(c):
   528      if not set_once(c._down_once):
   529          return
   530  
   531      c._down.close()
   532      # XXX py: we don't remember .errClose
   533      c._netsk.close()
   534  
   535  
   536  # close closes network endpoint and unregisters conn from Host.
   537  @func(conn)
   538  def close(c):
   539      c._shutdown()
   540      if set_once(c._close_once):
   541          sk = c._socket
   542          h  = sk._host
   543  
   544          with h._sockmu:
   545              sk._conn = None
   546              if sk._empty():
   547                  h._socketv[sk._port] = None
   548  
   549      # XXX py: we don't reraise c.errClose
   550  
   551  # XXX py: don't bother to override recv (Read)
   552  # XXX py: don't bother to override send (Write)
   553  
   554  # local_addr returns virtnet address of local end of connection.
   555  @func(conn)
   556  def local_addr(c):
   557      return c._socket.addr()
   558  
   559  # getsockname returns virtnet address of local end of connection as net.AF_INET (host, port) pair.
   560  @func(conn)
   561  def getsockname(c):
   562      return c.local_addr().netaddr()
   563  
   564  # remote_addr returns virtnet address of remote end of connection.
   565  @func(conn)
   566  def remote_addr(c):
   567      return c._peerAddr
   568  
   569  # getpeername returns virtnet address of remote end of connection as net.AF_INET (host, port) pair.
   570  @func(conn)
   571  def getpeername(c):
   572      return c.remote_addr().netaddr()
   573  
   574  # ----------------------------------------
   575  
   576  # _allocFreeSocket finds first free port and allocates socket entry for it.
   577  @func(Host)
   578  def _allocFreeSocket(h):
   579      port = 1
   580      while port < len(h._socketv):
   581          if h._socketv[port] is None:
   582              break
   583          port += 1
   584  
   585      while port >= len(h._socketv):
   586          h._socketv.append(None)
   587  
   588      sk = socket(h, port)
   589      h._socketv[port] = sk
   590      return sk
   591  
   592  
   593  # empty checks whether socket's both conn and listener are all nil.
   594  @func(socket)
   595  def _empty(sk):
   596      return (sk._conn is None and sk._listener is None)
   597  
   598  # addr returns address corresponding to socket.
   599  @func(socket)
   600  def addr(sk):
   601      h = sk._host
   602      return Addr(h.network(), h.name(), sk._port)
   603  
   604  # Addr.parse parses addr into virtnet address for named network.
   605  @func(Addr)
   606  @staticmethod
   607  def parse(net, addr):
   608      try:
   609          addrv = addr.split(':')
   610          if len(addrv) != 2:
   611              raise ValueError()
   612          return Addr(net, addrv[0], int(addrv[1]))
   613      except:
   614          raise ValueError('%s is not valid virtnet address' % addr)
   615  
   616  # _parseAddr parses addr into virtnet address from host point of view.
   617  @func(Host)
   618  def _parseAddr(h, addr):
   619      a = Addr.parse(h.network(), addr)
   620      if a.host == "":
   621          a.host = h._name
   622      return a
   623  
   624  # addr returns address where listener is accepting incoming connections.
   625  @func(listener)
   626  def addr(l):
   627      return l._socket.addr()
   628  
   629  
   630  # network returns full network name this subnetwork is part of.
   631  @func(VirtSubNetwork)
   632  def network(n):
   633      return n._network
   634  
   635  # network returns full network name of underlying network.
   636  @func(Host)
   637  def network(h):
   638      return h._subnet.network()
   639  
   640  # name returns host name.
   641  @func(Host)
   642  def name(h):
   643      return h._name
   644  
   645  # ----------------------------------------
   646  
   647  # _excDown raises appropriate exception cause when h.down is found ready.
   648  @func(Host)
   649  def _excDown(h):
   650      if ready(h._subnet._down):
   651          raise ErrNetDown
   652      else:
   653          raise ErrHostDown
   654  
   655  # _excDown raises appropriate exception cause when l.down is found ready.
   656  @func(listener)
   657  def _excDown(l):
   658      h = l._socket._host
   659      n = h._subnet
   660  
   661      if ready(n._down):
   662          raise ErrNetDown
   663      elif ready(h._down):
   664          raise ErrHostDown
   665      else:
   666          raise ErrSockDown
   667  
   668  # XXX py: conn.errOrDown is not implemented because conn.{Read,Write} are not wrapped.
   669  
   670  # ready returns whether channel ch is ready.
   671  def ready(ch):
   672      _, _rx = select(
   673          ch.recv,    # 0
   674          default,    # 1
   675      )
   676      if _ == 0:
   677          return True
   678      if _ == 1:
   679          return False
   680  
   681  
   682  # -------- lonet networking --------
   683  #
   684  # See lonet.go for details.
   685  
   686  # protocolError represents logical error in lonet handshake exchange.
   687  class protocolError(Exception):
   688      pass
   689  
   690  xerr.register_wde_class(protocolError)
   691  
   692  
   693  # `mkdir -p`; https://stackoverflow.com/a/273227
   694  def _mkdir_p(path, mode):
   695      try:
   696          os.makedirs(path, mode)
   697      except OSError as e:
   698          if e.errno != errno.EEXIST:
   699              raise
   700  
   701  # join joins or creates new lonet network with given name.
   702  def join(network):
   703      with errctx("lonet: join %s" % qq(network)):
   704          lonet = tempfile.gettempdir() + "/lonet"
   705          _mkdir_p(lonet, 0777 | stat.S_ISVTX)
   706  
   707          if network != "":
   708              netdir = lonet + "/" + network
   709              _mkdir_p(netdir, 0700)
   710          else:
   711              netdir = tempfile.mkdtemp(dir=lonet)
   712              network = os.path.basename(netdir)
   713  
   714          registry = SQLiteRegistry(netdir + "/registry.db", network)
   715          return _SubNetwork("lonet" + network, registry)
   716  
   717  
   718  # lonet handshake:
   719  # scanf("> lonet %q dial %q %q\n", network, src, dst)
   720  # scanf("< lonet %q %s %q\n", network, reply, arg)
   721  _lodial_re  = re.compile(r'> lonet "(?P<network>.*?[^\\])" dial "(?P<src>.*?[^\\])" "(?P<dst>.*?[^\\])"\n')
   722  _loreply_re = re.compile(r'< lonet "(?P<network>.*?[^\\])" (?P<reply>[^\s]+) "(?P<arg>.*?[^\\])"\n')
   723  
   724  # _SubNetwork represents one subnetwork of a lonet network.
   725  class _SubNetwork(VirtSubNetwork):
   726      # ._oslistener  net.socket
   727      # ._tserve      Thread(._serve)
   728  
   729      def __init__(n, network, registry):
   730          super(_SubNetwork, n).__init__(network, registry)
   731  
   732          try:
   733              # start OS listener
   734              oslistener = net.socket(net.AF_INET, net.SOCK_STREAM)
   735              oslistener.bind(("127.0.0.1", 0))
   736              oslistener.listen(1024)
   737  
   738          except:
   739              registry.close()
   740              raise
   741  
   742          n._oslistener = oslistener
   743  
   744          # XXX -> go(n._serve, serveCtx) + cancel serveCtx in close
   745          n._tserve = threading.Thread(target=n._serve, name="%s/serve" % n._network)
   746          n._tserve.start()
   747  
   748  
   749      def _vnet_close(n):
   750          # XXX py: no errctx here - it is in _vnet_down
   751          # XXX cancel + join tloaccept*
   752          n._oslistener.close()
   753          n._tserve.join()
   754  
   755  
   756      # _serve serves incoming OS-level connections to this subnetwork.
   757      def _serve(n):
   758          # XXX net.socket.close does not interrupt sk.accept
   759          # XXX we workaround it with accept timeout and polling for ._down
   760          n._oslistener.settimeout(1E-3)   # 1ms
   761          while 1:
   762              if ready(n._down):
   763                  break
   764  
   765              try:
   766                  osconn, _ = n._oslistener.accept()
   767              except net.timeout:
   768                  continue
   769  
   770              except Exception as e:
   771                  n._vnet_down(e)
   772                  return
   773  
   774              # XXX wg.Add(1)
   775              def _(osconn):
   776                  # XXX defer wg.Done()
   777  
   778                  myaddr   = addrstr4(*n._oslistener.getsockname())
   779                  peeraddr = addrstr4(*osconn.getpeername())
   780  
   781                  try:
   782                      n._loaccept(osconn)
   783                  except Exception as e:
   784                      if errcause(e) is not ErrConnRefused:
   785                          log.error("lonet %s: serve %s <- %s : %s" % (qq(n._network), myaddr, peeraddr, e))
   786  
   787              go(_, osconn)
   788  
   789  
   790      # --- acceptor vs dialer ---
   791  
   792      # _loaccept handles incoming OS-level connection.
   793      def _loaccept(n, osconn):
   794          # XXX does not support interruption
   795          with errctx("loaccept"):
   796              try:
   797                  n.__loaccept(osconn)
   798              except Exception:
   799                  # close osconn on error
   800                  osconn.close()
   801                  raise
   802  
   803      def __loaccept(n, osconn):
   804          line = skreadline(osconn, 1024)
   805  
   806          def reply(reply):
   807              line = "< lonet %s %s\n" % (qq(n._network), reply)
   808              osconn.sendall(line)
   809  
   810          def ereply(err, tb):
   811              e = err
   812              if err is ErrConnRefused:
   813                  e = "connection refused"  # str(ErrConnRefused) is "[Errno 111] connection refused"
   814              reply("E %s" % qq(e))
   815              if not xerr.well_defined(err):
   816                  err = Error("BUG", err, cause_tb=tb)
   817              raise err
   818  
   819          def eproto(ereason, detail):
   820              reply("E %s" % qq(protocolError(ereason)))
   821              raise protocolError(ereason + ": " + detail)
   822  
   823  
   824          m = _lodial_re.match(line)
   825          if m is None:
   826              eproto("invalid dial request", "%s" % qq(line))
   827  
   828          network = m.group('network').decode('string_escape')
   829          src     = m.group('src').decode('string_escape')
   830          dst     = m.group('dst').decode('string_escape')
   831  
   832          if network != n._network:
   833              eproto("network mismatch", "%s" % qq(network))
   834  
   835          try:
   836              asrc = Addr.parse(network, src)
   837          except ValueError:
   838              eproto("src address invalid", "%s" % qq(src))
   839  
   840          try:
   841              adst = Addr.parse(network, dst)
   842          except ValueError:
   843              eproto("dst address invalid", "%s" % qq(dst))
   844  
   845          with errctx("%s <- %s" % (dst, src)):
   846              try:
   847                  accept = n._vnet_accept(asrc, adst, osconn)
   848              except Exception as e:
   849                  _, _, tb = sys.exc_info()
   850                  ereply(e, tb)
   851  
   852              try:
   853                  reply('connected %s' % qq(accept.addr))
   854              except Exception as e:
   855                  accept.ack.send(e)
   856                  raise
   857              else:
   858                  accept.ack.send(None)
   859  
   860  
   861      # _loconnect tries to establish lonet connection on top of OS-level connection.
   862      def _loconnect(n, osconn, src, dst):
   863          # XXX does not support interruption
   864          try:
   865              return n.__loconnect(osconn, src, dst)
   866          except Exception as err:
   867              peeraddr = addrstr4(*osconn.getpeername())
   868  
   869              # close osconn on error
   870              osconn.close()
   871  
   872              _, _, tb = sys.exc_info()
   873              if err is not ErrConnRefused:
   874                  err = Error("loconnect %s" % peeraddr, err, tb)
   875              raise err
   876  
   877  
   878      def __loconnect(n, osconn, src, dst):
   879          osconn.sendall("> lonet %s dial %s %s\n" % (qq(n._network), qq(src), qq(dst)))
   880          line = skreadline(osconn, 1024)
   881          m = _loreply_re.match(line)
   882          if m is None:
   883              raise protocolError("invalid dial reply: %s" % qq(line))
   884  
   885          network = m.group('network').decode('string_escape')
   886          reply   = m.group('reply') # no unescape
   887          arg     = m.group('arg').decode('string_escape')
   888  
   889          if reply == "E":
   890              if arg == "connection refused":
   891                  raise ErrConnRefused
   892              else:
   893                  raise Error(arg)
   894  
   895          if reply == "connected":
   896              pass    # ok
   897          else:
   898              raise protocolError("invalid reply verb: %s" % qq(reply))
   899  
   900          if network != n._network:
   901              raise protocolError("connected, but network mismatch: %s" % qq(network))
   902  
   903          try:
   904              acceptAddr = Addr.parse(network, arg)
   905          except ValueError:
   906              raise protocolError("connected, but accept address invalid: %s" % qq(acceptAddr))
   907  
   908          if acceptAddr.host != dst.host:
   909              raise protocolError("connected, but accept address is for different host: %s" % qq(acceptAddr.host))
   910  
   911          # everything is ok
   912          return acceptAddr
   913  
   914  
   915      def _vnet_dial(n, src, dst, dstosladdr):
   916          try:
   917              # XXX abusing Addr.parse to parse TCP address
   918              a = Addr.parse("", dstosladdr)
   919          except ValueError:
   920              raise ValueError('%s is not valid TCP address' % dstosladdr)
   921  
   922          osconn = net.socket(net.AF_INET, net.SOCK_STREAM)
   923          osconn.connect((a.host, a.port))
   924          addrAccept = n._loconnect(osconn, src, dst)
   925          return osconn, addrAccept
   926  
   927      def _vnet_newhost(n, hostname, registry):
   928          registry.announce(hostname, '%s:%d' % n._oslistener.getsockname())
   929  
   930  
   931  @func(protocolError)
   932  def __str__(e):
   933      return "protocol error: %s" % e.args
   934  
   935  
   936  # skreadline reads 1 line from sk up to maxlen bytes.
   937  def skreadline(sk, maxlen):
   938      line = ""
   939      while len(line) < maxlen:
   940          b = sk.recv(1)
   941          if len(b) == 0: # EOF
   942              raise Error('unexpected EOF')
   943          assert len(b) == 1
   944          line += b
   945          if b == "\n":
   946              break
   947  
   948      return line
   949  
   950  
   951  
   952  # -------- registry --------
   953  #
   954  # See registry_sqlite.go for details.
   955  
   956  
   957  # RegistryError is the error raised by registry operations.
   958  class RegistryError(Exception):
   959      def __init__(self, err, registry, op, *argv):
   960          self.err, self.registry, self.op, self.argv = err, registry, op, argv
   961  
   962      def __str__(self):
   963          return "%s: %s %s: %s" % (self.registry, self.op, self.argv, self.err)
   964  
   965  # _regerr wraps f to raise RegistryError exception.
   966  def _regerr(f):
   967      @functools.wraps(f)
   968      def f_regerr(self, *argv):
   969          try:
   970              return f(self, *argv)
   971          except Exception as err:
   972              if not xerr.well_defined(err):
   973                  _, _, tb = sys.exc_info()
   974                  err = Error("BUG", err, tb)
   975              raise RegistryError(err, self.uri, f.__name__, *argv)
   976  
   977      return f_regerr
   978  
   979  
   980  # DBPool provides pool of SQLite connections.
   981  class DBPool(object):
   982  
   983      def __init__(self, dburi):
   984          # factory to create new connection.
   985          #
   986          # ( !check_same_thread because it is safe from long ago to pass SQLite
   987          #   connections in between threads, and with using pool it can happen. )
   988          def factory():
   989              conn = sqlite3.connect(dburi, check_same_thread=False)
   990              conn.text_factory    = str  # always return bytestrings - we keep text in UTF-8
   991              conn.isolation_level = None # autocommit
   992              return conn
   993  
   994          self._factory = factory             # None when pool closed
   995          self._lock    = threading.Lock()
   996          self._connv   = []                  # of sqlite3.connection
   997  
   998      # get gets connection from the pool.
   999      #
  1000      # once user is done with it, it has to put the connection back via put.
  1001      def get(self):
  1002          # try getting already available connection
  1003          with self._lock:
  1004              factory = self._factory
  1005              if factory is None:
  1006                  raise RuntimeError("sqlite: pool: get on closed pool")
  1007              if len(self._connv) > 0:
  1008                  conn = self._connv.pop()
  1009                  return conn
  1010  
  1011          # no connection available - open new one
  1012          return factory()
  1013  
  1014  
  1015      # put puts connection back into the pool.
  1016      def put(self, conn):
  1017          with self._lock:
  1018              if self._factory is not None:
  1019                  self._connv.append(conn)
  1020                  return
  1021  
  1022          # conn is put back after pool was closed - close conn.
  1023          conn.close()
  1024  
  1025      # close closes the pool.
  1026      def close(self):
  1027          with self._lock:
  1028              self._factory = None
  1029              connv = self._connv
  1030              self._connv = []
  1031  
  1032          for conn in connv:
  1033              conn.close()
  1034  
  1035  
  1036      # with xget one can use DBPool as context manager to automatically get / put a connection.
  1037      def xget(self):
  1038          return _DBPoolContext(self)
  1039  
  1040  class _DBPoolContext(object):
  1041  
  1042      def __init__(self, pool):
  1043          self.pool = pool
  1044          self.conn = pool.get()
  1045  
  1046      def __enter__(self):
  1047          return self.conn
  1048  
  1049      def __exit__(self, exc_type, exc_val, exc_tb):
  1050          self.pool.put(self.conn)
  1051  
  1052  
  1053  # SQLiteRegistry implements network registry as shared SQLite file.
  1054  class SQLiteRegistry(object):
  1055  
  1056      schema_ver = "lonet.1"
  1057  
  1058      @_regerr
  1059      def __init__(r, dburi, network):
  1060          r.uri = dburi
  1061          r._dbpool = DBPool(dburi)
  1062          r._setup(network)
  1063  
  1064      def close(r):
  1065          r._dbpool.close()
  1066  
  1067      def _setup(r, network):
  1068          with errctx('setup %s' % qq(network)):
  1069              with r._dbpool.xget() as conn:
  1070                  with conn:
  1071                      conn.execute("""
  1072                          CREATE TABLE IF NOT EXISTS hosts (
  1073                                  hostname	TEXT NON NULL PRIMARY KEY,
  1074                                  osladdr		TEXT NON NULL
  1075                          )
  1076                      """)
  1077  
  1078                      conn.execute("""
  1079                      CREATE TABLE IF NOT EXISTS meta (
  1080                                  name		TEXT NON NULL PRIMARY KEY,
  1081                                  value		TEXT NON NULL
  1082                          )
  1083                      """)
  1084  
  1085                      ver = r._config(conn, "schemaver")
  1086                      if ver == "":
  1087                          ver = r.schema_ver
  1088                          r._set_config(conn, "schemaver", ver)
  1089                      if ver != r.schema_ver:
  1090                          raise Error('schema version mismatch: want %s; have %s' % (qq(r._schema_ver), qq(ver)))
  1091  
  1092                      dbnetwork = r._config(conn, "network")
  1093                      if dbnetwork == "":
  1094                          dbnetwork = network
  1095                          r._set_config(conn, "network", dbnetwork)
  1096                      if dbnetwork != network:
  1097                          raise Error('network name mismatch: want %s; have %s' % (qq(network), qq(dbnetwork)))
  1098  
  1099  
  1100      def _config(r, conn, name):
  1101          with errctx('config: get %s' % qq(name)):
  1102              rowv = query(conn, "SELECT value FROM meta WHERE name = ?", name)
  1103              if len(rowv) == 0:
  1104                  return ""
  1105              if len(rowv) > 1:
  1106                  raise Error("registry broken: duplicate config entries")
  1107              return rowv[0][0]
  1108  
  1109  
  1110      def _set_config(r, conn, name, value):
  1111          with errctx('config: set %s = %s' % (qq(name), qq(value))):
  1112              conn.execute(
  1113                      "INSERT OR REPLACE INTO meta (name, value) VALUES (?, ?)",
  1114                      (name, value))
  1115  
  1116  
  1117      @_regerr
  1118      def announce(r, hostname, osladdr):
  1119          with r._dbpool.xget() as conn:
  1120              try:
  1121                  conn.execute(
  1122                      "INSERT INTO hosts (hostname, osladdr) VALUES (?, ?)",
  1123                      (hostname, osladdr))
  1124              except sqlite3.IntegrityError as e:
  1125                  if e.message.startswith('UNIQUE constraint failed'):
  1126                      raise ErrHostDup
  1127                  raise
  1128  
  1129      @_regerr
  1130      def query(r, hostname):
  1131          with r._dbpool.xget() as conn:
  1132              rowv = query(conn, "SELECT osladdr FROM hosts WHERE hostname = ?", hostname)
  1133              if len(rowv) == 0:
  1134                  return None
  1135              if len(rowv) > 1:
  1136                  raise Error("registry broken: duplicate host entries")
  1137              return rowv[0][0]
  1138  
  1139  
  1140  # query executes query on connection, fetches and returns all rows as [].
  1141  def query(conn, sql, *argv):
  1142      rowi = conn.execute(sql, argv)
  1143      return list(rowi)