github.com/decred/dcrlnd@v0.7.6/watchtower/wtclient/session_negotiator.go (about)

     1  package wtclient
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  	"time"
     7  
     8  	"github.com/decred/dcrd/chaincfg/chainhash"
     9  	"github.com/decred/dcrlnd/keychain"
    10  	"github.com/decred/dcrlnd/lnwire"
    11  	"github.com/decred/dcrlnd/watchtower/blob"
    12  	"github.com/decred/dcrlnd/watchtower/wtdb"
    13  	"github.com/decred/dcrlnd/watchtower/wtpolicy"
    14  	"github.com/decred/dcrlnd/watchtower/wtserver"
    15  	"github.com/decred/dcrlnd/watchtower/wtwire"
    16  	"github.com/decred/slog"
    17  )
    18  
    19  // SessionNegotiator is an interface for asynchronously requesting new sessions.
    20  type SessionNegotiator interface {
    21  	// RequestSession signals to the session negotiator that the client
    22  	// needs another session. Once the session is negotiated, it should be
    23  	// returned via NewSessions.
    24  	RequestSession()
    25  
    26  	// NewSessions is a read-only channel where newly negotiated sessions
    27  	// will be delivered.
    28  	NewSessions() <-chan *wtdb.ClientSession
    29  
    30  	// Start safely initializes the session negotiator.
    31  	Start() error
    32  
    33  	// Stop safely shuts down the session negotiator.
    34  	Stop() error
    35  }
    36  
    37  // NegotiatorConfig provides access to the resources required by a
    38  // SessionNegotiator to faithfully carry out its duties. All nil-able field must
    39  // be initialized.
    40  type NegotiatorConfig struct {
    41  	// DB provides access to a persistent storage medium used by the tower
    42  	// to properly allocate session ephemeral keys and record successfully
    43  	// negotiated sessions.
    44  	DB DB
    45  
    46  	// SecretKeyRing allows the client to derive new session private keys
    47  	// when attempting to negotiate session with a tower.
    48  	SecretKeyRing ECDHKeyRing
    49  
    50  	// Candidates is an abstract set of tower candidates that the negotiator
    51  	// will traverse serially when attempting to negotiate a new session.
    52  	Candidates TowerCandidateIterator
    53  
    54  	// Policy defines the session policy that will be proposed to towers
    55  	// when attempting to negotiate a new session. This policy will be used
    56  	// across all negotiation proposals for the lifetime of the negotiator.
    57  	Policy wtpolicy.Policy
    58  
    59  	// Dial initiates an outbound brontide connection to the given address
    60  	// using a specified private key. The peer is returned in the event of a
    61  	// successful connection.
    62  	Dial func(keychain.SingleKeyECDH, *lnwire.NetAddress) (wtserver.Peer,
    63  		error)
    64  
    65  	// SendMessage writes a wtwire message to remote peer.
    66  	SendMessage func(wtserver.Peer, wtwire.Message) error
    67  
    68  	// ReadMessage reads a message from a remote peer and returns the
    69  	// decoded wtwire message.
    70  	ReadMessage func(wtserver.Peer) (wtwire.Message, error)
    71  
    72  	// ChainHash the genesis hash identifying the chain for any negotiated
    73  	// sessions. Any state updates sent to that session should also
    74  	// originate from this chain.
    75  	ChainHash chainhash.Hash
    76  
    77  	// MinBackoff defines the initial backoff applied by the session
    78  	// negotiator after all tower candidates have been exhausted and
    79  	// reattempting negotiation with the same set of candidates. Subsequent
    80  	// backoff durations will grow exponentially.
    81  	MinBackoff time.Duration
    82  
    83  	// MaxBackoff defines the maximum backoff applied by the session
    84  	// negotiator after all tower candidates have been exhausted and
    85  	// reattempting negotiation with the same set of candidates. If the
    86  	// exponential backoff produces a timeout greater than this value, the
    87  	// backoff duration will be clamped to MaxBackoff.
    88  	MaxBackoff time.Duration
    89  
    90  	// Log specifies the desired log output, which should be prefixed by the
    91  	// client type, e.g. anchor or legacy.
    92  	Log slog.Logger
    93  }
    94  
    95  // sessionNegotiator is concrete SessionNegotiator that is able to request new
    96  // sessions from a set of candidate towers asynchronously and return successful
    97  // sessions to the primary client.
    98  type sessionNegotiator struct {
    99  	started sync.Once
   100  	stopped sync.Once
   101  
   102  	localInit *wtwire.Init
   103  
   104  	cfg *NegotiatorConfig
   105  	log slog.Logger
   106  
   107  	dispatcher             chan struct{}
   108  	newSessions            chan *wtdb.ClientSession
   109  	successfulNegotiations chan *wtdb.ClientSession
   110  
   111  	wg   sync.WaitGroup
   112  	quit chan struct{}
   113  }
   114  
   115  // Compile-time constraint to ensure a *sessionNegotiator implements the
   116  // SessionNegotiator interface.
   117  var _ SessionNegotiator = (*sessionNegotiator)(nil)
   118  
   119  // newSessionNegotiator initializes a fresh sessionNegotiator instance.
   120  func newSessionNegotiator(cfg *NegotiatorConfig) *sessionNegotiator {
   121  	// Generate the set of features the negitator will present to the tower
   122  	// upon connection. For anchor channels, we'll conditionally signal that
   123  	// we require support for anchor channels depdening on the requested
   124  	// policy.
   125  	features := []lnwire.FeatureBit{
   126  		wtwire.AltruistSessionsRequired,
   127  	}
   128  	if cfg.Policy.IsAnchorChannel() {
   129  		features = append(features, wtwire.AnchorCommitRequired)
   130  	}
   131  
   132  	localInit := wtwire.NewInitMessage(
   133  		lnwire.NewRawFeatureVector(features...),
   134  		cfg.ChainHash,
   135  	)
   136  
   137  	return &sessionNegotiator{
   138  		cfg:                    cfg,
   139  		log:                    cfg.Log,
   140  		localInit:              localInit,
   141  		dispatcher:             make(chan struct{}, 1),
   142  		newSessions:            make(chan *wtdb.ClientSession),
   143  		successfulNegotiations: make(chan *wtdb.ClientSession),
   144  		quit:                   make(chan struct{}),
   145  	}
   146  }
   147  
   148  // Start safely starts up the sessionNegotiator.
   149  func (n *sessionNegotiator) Start() error {
   150  	n.started.Do(func() {
   151  		n.log.Debugf("Starting session negotiator")
   152  
   153  		n.wg.Add(1)
   154  		go n.negotiationDispatcher()
   155  	})
   156  
   157  	return nil
   158  }
   159  
   160  // Stop safely shutsdown the sessionNegotiator.
   161  func (n *sessionNegotiator) Stop() error {
   162  	n.stopped.Do(func() {
   163  		n.log.Debugf("Stopping session negotiator")
   164  
   165  		close(n.quit)
   166  		n.wg.Wait()
   167  	})
   168  
   169  	return nil
   170  }
   171  
   172  // NewSessions returns a receive-only channel from which newly negotiated
   173  // sessions will be returned.
   174  func (n *sessionNegotiator) NewSessions() <-chan *wtdb.ClientSession {
   175  	return n.newSessions
   176  }
   177  
   178  // RequestSession sends a request to the sessionNegotiator to begin requesting a
   179  // new session. If one is already in the process of being negotiated, the
   180  // request will be ignored.
   181  func (n *sessionNegotiator) RequestSession() {
   182  	select {
   183  	case n.dispatcher <- struct{}{}:
   184  	default:
   185  	}
   186  }
   187  
   188  // negotiationDispatcher acts as the primary event loop for the
   189  // sessionNegotiator, coordinating requests for more sessions and dispatching
   190  // attempts to negotiate them from a list of candidates.
   191  func (n *sessionNegotiator) negotiationDispatcher() {
   192  	defer n.wg.Done()
   193  
   194  	var pendingNegotiations int
   195  	for {
   196  		select {
   197  		case <-n.dispatcher:
   198  			pendingNegotiations++
   199  
   200  			if pendingNegotiations > 1 {
   201  				n.log.Debugf("Already negotiating session, " +
   202  					"waiting for existing negotiation to " +
   203  					"complete")
   204  				continue
   205  			}
   206  
   207  			// TODO(conner): consider reusing good towers
   208  
   209  			n.log.Debugf("Dispatching session negotiation")
   210  
   211  			n.wg.Add(1)
   212  			go n.negotiate()
   213  
   214  		case session := <-n.successfulNegotiations:
   215  			select {
   216  			case n.newSessions <- session:
   217  				pendingNegotiations--
   218  			case <-n.quit:
   219  				return
   220  			}
   221  
   222  			if pendingNegotiations > 0 {
   223  				n.log.Debugf("Dispatching pending session " +
   224  					"negotiation")
   225  
   226  				n.wg.Add(1)
   227  				go n.negotiate()
   228  			}
   229  
   230  		case <-n.quit:
   231  			return
   232  		}
   233  	}
   234  }
   235  
   236  // negotiate handles the process of iterating through potential tower candidates
   237  // and attempting to negotiate a new session until a successful negotiation
   238  // occurs. If the candidate iterator becomes exhausted because none were
   239  // successful, this method will back off exponentially up to the configured max
   240  // backoff. This method will continue trying until a negotiation is successful
   241  // before returning the negotiated session to the dispatcher via the succeed
   242  // channel.
   243  //
   244  // NOTE: This method MUST be run as a goroutine.
   245  func (n *sessionNegotiator) negotiate() {
   246  	defer n.wg.Done()
   247  
   248  	// On the first pass, initialize the backoff to our configured min
   249  	// backoff.
   250  	var backoff time.Duration
   251  
   252  	// Create a closure to update the backoff upon failure such that it
   253  	// stays within our min and max backoff parameters.
   254  	updateBackoff := func() {
   255  		if backoff == 0 {
   256  			backoff = n.cfg.MinBackoff
   257  		} else {
   258  			backoff *= 2
   259  			if backoff > n.cfg.MaxBackoff {
   260  				backoff = n.cfg.MaxBackoff
   261  			}
   262  		}
   263  	}
   264  
   265  retryWithBackoff:
   266  	// If we are retrying, wait out the delay before continuing.
   267  	if backoff > 0 {
   268  		select {
   269  		case <-time.After(backoff):
   270  		case <-n.quit:
   271  			return
   272  		}
   273  	}
   274  
   275  	for {
   276  		select {
   277  		case <-n.quit:
   278  			return
   279  		default:
   280  		}
   281  
   282  		// Pull the next candidate from our list of addresses.
   283  		tower, err := n.cfg.Candidates.Next()
   284  		if err != nil {
   285  			// We've run out of addresses, update our backoff.
   286  			updateBackoff()
   287  
   288  			n.log.Debugf("Unable to get new tower candidate, "+
   289  				"retrying after %v -- reason: %v", backoff, err)
   290  
   291  			// Only reset the iterator once we've exhausted all
   292  			// candidates. Doing so allows us to load balance
   293  			// sessions better amongst all of the tower candidates.
   294  			if err == ErrTowerCandidatesExhausted {
   295  				n.cfg.Candidates.Reset()
   296  			}
   297  
   298  			goto retryWithBackoff
   299  		}
   300  
   301  		towerPub := tower.IdentityKey.SerializeCompressed()
   302  		n.log.Debugf("Attempting session negotiation with tower=%x",
   303  			towerPub)
   304  
   305  		// Before proceeding, we will reserve a session key index to use
   306  		// with this specific tower. If one is already reserved, the
   307  		// existing index will be returned.
   308  		keyIndex, err := n.cfg.DB.NextSessionKeyIndex(
   309  			tower.ID, n.cfg.Policy.BlobType,
   310  		)
   311  		if err != nil {
   312  			n.log.Debugf("Unable to reserve session key index "+
   313  				"for tower=%x: %v", towerPub, err)
   314  			continue
   315  		}
   316  
   317  		// We'll now attempt the CreateSession dance with the tower to
   318  		// get a new session, trying all addresses if necessary.
   319  		err = n.createSession(tower, keyIndex)
   320  		if err != nil {
   321  			// An unexpected error occurred, updpate our backoff.
   322  			updateBackoff()
   323  
   324  			n.log.Debugf("Session negotiation with tower=%x "+
   325  				"failed, trying again -- reason: %v",
   326  				tower.IdentityKey.SerializeCompressed(), err)
   327  
   328  			goto retryWithBackoff
   329  		}
   330  
   331  		// Success.
   332  		return
   333  	}
   334  }
   335  
   336  // createSession takes a tower an attempts to negotiate a session using any of
   337  // its stored addresses. This method returns after the first successful
   338  // negotiation, or after all addresses have failed with ErrFailedNegotiation. If
   339  // the tower has no addresses, ErrNoTowerAddrs is returned.
   340  func (n *sessionNegotiator) createSession(tower *wtdb.Tower,
   341  	keyIndex uint32) error {
   342  
   343  	// If the tower has no addresses, there's nothing we can do.
   344  	if len(tower.Addresses) == 0 {
   345  		return ErrNoTowerAddrs
   346  	}
   347  
   348  	sessionKeyDesc, err := n.cfg.SecretKeyRing.DeriveKey(
   349  		keychain.KeyLocator{
   350  			Family: keychain.KeyFamilyTowerSession,
   351  			Index:  keyIndex,
   352  		},
   353  	)
   354  	if err != nil {
   355  		return err
   356  	}
   357  	sessionKey := keychain.NewPubKeyECDH(
   358  		sessionKeyDesc, n.cfg.SecretKeyRing,
   359  	)
   360  
   361  	for _, lnAddr := range tower.LNAddrs() {
   362  		err := n.tryAddress(sessionKey, keyIndex, tower, lnAddr)
   363  		switch {
   364  		case err == ErrPermanentTowerFailure:
   365  			// TODO(conner): report to iterator? can then be reset
   366  			// with restart
   367  			fallthrough
   368  
   369  		case err != nil:
   370  			n.log.Debugf("Request for session negotiation with "+
   371  				"tower=%s failed, trying again -- reason: "+
   372  				"%v", lnAddr, err)
   373  			continue
   374  
   375  		default:
   376  			return nil
   377  		}
   378  	}
   379  
   380  	return ErrFailedNegotiation
   381  }
   382  
   383  // tryAddress executes a single create session dance using the given address.
   384  // The address should belong to the tower's set of addresses. This method only
   385  // returns true if all steps succeed and the new session has been persisted, and
   386  // fails otherwise.
   387  func (n *sessionNegotiator) tryAddress(sessionKey keychain.SingleKeyECDH,
   388  	keyIndex uint32, tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error {
   389  
   390  	// Connect to the tower address using our generated session key.
   391  	conn, err := n.cfg.Dial(sessionKey, lnAddr)
   392  	if err != nil {
   393  		return err
   394  	}
   395  
   396  	// Send local Init message.
   397  	err = n.cfg.SendMessage(conn, n.localInit)
   398  	if err != nil {
   399  		return fmt.Errorf("unable to send Init: %v", err)
   400  	}
   401  
   402  	// Receive remote Init message.
   403  	remoteMsg, err := n.cfg.ReadMessage(conn)
   404  	if err != nil {
   405  		return fmt.Errorf("unable to read Init: %v", err)
   406  	}
   407  
   408  	// Check that returned message is wtwire.Init.
   409  	remoteInit, ok := remoteMsg.(*wtwire.Init)
   410  	if !ok {
   411  		return fmt.Errorf("expected Init, got %T in reply", remoteMsg)
   412  	}
   413  
   414  	// Verify the watchtower's remote Init message against our own.
   415  	err = n.localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames)
   416  	if err != nil {
   417  		return err
   418  	}
   419  
   420  	policy := n.cfg.Policy
   421  	createSession := &wtwire.CreateSession{
   422  		BlobType:     policy.BlobType,
   423  		MaxUpdates:   policy.MaxUpdates,
   424  		RewardBase:   policy.RewardBase,
   425  		RewardRate:   policy.RewardRate,
   426  		SweepFeeRate: policy.SweepFeeRate,
   427  	}
   428  
   429  	// Send CreateSession message.
   430  	err = n.cfg.SendMessage(conn, createSession)
   431  	if err != nil {
   432  		return fmt.Errorf("unable to send CreateSession: %v", err)
   433  	}
   434  
   435  	// Receive CreateSessionReply message.
   436  	remoteMsg, err = n.cfg.ReadMessage(conn)
   437  	if err != nil {
   438  		return fmt.Errorf("unable to read CreateSessionReply: %v", err)
   439  	}
   440  
   441  	// Check that returned message is wtwire.CreateSessionReply.
   442  	createSessionReply, ok := remoteMsg.(*wtwire.CreateSessionReply)
   443  	if !ok {
   444  		return fmt.Errorf("expected CreateSessionReply, got %T in "+
   445  			"reply", remoteMsg)
   446  	}
   447  
   448  	switch createSessionReply.Code {
   449  	case wtwire.CodeOK, wtwire.CreateSessionCodeAlreadyExists:
   450  
   451  		// TODO(conner): add last-applied to create session reply to
   452  		// handle case where we lose state, session already exists, and
   453  		// we want to possibly resume using the session
   454  
   455  		// TODO(conner): validate reward address
   456  		rewardPkScript := createSessionReply.Data
   457  
   458  		sessionID := wtdb.NewSessionIDFromPubKey(sessionKey.PubKey())
   459  		clientSession := &wtdb.ClientSession{
   460  			ClientSessionBody: wtdb.ClientSessionBody{
   461  				TowerID:        tower.ID,
   462  				KeyIndex:       keyIndex,
   463  				Policy:         n.cfg.Policy,
   464  				RewardPkScript: rewardPkScript,
   465  			},
   466  			Tower:          tower,
   467  			SessionKeyECDH: sessionKey,
   468  			ID:             sessionID,
   469  		}
   470  
   471  		err = n.cfg.DB.CreateClientSession(clientSession)
   472  		if err != nil {
   473  			return fmt.Errorf("unable to persist ClientSession: %v",
   474  				err)
   475  		}
   476  
   477  		n.log.Debugf("New session negotiated with %s, policy: %s",
   478  			lnAddr, clientSession.Policy)
   479  
   480  		// We have a newly negotiated session, return it to the
   481  		// dispatcher so that it can update how many outstanding
   482  		// negotiation requests we have.
   483  		select {
   484  		case n.successfulNegotiations <- clientSession:
   485  			return nil
   486  		case <-n.quit:
   487  			return ErrNegotiatorExiting
   488  		}
   489  
   490  	// TODO(conner): handle error codes properly
   491  	case wtwire.CreateSessionCodeRejectBlobType:
   492  		return fmt.Errorf("tower rejected blob type: %v",
   493  			policy.BlobType)
   494  
   495  	case wtwire.CreateSessionCodeRejectMaxUpdates:
   496  		return fmt.Errorf("tower rejected max updates: %v",
   497  			policy.MaxUpdates)
   498  
   499  	case wtwire.CreateSessionCodeRejectRewardRate:
   500  		// The tower rejected the session because of the reward rate. If
   501  		// we didn't request a reward session, we'll treat this as a
   502  		// permanent tower failure.
   503  		if !policy.BlobType.Has(blob.FlagReward) {
   504  			return ErrPermanentTowerFailure
   505  		}
   506  
   507  		return fmt.Errorf("tower rejected reward rate: %v",
   508  			policy.RewardRate)
   509  
   510  	case wtwire.CreateSessionCodeRejectSweepFeeRate:
   511  		return fmt.Errorf("tower rejected sweep fee rate: %v",
   512  			policy.SweepFeeRate)
   513  
   514  	default:
   515  		return fmt.Errorf("received unhandled error code: %v",
   516  			createSessionReply.Code)
   517  	}
   518  }