github.com/decred/dcrlnd@v0.7.6/watchtower/wtserver/server.go (about)

     1  package wtserver
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/decred/dcrd/chaincfg/chainhash"
    12  	"github.com/decred/dcrd/connmgr"
    13  	"github.com/decred/dcrd/txscript/v4/stdaddr"
    14  	"github.com/decred/dcrlnd/keychain"
    15  	"github.com/decred/dcrlnd/lnwire"
    16  	"github.com/decred/dcrlnd/watchtower/wtdb"
    17  	"github.com/decred/dcrlnd/watchtower/wtwire"
    18  )
    19  
    20  var (
    21  	// ErrPeerAlreadyConnected signals that a peer with the same session id
    22  	// is already active within the server.
    23  	ErrPeerAlreadyConnected = errors.New("peer already connected")
    24  
    25  	// ErrServerExiting signals that a request could not be processed
    26  	// because the server has been requested to shut down.
    27  	ErrServerExiting = errors.New("server shutting down")
    28  )
    29  
    30  // Config abstracts the primary components and dependencies of the server.
    31  type Config struct {
    32  	// DB provides persistent access to the server's sessions and for
    33  	// storing state updates.
    34  	DB DB
    35  
    36  	// NodeKeyECDH is the the ECDH capable wrapper of the key to be used in
    37  	// accepting new brontide connections.
    38  	NodeKeyECDH keychain.SingleKeyECDH
    39  
    40  	// Listeners specifies which address to which clients may connect.
    41  	Listeners []net.Listener
    42  
    43  	// ReadTimeout specifies how long a client may go without sending a
    44  	// message.
    45  	ReadTimeout time.Duration
    46  
    47  	// WriteTimeout specifies how long a client may go without reading a
    48  	// message from the other end, if the connection has stopped buffering
    49  	// the server's replies.
    50  	WriteTimeout time.Duration
    51  
    52  	// NewAddress is used to generate reward addresses, where a cut of
    53  	// successfully sent funds can be received.
    54  	NewAddress func() (stdaddr.Address, error)
    55  
    56  	// ChainHash identifies the network that the server is watching.
    57  	ChainHash chainhash.Hash
    58  
    59  	// NoAckCreateSession causes the server to not reply to create session
    60  	// requests, this should only be used for testing.
    61  	NoAckCreateSession bool
    62  
    63  	// NoAckUpdates causes the server to not acknowledge state updates, this
    64  	// should only be used for testing.
    65  	NoAckUpdates bool
    66  
    67  	// DisableReward causes the server to reject any session creation
    68  	// attempts that request rewards.
    69  	DisableReward bool
    70  }
    71  
    72  // Server houses the state required to handle watchtower peers. It's primary job
    73  // is to accept incoming connections, and dispatch processing of the client
    74  // message streams.
    75  type Server struct {
    76  	started sync.Once
    77  	stopped sync.Once
    78  
    79  	cfg *Config
    80  
    81  	connMgr *connmgr.ConnManager
    82  
    83  	clientMtx sync.RWMutex
    84  	clients   map[wtdb.SessionID]Peer
    85  
    86  	newPeers chan Peer
    87  
    88  	localInit *wtwire.Init
    89  
    90  	wg   sync.WaitGroup
    91  	quit chan struct{}
    92  }
    93  
    94  // New creates a new server to handle watchtower clients. The server will accept
    95  // clients connecting to the listener addresses, and allows them to open
    96  // sessions and send state updates.
    97  func New(cfg *Config) (*Server, error) {
    98  	localInit := wtwire.NewInitMessage(
    99  		lnwire.NewRawFeatureVector(
   100  			wtwire.AltruistSessionsOptional,
   101  			wtwire.AnchorCommitOptional,
   102  		),
   103  		cfg.ChainHash,
   104  	)
   105  
   106  	s := &Server{
   107  		cfg:       cfg,
   108  		clients:   make(map[wtdb.SessionID]Peer),
   109  		newPeers:  make(chan Peer),
   110  		localInit: localInit,
   111  		quit:      make(chan struct{}),
   112  	}
   113  
   114  	connMgr, err := connmgr.New(&connmgr.Config{
   115  		Listeners: cfg.Listeners,
   116  		OnAccept:  s.inboundPeerConnected,
   117  		Dial:      noDial,
   118  	})
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  
   123  	s.connMgr = connMgr
   124  
   125  	return s, nil
   126  }
   127  
   128  // Start begins listening on the server's listeners.
   129  func (s *Server) Start() error {
   130  	s.started.Do(func() {
   131  		log.Infof("Starting watchtower server")
   132  
   133  		s.wg.Add(1)
   134  		go s.peerHandler()
   135  
   136  		s.connMgr.Start()
   137  
   138  		log.Infof("Watchtower server started successfully")
   139  	})
   140  	return nil
   141  }
   142  
   143  // Stop shutdowns down the server's listeners and any active requests.
   144  func (s *Server) Stop() error {
   145  	s.stopped.Do(func() {
   146  		log.Infof("Stopping watchtower server")
   147  
   148  		s.connMgr.Stop()
   149  
   150  		close(s.quit)
   151  		s.wg.Wait()
   152  
   153  		log.Infof("Watchtower server stopped successfully")
   154  	})
   155  	return nil
   156  }
   157  
   158  // inboundPeerConnected is the callback given to the connection manager, and is
   159  // called each time a new connection is made to the watchtower. This method
   160  // proxies the new peers by filtering out those that do not satisfy the
   161  // server.Peer interface, and closes their connection. Successful connections
   162  // will be passed on to the public InboundPeerConnected method.
   163  func (s *Server) inboundPeerConnected(c net.Conn) {
   164  	peer, ok := c.(Peer)
   165  	if !ok {
   166  		log.Warnf("incoming connection %T does not satisfy "+
   167  			"server.Peer interface", c)
   168  		c.Close()
   169  		return
   170  	}
   171  
   172  	s.InboundPeerConnected(peer)
   173  }
   174  
   175  // InboundPeerConnected accepts a server.Peer, and handles the request submitted
   176  // by the client. This method serves also as a public endpoint for locally
   177  // registering new clients with the server.
   178  func (s *Server) InboundPeerConnected(peer Peer) {
   179  	select {
   180  	case s.newPeers <- peer:
   181  	case <-s.quit:
   182  	}
   183  }
   184  
   185  // peerHandler processes newly accepted peers and spawns a client handler for
   186  // each. The peerHandler is used to ensure that waitgrouped client handlers are
   187  // spawned from a waitgrouped goroutine.
   188  func (s *Server) peerHandler() {
   189  	defer s.wg.Done()
   190  	defer s.removeAllPeers()
   191  
   192  	for {
   193  		select {
   194  		case peer := <-s.newPeers:
   195  			s.wg.Add(1)
   196  			go s.handleClient(peer)
   197  
   198  		case <-s.quit:
   199  			return
   200  		}
   201  	}
   202  }
   203  
   204  // handleClient processes a series watchtower messages sent by a client. The
   205  // client may either send:
   206  //   - a single CreateSession message.
   207  //   - a series of StateUpdate messages.
   208  //
   209  // This method uses the server's peer map to ensure at most one peer using the
   210  // same session id can enter the main event loop. The connection will be
   211  // dropped by the watchtower if no messages are sent or received by the
   212  // configured Read/WriteTimeouts.
   213  //
   214  // NOTE: This method MUST be run as a goroutine.
   215  func (s *Server) handleClient(peer Peer) {
   216  	defer s.wg.Done()
   217  
   218  	// Use the connection's remote pubkey as the client's session id.
   219  	id := wtdb.NewSessionIDFromPubKey(peer.RemotePub())
   220  
   221  	// Register this peer in the server's client map, and defer the
   222  	// connection's cleanup. If the peer already exists, we will close the
   223  	// connection and exit immediately.
   224  	err := s.addPeer(&id, peer)
   225  	if err != nil {
   226  		peer.Close()
   227  		return
   228  	}
   229  	defer s.removePeer(&id, peer.RemoteAddr())
   230  
   231  	msg, err := s.readMessage(peer)
   232  	if err != nil {
   233  		log.Errorf("Unable to read message from client %s@%s: %v",
   234  			id, peer.RemoteAddr(), err)
   235  		return
   236  	}
   237  
   238  	remoteInit, ok := msg.(*wtwire.Init)
   239  	if !ok {
   240  		log.Errorf("client %s@%s did not send Init msg as first "+
   241  			"message", id, peer.RemoteAddr())
   242  		return
   243  	}
   244  
   245  	err = s.sendMessage(peer, s.localInit)
   246  	if err != nil {
   247  		log.Errorf("unable to send Init msg to %s: %v", id, err)
   248  		return
   249  	}
   250  
   251  	err = s.localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames)
   252  	if err != nil {
   253  		log.Errorf("Cannot support client %s: %v", id, err)
   254  		return
   255  	}
   256  
   257  	nextMsg, err := s.readMessage(peer)
   258  	if err != nil {
   259  		log.Errorf("Unable to read watchtower msg from %s: %v",
   260  			id, err)
   261  		return
   262  	}
   263  
   264  	switch msg := nextMsg.(type) {
   265  	case *wtwire.CreateSession:
   266  		// Attempt to open a new session for this client.
   267  		err = s.handleCreateSession(peer, &id, msg)
   268  		if err != nil {
   269  			log.Errorf("Unable to handle CreateSession "+
   270  				"from %s: %v", id, err)
   271  		}
   272  
   273  	case *wtwire.DeleteSession:
   274  		err = s.handleDeleteSession(peer, &id)
   275  		if err != nil {
   276  			log.Errorf("Unable to handle DeleteSession "+
   277  				"from %s: %v", id, err)
   278  		}
   279  
   280  	case *wtwire.StateUpdate:
   281  		err = s.handleStateUpdates(peer, &id, msg)
   282  		if err != nil {
   283  			log.Errorf("Unable to handle StateUpdate "+
   284  				"from %s: %v", id, err)
   285  		}
   286  
   287  	default:
   288  		log.Errorf("Received unsupported message type: %T "+
   289  			"from %s", nextMsg, id)
   290  	}
   291  }
   292  
   293  // connFailure is a default error used when a request failed with a non-zero
   294  // error code.
   295  type connFailure struct {
   296  	ID   wtdb.SessionID
   297  	Code wtwire.ErrorCode
   298  }
   299  
   300  // Error displays the SessionID and Code that caused the connection failure.
   301  func (f *connFailure) Error() string {
   302  	return fmt.Sprintf("connection with %s failed with code=%s",
   303  		f.ID, f.Code,
   304  	)
   305  }
   306  
   307  // readMessage receives and parses the next message from the given Peer. An
   308  // error is returned if a message is not received before the server's read
   309  // timeout, the read off the wire failed, or the message could not be
   310  // deserialized.
   311  func (s *Server) readMessage(peer Peer) (wtwire.Message, error) {
   312  	// Set a read timeout to ensure we drop the client if not sent in a
   313  	// timely manner.
   314  	err := peer.SetReadDeadline(time.Now().Add(s.cfg.ReadTimeout))
   315  	if err != nil {
   316  		err = fmt.Errorf("unable to set read deadline: %v", err)
   317  		return nil, err
   318  	}
   319  
   320  	// Pull the next message off the wire, and parse it according to the
   321  	// watchtower wire specification.
   322  	rawMsg, err := peer.ReadNextMessage()
   323  	if err != nil {
   324  		err = fmt.Errorf("unable to read message: %v", err)
   325  		return nil, err
   326  	}
   327  
   328  	msgReader := bytes.NewReader(rawMsg)
   329  	msg, err := wtwire.ReadMessage(msgReader, 0)
   330  	if err != nil {
   331  		err = fmt.Errorf("unable to parse message: %v", err)
   332  		return nil, err
   333  	}
   334  
   335  	logMessage(peer, msg, true)
   336  
   337  	return msg, nil
   338  }
   339  
   340  // sendMessage sends a watchtower wire message to the target peer.
   341  func (s *Server) sendMessage(peer Peer, msg wtwire.Message) error {
   342  	// TODO(conner): use buffer pool?
   343  
   344  	var b bytes.Buffer
   345  	_, err := wtwire.WriteMessage(&b, msg, 0)
   346  	if err != nil {
   347  		err = fmt.Errorf("unable to encode msg: %v", err)
   348  		return err
   349  	}
   350  
   351  	err = peer.SetWriteDeadline(time.Now().Add(s.cfg.WriteTimeout))
   352  	if err != nil {
   353  		err = fmt.Errorf("unable to set write deadline: %v", err)
   354  		return err
   355  	}
   356  
   357  	logMessage(peer, msg, false)
   358  
   359  	_, err = peer.Write(b.Bytes())
   360  	return err
   361  }
   362  
   363  // addPeer stores a client in the server's client map. An error is returned if a
   364  // client with the same session id already exists.
   365  func (s *Server) addPeer(id *wtdb.SessionID, peer Peer) error {
   366  	s.clientMtx.Lock()
   367  	defer s.clientMtx.Unlock()
   368  
   369  	if existingPeer, ok := s.clients[*id]; ok {
   370  		log.Infof("Already connected to peer %s@%s, disconnecting %s",
   371  			id, existingPeer.RemoteAddr(), peer.RemoteAddr())
   372  		return ErrPeerAlreadyConnected
   373  	}
   374  	s.clients[*id] = peer
   375  
   376  	log.Infof("Accepted incoming peer %s@%s",
   377  		id, peer.RemoteAddr())
   378  
   379  	return nil
   380  }
   381  
   382  // removePeer deletes a client from the server's client map. If a peer is found,
   383  // this method will close the peer's connection.
   384  func (s *Server) removePeer(id *wtdb.SessionID, addr net.Addr) {
   385  	log.Infof("Releasing incoming peer %s@%s", id, addr)
   386  
   387  	s.clientMtx.Lock()
   388  	peer, ok := s.clients[*id]
   389  	delete(s.clients, *id)
   390  	s.clientMtx.Unlock()
   391  
   392  	if ok {
   393  		peer.Close()
   394  	}
   395  }
   396  
   397  // removeAllPeers iterates through the server's current set of peers and closes
   398  // all open connections.
   399  func (s *Server) removeAllPeers() {
   400  	s.clientMtx.Lock()
   401  	defer s.clientMtx.Unlock()
   402  
   403  	for id, peer := range s.clients {
   404  		log.Infof("Releasing incoming peer %s@%s", id,
   405  			peer.RemoteAddr())
   406  
   407  		delete(s.clients, id)
   408  		peer.Close()
   409  	}
   410  }
   411  
   412  // logMessage writes information about a message exchanged with a remote peer,
   413  // using directional prepositions to signal whether the message was sent or
   414  // received.
   415  func logMessage(peer Peer, msg wtwire.Message, read bool) {
   416  	var action = "Received"
   417  	var preposition = "from"
   418  	if !read {
   419  		action = "Sending"
   420  		preposition = "to"
   421  	}
   422  
   423  	summary := wtwire.MessageSummary(msg)
   424  	if len(summary) > 0 {
   425  		summary = "(" + summary + ")"
   426  	}
   427  
   428  	log.Debugf("%s %s%v %s %x@%s", action, msg.MsgType(), summary,
   429  		preposition, peer.RemotePub().SerializeCompressed(),
   430  		peer.RemoteAddr())
   431  }
   432  
   433  // noDial is a dummy dial method passed to the server's connmgr.
   434  func noDial(string, string) (net.Conn, error) {
   435  	return nil, fmt.Errorf("watchtower cannot make outgoing conns")
   436  }