github.com/mutagen-io/mutagen@v0.18.0-rc1/pkg/forwarding/controller.go (about)

     1  package forwarding
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"os"
     8  	"sync"
     9  	"time"
    10  
    11  	"google.golang.org/protobuf/proto"
    12  	"google.golang.org/protobuf/types/known/timestamppb"
    13  
    14  	"github.com/mutagen-io/mutagen/pkg/encoding"
    15  	"github.com/mutagen-io/mutagen/pkg/logging"
    16  	"github.com/mutagen-io/mutagen/pkg/mutagen"
    17  	"github.com/mutagen-io/mutagen/pkg/prompting"
    18  	"github.com/mutagen-io/mutagen/pkg/state"
    19  	"github.com/mutagen-io/mutagen/pkg/url"
    20  )
    21  
    22  const (
    23  	// autoReconnectInterval is the period of time to wait before attempting an
    24  	// automatic reconnect after disconnection or a failed reconnect.
    25  	autoReconnectInterval = 15 * time.Second
    26  )
    27  
    28  // controller manages and executes a single session.
    29  type controller struct {
    30  	// logger is the controller logger.
    31  	logger *logging.Logger
    32  	// sessionPath is the path to the serialized session.
    33  	sessionPath string
    34  	// stateLock guards and tracks changes to session's Paused field and state.
    35  	stateLock *state.TrackingLock
    36  	// session encodes the associated session metadata. It is considered static
    37  	// and safe for concurrent access except for its Paused field, for which
    38  	// stateLock should be held. It should be saved to disk any time it is
    39  	// modified.
    40  	session *Session
    41  	// mergedSourceConfiguration is the source-specific configuration object
    42  	// (computed from the core configuration and source-specific overrides). It
    43  	// is considered static and safe for concurrent access. It is a derived
    44  	// field and not saved to disk.
    45  	mergedSourceConfiguration *Configuration
    46  	// mergedDestinationConfiguration is the destination-specific configuration
    47  	// object (computed from the core configuration and destination-specific
    48  	// overrides). It is considered static and safe for concurrent access. It is
    49  	// a derived field and not saved to disk.
    50  	mergedDestinationConfiguration *Configuration
    51  	// state represents the current forwarding state.
    52  	state *State
    53  	// lifecycleLock guards access to disabled, cancel, and done. Only the
    54  	// current holder of the lifecycle lock may set any of these fields or
    55  	// invoke cancel. The forwarding loop may close done without holding the
    56  	// lifecycle lock. Moreover, previous lifecycle lock holders may poll on
    57  	// done after storing it in a separate variable and releasing the lifecycle
    58  	// lock. Any code wishing to set these fields must first acquire the lock,
    59  	// then cancel the forwarding loop and wait for it to complete before making
    60  	// any changes.
    61  	lifecycleLock sync.Mutex
    62  	// disabled indicates that no more changes to the forwarding loop lifecycle
    63  	// are allowed (i.e. no more forwarding loops can be started for this
    64  	// controller). This is used by terminate and shutdown. It should only be
    65  	// set to true once any existing forwarding loop has been stopped.
    66  	disabled bool
    67  	// cancel cancels the forwarding loop execution context. It should be nil if
    68  	// and only if there is no forwarding loop running.
    69  	cancel context.CancelFunc
    70  	// done will be closed by the current forwarding loop when it exits.
    71  	done chan struct{}
    72  }
    73  
    74  // newSession creates a new session and corresponding controller.
    75  func newSession(
    76  	ctx context.Context,
    77  	logger *logging.Logger,
    78  	tracker *state.Tracker,
    79  	identifier string,
    80  	source, destination *url.URL,
    81  	configuration, configurationSource, configurationDestination *Configuration,
    82  	name string,
    83  	labels map[string]string,
    84  	paused bool,
    85  	prompter string,
    86  ) (*controller, error) {
    87  	// Update status.
    88  	prompting.Message(prompter, "Creating session...")
    89  
    90  	// Set the session version.
    91  	version := DefaultVersion
    92  
    93  	// Compute the creation time and check that it's valid for Protocol Buffers.
    94  	creationTime := timestamppb.Now()
    95  	if err := creationTime.CheckValid(); err != nil {
    96  		return nil, fmt.Errorf("unable to record creation time: %w", err)
    97  	}
    98  
    99  	// Compute merged endpoint configurations.
   100  	mergedSourceConfiguration := MergeConfigurations(configuration, configurationSource)
   101  	mergedDestinationConfiguration := MergeConfigurations(configuration, configurationDestination)
   102  
   103  	// If the session isn't being created paused, then try to connect to the
   104  	// endpoints. Before doing so, set up a deferred handler that will shut down
   105  	// any endpoints that aren't handed off to the run loop due to errors.
   106  	var sourceEndpoint, destinationEndpoint Endpoint
   107  	var err error
   108  	defer func() {
   109  		if sourceEndpoint != nil {
   110  			sourceEndpoint.Shutdown()
   111  			sourceEndpoint = nil
   112  		}
   113  		if destinationEndpoint != nil {
   114  			destinationEndpoint.Shutdown()
   115  			destinationEndpoint = nil
   116  		}
   117  	}()
   118  	if !paused {
   119  		logger.Info("Connecting to source endpoint")
   120  		sourceEndpoint, err = connect(
   121  			ctx,
   122  			logger.Sublogger("source"),
   123  			source,
   124  			prompter,
   125  			identifier,
   126  			version,
   127  			mergedSourceConfiguration,
   128  			true,
   129  		)
   130  		if err != nil {
   131  			logger.Info("Source connection failure:", err)
   132  			return nil, fmt.Errorf("unable to connect to source: %w", err)
   133  		}
   134  		logger.Info("Connecting to destination endpoint")
   135  		destinationEndpoint, err = connect(
   136  			ctx,
   137  			logger.Sublogger("destination"),
   138  			destination,
   139  			prompter,
   140  			identifier,
   141  			version,
   142  			mergedDestinationConfiguration,
   143  			false,
   144  		)
   145  		if err != nil {
   146  			logger.Info("Destination connection failure:", err)
   147  			return nil, fmt.Errorf("unable to connect to destination: %w", err)
   148  		}
   149  	}
   150  
   151  	// Create the session.
   152  	session := &Session{
   153  		Identifier:               identifier,
   154  		Version:                  version,
   155  		CreationTime:             creationTime,
   156  		CreatingVersionMajor:     mutagen.VersionMajor,
   157  		CreatingVersionMinor:     mutagen.VersionMinor,
   158  		CreatingVersionPatch:     mutagen.VersionPatch,
   159  		Source:                   source,
   160  		Destination:              destination,
   161  		Configuration:            configuration,
   162  		ConfigurationSource:      configurationSource,
   163  		ConfigurationDestination: configurationDestination,
   164  		Name:                     name,
   165  		Labels:                   labels,
   166  		Paused:                   paused,
   167  	}
   168  
   169  	// Compute the session path.
   170  	sessionPath, err := pathForSession(session.Identifier)
   171  	if err != nil {
   172  		return nil, fmt.Errorf("unable to compute session path: %w", err)
   173  	}
   174  
   175  	// Save the session to disk.
   176  	if err := encoding.MarshalAndSaveProtobuf(sessionPath, session); err != nil {
   177  		return nil, fmt.Errorf("unable to save session: %w", err)
   178  	}
   179  
   180  	// Create the controller.
   181  	controller := &controller{
   182  		logger:                         logger,
   183  		sessionPath:                    sessionPath,
   184  		stateLock:                      state.NewTrackingLock(tracker),
   185  		session:                        session,
   186  		mergedSourceConfiguration:      mergedSourceConfiguration,
   187  		mergedDestinationConfiguration: mergedDestinationConfiguration,
   188  		state: &State{
   189  			Session:          session,
   190  			SourceState:      &EndpointState{},
   191  			DestinationState: &EndpointState{},
   192  		},
   193  	}
   194  
   195  	// If the session isn't being created paused, then start a forwarding loop
   196  	// and mark the endpoints as handed off to that loop so that we don't defer
   197  	// their shutdown.
   198  	if !paused {
   199  		ctx, cancel := context.WithCancel(context.Background())
   200  		controller.cancel = cancel
   201  		controller.done = make(chan struct{})
   202  		go controller.run(ctx, sourceEndpoint, destinationEndpoint)
   203  		sourceEndpoint = nil
   204  		destinationEndpoint = nil
   205  	}
   206  
   207  	// Success.
   208  	logger.Info("Session initialized")
   209  	return controller, nil
   210  }
   211  
   212  // loadSession loads an existing session and creates a corresponding controller.
   213  func loadSession(logger *logging.Logger, tracker *state.Tracker, identifier string) (*controller, error) {
   214  	// Compute the session path.
   215  	sessionPath, err := pathForSession(identifier)
   216  	if err != nil {
   217  		return nil, fmt.Errorf("unable to compute session path: %w", err)
   218  	}
   219  
   220  	// Load and validate the session.
   221  	session := &Session{}
   222  	if err := encoding.LoadAndUnmarshalProtobuf(sessionPath, session); err != nil {
   223  		return nil, fmt.Errorf("unable to load session configuration: %w", err)
   224  	}
   225  	if err := session.EnsureValid(); err != nil {
   226  		return nil, fmt.Errorf("invalid session found on disk: %w", err)
   227  	}
   228  
   229  	// Create the controller.
   230  	controller := &controller{
   231  		logger:      logger,
   232  		sessionPath: sessionPath,
   233  		stateLock:   state.NewTrackingLock(tracker),
   234  		session:     session,
   235  		mergedSourceConfiguration: MergeConfigurations(
   236  			session.Configuration,
   237  			session.ConfigurationSource,
   238  		),
   239  		mergedDestinationConfiguration: MergeConfigurations(
   240  			session.Configuration,
   241  			session.ConfigurationDestination,
   242  		),
   243  		state: &State{
   244  			Session:          session,
   245  			SourceState:      &EndpointState{},
   246  			DestinationState: &EndpointState{},
   247  		},
   248  	}
   249  
   250  	// If the session isn't marked as paused, start a forwarding loop.
   251  	if !session.Paused {
   252  		ctx, cancel := context.WithCancel(context.Background())
   253  		controller.cancel = cancel
   254  		controller.done = make(chan struct{})
   255  		go controller.run(ctx, nil, nil)
   256  	}
   257  
   258  	// Success.
   259  	logger.Info("Session loaded")
   260  	return controller, nil
   261  }
   262  
   263  // currentState creates a static snapshot of the current session state.
   264  func (c *controller) currentState() *State {
   265  	// Lock the session state and defer its release. It's very important that we
   266  	// unlock without a notification here, otherwise we'd trigger an infinite
   267  	// cycle of list/notify.
   268  	c.stateLock.Lock()
   269  	defer c.stateLock.UnlockWithoutNotify()
   270  
   271  	// Create a static copy of the state.
   272  	return proto.Clone(c.state).(*State)
   273  }
   274  
   275  // resume attempts to reconnect and resume the session if it isn't currently
   276  // connected and forwarding.
   277  func (c *controller) resume(ctx context.Context, prompter string) error {
   278  	// Update status.
   279  	prompting.Message(prompter, fmt.Sprintf("Resuming session %s...", c.session.Identifier))
   280  
   281  	// Lock the controller's lifecycle and defer its release.
   282  	c.lifecycleLock.Lock()
   283  	defer c.lifecycleLock.Unlock()
   284  
   285  	// Don't allow any resume operations if the controller is disabled.
   286  	if c.disabled {
   287  		return errors.New("controller disabled")
   288  	}
   289  
   290  	// Perform logging.
   291  	c.logger.Infof("Resuming")
   292  
   293  	// Check if there's an existing forwarding loop (i.e. if the session is
   294  	// unpaused).
   295  	if c.cancel != nil {
   296  		// If there is an existing forwarding loop, check if it's already in a
   297  		// state that's considered "forwarding".
   298  		c.stateLock.Lock()
   299  		forwarding := c.state.Status >= Status_ForwardingConnections
   300  		c.stateLock.UnlockWithoutNotify()
   301  
   302  		// If we're already forwarding, then there's nothing we need to do. We
   303  		// don't even need to mark the session as unpaused because it can't be
   304  		// marked as paused if an existing forwarding loop is running (we
   305  		// enforce this invariant as part of the controller's logic).
   306  		if forwarding {
   307  			return nil
   308  		}
   309  
   310  		// Otherwise, cancel the existing forwarding loop and wait for it to
   311  		// finish.
   312  		//
   313  		// There's something of an efficiency race condition here, because the
   314  		// existing loop might succeed in connecting between the time we check
   315  		// and the time we cancel it. That could happen if an auto-reconnect
   316  		// succeeds or even if the loop was already passed connections and it's
   317  		// just hasn't updated its status yet. But the only danger here is
   318  		// basically wasting those connections, and the window is very small.
   319  		c.cancel()
   320  		<-c.done
   321  
   322  		// Nil out any lifecycle state.
   323  		c.cancel = nil
   324  		c.done = nil
   325  	}
   326  
   327  	// Mark the session as unpaused and save it to disk.
   328  	c.stateLock.Lock()
   329  	c.session.Paused = false
   330  	saveErr := encoding.MarshalAndSaveProtobuf(c.sessionPath, c.session)
   331  	c.stateLock.Unlock()
   332  
   333  	// Attempt to connect to source.
   334  	c.stateLock.Lock()
   335  	c.state.Status = Status_ConnectingSource
   336  	c.stateLock.Unlock()
   337  	source, sourceConnectErr := connect(
   338  		ctx,
   339  		c.logger.Sublogger("source"),
   340  		c.session.Source,
   341  		prompter,
   342  		c.session.Identifier,
   343  		c.session.Version,
   344  		c.mergedSourceConfiguration,
   345  		true,
   346  	)
   347  	c.stateLock.Lock()
   348  	c.state.SourceState.Connected = (source != nil)
   349  	c.stateLock.Unlock()
   350  
   351  	// Attempt to connect to destination.
   352  	c.stateLock.Lock()
   353  	c.state.Status = Status_ConnectingDestination
   354  	c.stateLock.Unlock()
   355  	destination, destinationConnectErr := connect(
   356  		ctx,
   357  		c.logger.Sublogger("destination"),
   358  		c.session.Destination,
   359  		prompter,
   360  		c.session.Identifier,
   361  		c.session.Version,
   362  		c.mergedDestinationConfiguration,
   363  		false,
   364  	)
   365  	c.stateLock.Lock()
   366  	c.state.DestinationState.Connected = (destination != nil)
   367  	c.stateLock.Unlock()
   368  
   369  	// Start the forwarding loop with what we have. Source or destination may
   370  	// have failed to connect (and be nil), but in any case that'll just make
   371  	// the run loop keep trying to connect.
   372  	ctx, cancel := context.WithCancel(context.Background())
   373  	c.cancel = cancel
   374  	c.done = make(chan struct{})
   375  	go c.run(ctx, source, destination)
   376  
   377  	// Report any errors. Since we always want to start a forwarding loop, even
   378  	// on partial or complete failure (since it might be able to auto-reconnect
   379  	// on its own), we wait until the end to report errors.
   380  	if saveErr != nil {
   381  		return fmt.Errorf("unable to save session: %w", saveErr)
   382  	} else if sourceConnectErr != nil {
   383  		return fmt.Errorf("unable to connect to source: %w", sourceConnectErr)
   384  	} else if destinationConnectErr != nil {
   385  		return fmt.Errorf("unable to connect to destination: %w", destinationConnectErr)
   386  	}
   387  
   388  	// Success.
   389  	return nil
   390  }
   391  
   392  // controllerHaltMode represents the behavior to use when halting a session.
   393  type controllerHaltMode uint8
   394  
   395  const (
   396  	// controllerHaltModePause indicates that a session should be halted and
   397  	// marked as paused.
   398  	controllerHaltModePause controllerHaltMode = iota
   399  	// controllerHaltModeShutdown indicates that a session should be halted.
   400  	controllerHaltModeShutdown
   401  	// controllerHaltModeShutdown indicates that a session should be halted and
   402  	// then deleted.
   403  	controllerHaltModeTerminate
   404  )
   405  
   406  // description returns a human-readable description of a halt mode.
   407  func (m controllerHaltMode) description() string {
   408  	switch m {
   409  	case controllerHaltModePause:
   410  		return "Pausing"
   411  	case controllerHaltModeShutdown:
   412  		return "Shutting down"
   413  	case controllerHaltModeTerminate:
   414  		return "Terminating"
   415  	default:
   416  		panic("unhandled halt mode")
   417  	}
   418  }
   419  
   420  // halt halts the session with the specified behavior.
   421  func (c *controller) halt(_ context.Context, mode controllerHaltMode, prompter string) error {
   422  	// Update status.
   423  	prompting.Message(prompter, fmt.Sprintf("%s session %s...", mode.description(), c.session.Identifier))
   424  
   425  	// Lock the controller's lifecycle and defer its release.
   426  	c.lifecycleLock.Lock()
   427  	defer c.lifecycleLock.Unlock()
   428  
   429  	// Don't allow any additional halt operations if the controller is disabled,
   430  	// because either this session is being terminated or the service is
   431  	// shutting down, and in either case there is no point in halting.
   432  	if c.disabled {
   433  		return errors.New("controller disabled")
   434  	}
   435  
   436  	// Perform logging.
   437  	c.logger.Infof(mode.description())
   438  
   439  	// Kill any existing forwarding loop.
   440  	if c.cancel != nil {
   441  		// Cancel the forwarding loop and wait for it to finish.
   442  		c.cancel()
   443  		<-c.done
   444  
   445  		// Nil out any lifecycle state.
   446  		c.cancel = nil
   447  		c.done = nil
   448  	}
   449  
   450  	// Handle based on the halt mode.
   451  	if mode == controllerHaltModePause {
   452  		// Mark the session as paused and save it.
   453  		c.stateLock.Lock()
   454  		c.session.Paused = true
   455  		saveErr := encoding.MarshalAndSaveProtobuf(c.sessionPath, c.session)
   456  		c.stateLock.Unlock()
   457  		if saveErr != nil {
   458  			return fmt.Errorf("unable to save session: %w", saveErr)
   459  		}
   460  	} else if mode == controllerHaltModeShutdown {
   461  		// Disable the controller.
   462  		c.disabled = true
   463  	} else if mode == controllerHaltModeTerminate {
   464  		// Disable the controller.
   465  		c.disabled = true
   466  
   467  		// Wipe the session information from disk.
   468  		sessionRemoveErr := os.Remove(c.sessionPath)
   469  		if sessionRemoveErr != nil {
   470  			return fmt.Errorf("unable to remove session from disk: %w", sessionRemoveErr)
   471  		}
   472  	} else {
   473  		panic("invalid halt mode specified")
   474  	}
   475  
   476  	// Success.
   477  	return nil
   478  }
   479  
   480  // run is the main run loop for the controller, managing connectivity and
   481  // forwarding.
   482  func (c *controller) run(ctx context.Context, source, destination Endpoint) {
   483  	// Log run loop entry.
   484  	c.logger.Debug("Run loop commencing")
   485  
   486  	// Defer resource and state cleanup.
   487  	defer func() {
   488  		// Shutdown any endpoints. These might be non-nil if the run loop was
   489  		// cancelled while partially connected rather than after forwarding
   490  		// failure.
   491  		if source != nil {
   492  			source.Shutdown()
   493  		}
   494  		if destination != nil {
   495  			destination.Shutdown()
   496  		}
   497  
   498  		// Reset the state.
   499  		c.stateLock.Lock()
   500  		c.state = &State{
   501  			Session:          c.session,
   502  			SourceState:      &EndpointState{},
   503  			DestinationState: &EndpointState{},
   504  		}
   505  		c.stateLock.Unlock()
   506  
   507  		// Log run loop termination.
   508  		c.logger.Debug("Run loop terminated")
   509  
   510  		// Signal completion.
   511  		close(c.done)
   512  	}()
   513  
   514  	// Track the last time that forwarding failed.
   515  	var lastForwardingFailureTime time.Time
   516  
   517  	// Loop until cancelled.
   518  	for {
   519  		// Loop until we're connected to both endpoints. We do a non-blocking
   520  		// check for cancellation on each reconnect error so that we don't waste
   521  		// resources by trying another connect when the context has been
   522  		// cancelled (it'll be wasteful). This is better than sentinel errors.
   523  		for {
   524  			// Ensure that source is connected.
   525  			var sourceConnectErr error
   526  			if source == nil {
   527  				c.stateLock.Lock()
   528  				c.state.Status = Status_ConnectingSource
   529  				c.stateLock.Unlock()
   530  				source, sourceConnectErr = connect(
   531  					ctx,
   532  					c.logger.Sublogger("source"),
   533  					c.session.Source,
   534  					"",
   535  					c.session.Identifier,
   536  					c.session.Version,
   537  					c.mergedSourceConfiguration,
   538  					true,
   539  				)
   540  			}
   541  			c.stateLock.Lock()
   542  			c.state.SourceState.Connected = (source != nil)
   543  			if sourceConnectErr != nil {
   544  				c.state.LastError = fmt.Errorf("unable to connect to source: %w", sourceConnectErr).Error()
   545  			}
   546  			c.stateLock.Unlock()
   547  
   548  			// Check for cancellation to avoid a spurious connection to
   549  			// destination in case cancellation occurred while connecting to
   550  			// source.
   551  			select {
   552  			case <-ctx.Done():
   553  				return
   554  			default:
   555  			}
   556  
   557  			// Ensure that destination is connected.
   558  			var destinationConnectErr error
   559  			if destination == nil {
   560  				c.stateLock.Lock()
   561  				c.state.Status = Status_ConnectingDestination
   562  				c.stateLock.Unlock()
   563  				destination, destinationConnectErr = connect(
   564  					ctx,
   565  					c.logger.Sublogger("destination"),
   566  					c.session.Destination,
   567  					"",
   568  					c.session.Identifier,
   569  					c.session.Version,
   570  					c.mergedDestinationConfiguration,
   571  					false,
   572  				)
   573  			}
   574  			c.stateLock.Lock()
   575  			c.state.DestinationState.Connected = (destination != nil)
   576  			if destinationConnectErr != nil {
   577  				c.state.LastError = fmt.Errorf("unable to connect to destination: %w", destinationConnectErr).Error()
   578  			}
   579  			c.stateLock.Unlock()
   580  
   581  			// If both endpoints are connected, we're done. We perform this
   582  			// check here (rather than in the loop condition) because if we did
   583  			// it in the loop condition we'd still need a check here to avoid a
   584  			// sleep every time (even if already successfully connected).
   585  			if source != nil && destination != nil {
   586  				break
   587  			}
   588  
   589  			// If we failed to connect, wait and then retry. Watch for
   590  			// cancellation in the mean time.
   591  			select {
   592  			case <-ctx.Done():
   593  				return
   594  			case <-time.After(autoReconnectInterval):
   595  			}
   596  		}
   597  
   598  		// Grab transport error channels for each endpoint.
   599  		sourceTransportErrors := source.TransportErrors()
   600  		destinationTransportErrors := destination.TransportErrors()
   601  
   602  		// Create a cancellable subcontext that we can use to manage shutdown.
   603  		shutdownCtx, forceShutdown := context.WithCancel(ctx)
   604  
   605  		// Create a Goroutine that will shut down (and unblock) endpoints. This
   606  		// is the only way to unblock forwarding on cancellation.
   607  		shutdownComplete := make(chan struct{})
   608  		go func() {
   609  			<-shutdownCtx.Done()
   610  			source.Shutdown()
   611  			destination.Shutdown()
   612  			close(shutdownComplete)
   613  		}()
   614  
   615  		// Perform forwarding in a background Goroutine and monitor for errors.
   616  		forwardingErrors := make(chan error, 1)
   617  		go func() {
   618  			c.logger.Debug("Entering forwarding loop")
   619  			forwardingErrors <- c.forward(source, destination)
   620  		}()
   621  
   622  		// Wait for cancellation, an error from forwarding, or an error from
   623  		// either transport.
   624  		var cancelled bool
   625  		var sessionErr error
   626  		var forwardingErrorReceived bool
   627  		select {
   628  		case <-ctx.Done():
   629  			c.logger.Debug("Run loop cancelled")
   630  			sessionErr = errors.New("session cancelled")
   631  			cancelled = true
   632  		case sessionErr = <-forwardingErrors:
   633  			c.logger.Debug("Forwarding loop terminated with error:", sessionErr)
   634  			forwardingErrorReceived = true
   635  		case err := <-sourceTransportErrors:
   636  			c.logger.Debug("Source transport failure:", err)
   637  			sessionErr = fmt.Errorf("source transport failure: %w", err)
   638  		case err := <-destinationTransportErrors:
   639  			c.logger.Debug("Destination transport failure:", err)
   640  			sessionErr = fmt.Errorf("destination transport failure: %w", err)
   641  		}
   642  
   643  		// Force shutdown, which may have already occurred due to cancellation.
   644  		forceShutdown()
   645  
   646  		// Wait for shutdown to complete.
   647  		<-shutdownComplete
   648  
   649  		// If the forwarding loop wasn't what unblocked our wait, then wait for
   650  		// it to return a result so that we know it has exited. This isn't
   651  		// strictly necessary with our current design, but it's cleaner and more
   652  		// robust.
   653  		if !forwardingErrorReceived {
   654  			<-forwardingErrors
   655  			c.logger.Debug("Forwarding loop terminated")
   656  		}
   657  
   658  		// Nil out endpoints to update our state.
   659  		source = nil
   660  		destination = nil
   661  
   662  		// Reset the forwarding state, but propagate the error that caused
   663  		// failure.
   664  		c.stateLock.Lock()
   665  		c.state = &State{
   666  			Session:          c.session,
   667  			LastError:        sessionErr.Error(),
   668  			SourceState:      &EndpointState{},
   669  			DestinationState: &EndpointState{},
   670  		}
   671  		c.stateLock.Unlock()
   672  
   673  		// If we were cancelled, then return immediately.
   674  		if cancelled {
   675  			return
   676  		}
   677  
   678  		// If less than one auto-reconnect interval has elapsed since the last
   679  		// forwarding failure, then wait before attempting reconnection.
   680  		now := time.Now()
   681  		if now.Sub(lastForwardingFailureTime) < autoReconnectInterval {
   682  			select {
   683  			case <-ctx.Done():
   684  				return
   685  			case <-time.After(autoReconnectInterval):
   686  			}
   687  		}
   688  		lastForwardingFailureTime = now
   689  	}
   690  }
   691  
   692  // forward is the main forwarding loop for the controller.
   693  func (c *controller) forward(source, destination Endpoint) error {
   694  	// Create a context that we can use to regulate the lifecycle of forwarding
   695  	// Goroutines and defer its cancellation.
   696  	ctx, cancel := context.WithCancel(context.Background())
   697  	defer cancel()
   698  
   699  	// Clear any error state and update the status to forwarding. While we're at
   700  	// it, capture a pointer to the state instance that all forwarding
   701  	// Goroutines spawned by this loop will update. This state instance will be
   702  	// replaced once this loop returns, so those background Goroutines can
   703  	// continue to safely update it without any risk of updating a future loop's
   704  	// state object. The only penalty is that both state objects will share the
   705  	// same lock, but that's a negligible overhead.
   706  	var state *State
   707  	c.stateLock.Lock()
   708  	c.state.LastError = ""
   709  	c.state.Status = Status_ForwardingConnections
   710  	state = c.state
   711  	c.stateLock.Unlock()
   712  
   713  	// Create auditor functions to track data transfer.
   714  	incomingAuditor := func(amount uint64) {
   715  		c.stateLock.Lock()
   716  		state.TotalInboundData += amount
   717  		c.stateLock.Unlock()
   718  	}
   719  	outgoingAuditor := func(amount uint64) {
   720  		c.stateLock.Lock()
   721  		state.TotalOutboundData += amount
   722  		c.stateLock.Unlock()
   723  	}
   724  
   725  	// Accept and forward connections until there's an error.
   726  	for {
   727  		// Accept a connection from the source.
   728  		incoming, err := source.Open()
   729  		if err != nil {
   730  			return fmt.Errorf("unable to accept connection: %w", err)
   731  		}
   732  
   733  		// Open the outgoing connection to which we should forward.
   734  		outgoing, err := destination.Open()
   735  		if err != nil {
   736  			incoming.Close()
   737  			return fmt.Errorf("unable to open forwarding connection: %w", err)
   738  		}
   739  
   740  		// Increment the open and total connection counts.
   741  		c.stateLock.Lock()
   742  		state.OpenConnections++
   743  		state.TotalConnections++
   744  		c.stateLock.Unlock()
   745  
   746  		// Perform forwarding and update state in a background Goroutine.
   747  		go func() {
   748  			// Perform forwarding.
   749  			ForwardAndClose(ctx, incoming, outgoing, incomingAuditor, outgoingAuditor)
   750  
   751  			// Decrement open connection counts.
   752  			c.stateLock.Lock()
   753  			state.OpenConnections--
   754  			c.stateLock.Unlock()
   755  		}()
   756  	}
   757  }