github.com/Psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/server/tunnelServer.go (about)

     1  /*
     2   * Copyright (c) 2016, Psiphon Inc.
     3   * All rights reserved.
     4   *
     5   * This program is free software: you can redistribute it and/or modify
     6   * it under the terms of the GNU General Public License as published by
     7   * the Free Software Foundation, either version 3 of the License, or
     8   * (at your option) any later version.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package server
    21  
    22  import (
    23  	"bytes"
    24  	"context"
    25  	"crypto/rand"
    26  	"crypto/subtle"
    27  	"encoding/base64"
    28  	"encoding/json"
    29  	std_errors "errors"
    30  	"fmt"
    31  	"io"
    32  	"io/ioutil"
    33  	"net"
    34  	"strconv"
    35  	"sync"
    36  	"sync/atomic"
    37  	"syscall"
    38  	"time"
    39  
    40  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
    41  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/accesscontrol"
    42  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
    43  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
    44  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/monotime"
    45  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/obfuscator"
    46  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
    47  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
    48  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
    49  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
    50  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic"
    51  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/refraction"
    52  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics"
    53  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tun"
    54  	"github.com/marusama/semaphore"
    55  	cache "github.com/patrickmn/go-cache"
    56  )
    57  
    58  const (
    59  	SSH_AUTH_LOG_PERIOD                   = 30 * time.Minute
    60  	SSH_HANDSHAKE_TIMEOUT                 = 30 * time.Second
    61  	SSH_BEGIN_HANDSHAKE_TIMEOUT           = 1 * time.Second
    62  	SSH_CONNECTION_READ_DEADLINE          = 5 * time.Minute
    63  	SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE = 8192
    64  	SSH_TCP_PORT_FORWARD_QUEUE_SIZE       = 1024
    65  	SSH_KEEP_ALIVE_PAYLOAD_MIN_BYTES      = 0
    66  	SSH_KEEP_ALIVE_PAYLOAD_MAX_BYTES      = 256
    67  	SSH_SEND_OSL_INITIAL_RETRY_DELAY      = 30 * time.Second
    68  	SSH_SEND_OSL_RETRY_FACTOR             = 2
    69  	OSL_SESSION_CACHE_TTL                 = 5 * time.Minute
    70  	MAX_AUTHORIZATIONS                    = 16
    71  	PRE_HANDSHAKE_RANDOM_STREAM_MAX_COUNT = 1
    72  	RANDOM_STREAM_MAX_BYTES               = 10485760
    73  	ALERT_REQUEST_QUEUE_BUFFER_SIZE       = 16
    74  )
    75  
    76  // TunnelServer is the main server that accepts Psiphon client
    77  // connections, via various obfuscation protocols, and provides
    78  // port forwarding (TCP and UDP) services to the Psiphon client.
    79  // At its core, TunnelServer is an SSH server. SSH is the base
    80  // protocol that provides port forward multiplexing, and transport
    81  // security. Layered on top of SSH, optionally, is Obfuscated SSH
    82  // and meek protocols, which provide further circumvention
    83  // capabilities.
    84  type TunnelServer struct {
    85  	runWaitGroup      *sync.WaitGroup
    86  	listenerError     chan error
    87  	shutdownBroadcast <-chan struct{}
    88  	sshServer         *sshServer
    89  }
    90  
    91  type sshListener struct {
    92  	net.Listener
    93  	localAddress   string
    94  	tunnelProtocol string
    95  	port           int
    96  	BPFProgramName string
    97  }
    98  
    99  // NewTunnelServer initializes a new tunnel server.
   100  func NewTunnelServer(
   101  	support *SupportServices,
   102  	shutdownBroadcast <-chan struct{}) (*TunnelServer, error) {
   103  
   104  	sshServer, err := newSSHServer(support, shutdownBroadcast)
   105  	if err != nil {
   106  		return nil, errors.Trace(err)
   107  	}
   108  
   109  	return &TunnelServer{
   110  		runWaitGroup:      new(sync.WaitGroup),
   111  		listenerError:     make(chan error),
   112  		shutdownBroadcast: shutdownBroadcast,
   113  		sshServer:         sshServer,
   114  	}, nil
   115  }
   116  
   117  // Run runs the tunnel server; this function blocks while running a selection of
   118  // listeners that handle connection using various obfuscation protocols.
   119  //
   120  // Run listens on each designated tunnel port and spawns new goroutines to handle
   121  // each client connection. It halts when shutdownBroadcast is signaled. A list of active
   122  // clients is maintained, and when halting all clients are cleanly shutdown.
   123  //
   124  // Each client goroutine handles its own obfuscation (optional), SSH handshake, SSH
   125  // authentication, and then looping on client new channel requests. "direct-tcpip"
   126  // channels, dynamic port fowards, are supported. When the UDPInterceptUdpgwServerAddress
   127  // config parameter is configured, UDP port forwards over a TCP stream, following
   128  // the udpgw protocol, are handled.
   129  //
   130  // A new goroutine is spawned to handle each port forward for each client. Each port
   131  // forward tracks its bytes transferred. Overall per-client stats for connection duration,
   132  // GeoIP, number of port forwards, and bytes transferred are tracked and logged when the
   133  // client shuts down.
   134  //
   135  // Note: client handler goroutines may still be shutting down after Run() returns. See
   136  // comment in sshClient.stop(). TODO: fully synchronized shutdown.
   137  func (server *TunnelServer) Run() error {
   138  
   139  	// TODO: should TunnelServer hold its own support pointer?
   140  	support := server.sshServer.support
   141  
   142  	// First bind all listeners; once all are successful,
   143  	// start accepting connections on each.
   144  
   145  	var listeners []*sshListener
   146  
   147  	for tunnelProtocol, listenPort := range support.Config.TunnelProtocolPorts {
   148  
   149  		localAddress := net.JoinHostPort(
   150  			support.Config.ServerIPAddress, strconv.Itoa(listenPort))
   151  
   152  		var listener net.Listener
   153  		var BPFProgramName string
   154  		var err error
   155  
   156  		if protocol.TunnelProtocolUsesFrontedMeekQUIC(tunnelProtocol) {
   157  
   158  			// For FRONTED-MEEK-QUIC-OSSH, no listener implemented. The edge-to-server
   159  			// hop uses HTTPS and the client tunnel protocol is distinguished using
   160  			// protocol.MeekCookieData.ClientTunnelProtocol.
   161  			continue
   162  
   163  		} else if protocol.TunnelProtocolUsesQUIC(tunnelProtocol) {
   164  
   165  			logTunnelProtocol := tunnelProtocol
   166  			listener, err = quic.Listen(
   167  				CommonLogger(log),
   168  				func(clientAddress string, err error, logFields common.LogFields) {
   169  					logIrregularTunnel(
   170  						support, logTunnelProtocol, listenPort, clientAddress,
   171  						errors.Trace(err), LogFields(logFields))
   172  				},
   173  				localAddress,
   174  				support.Config.ObfuscatedSSHKey,
   175  				support.Config.EnableGQUIC)
   176  
   177  		} else if protocol.TunnelProtocolUsesRefractionNetworking(tunnelProtocol) {
   178  
   179  			listener, err = refraction.Listen(localAddress)
   180  
   181  		} else if protocol.TunnelProtocolUsesFrontedMeek(tunnelProtocol) {
   182  
   183  			listener, err = net.Listen("tcp", localAddress)
   184  
   185  		} else {
   186  
   187  			// Only direct, unfronted protocol listeners use TCP BPF circumvention
   188  			// programs.
   189  			listener, BPFProgramName, err = newTCPListenerWithBPF(support, localAddress)
   190  		}
   191  
   192  		if err != nil {
   193  			for _, existingListener := range listeners {
   194  				existingListener.Listener.Close()
   195  			}
   196  			return errors.Trace(err)
   197  		}
   198  
   199  		tacticsListener := NewTacticsListener(
   200  			support,
   201  			listener,
   202  			tunnelProtocol,
   203  			func(IP string) GeoIPData { return support.GeoIPService.Lookup(IP) })
   204  
   205  		log.WithTraceFields(
   206  			LogFields{
   207  				"localAddress":   localAddress,
   208  				"tunnelProtocol": tunnelProtocol,
   209  				"BPFProgramName": BPFProgramName,
   210  			}).Info("listening")
   211  
   212  		listeners = append(
   213  			listeners,
   214  			&sshListener{
   215  				Listener:       tacticsListener,
   216  				localAddress:   localAddress,
   217  				port:           listenPort,
   218  				tunnelProtocol: tunnelProtocol,
   219  				BPFProgramName: BPFProgramName,
   220  			})
   221  	}
   222  
   223  	for _, listener := range listeners {
   224  		server.runWaitGroup.Add(1)
   225  		go func(listener *sshListener) {
   226  			defer server.runWaitGroup.Done()
   227  
   228  			log.WithTraceFields(
   229  				LogFields{
   230  					"localAddress":   listener.localAddress,
   231  					"tunnelProtocol": listener.tunnelProtocol,
   232  				}).Info("running")
   233  
   234  			server.sshServer.runListener(
   235  				listener,
   236  				server.listenerError)
   237  
   238  			log.WithTraceFields(
   239  				LogFields{
   240  					"localAddress":   listener.localAddress,
   241  					"tunnelProtocol": listener.tunnelProtocol,
   242  				}).Info("stopped")
   243  
   244  		}(listener)
   245  	}
   246  
   247  	var err error
   248  	select {
   249  	case <-server.shutdownBroadcast:
   250  	case err = <-server.listenerError:
   251  	}
   252  
   253  	for _, listener := range listeners {
   254  		listener.Close()
   255  	}
   256  	server.sshServer.stopClients()
   257  	server.runWaitGroup.Wait()
   258  
   259  	log.WithTrace().Info("stopped")
   260  
   261  	return err
   262  }
   263  
   264  // GetLoadStats returns load stats for the tunnel server. The stats are
   265  // broken down by protocol ("SSH", "OSSH", etc.) and type. Types of stats
   266  // include current connected client count, total number of current port
   267  // forwards.
   268  func (server *TunnelServer) GetLoadStats() (
   269  	UpstreamStats, ProtocolStats, RegionStats) {
   270  
   271  	return server.sshServer.getLoadStats()
   272  }
   273  
   274  // GetEstablishedClientCount returns the number of currently established
   275  // clients.
   276  func (server *TunnelServer) GetEstablishedClientCount() int {
   277  	return server.sshServer.getEstablishedClientCount()
   278  }
   279  
   280  // ResetAllClientTrafficRules resets all established client traffic rules
   281  // to use the latest config and client properties. Any existing traffic
   282  // rule state is lost, including throttling state.
   283  func (server *TunnelServer) ResetAllClientTrafficRules() {
   284  	server.sshServer.resetAllClientTrafficRules()
   285  }
   286  
   287  // ResetAllClientOSLConfigs resets all established client OSL state to use
   288  // the latest OSL config. Any existing OSL state is lost, including partial
   289  // progress towards SLOKs.
   290  func (server *TunnelServer) ResetAllClientOSLConfigs() {
   291  	server.sshServer.resetAllClientOSLConfigs()
   292  }
   293  
   294  // SetClientHandshakeState sets the handshake state -- that it completed and
   295  // what parameters were passed -- in sshClient. This state is used for allowing
   296  // port forwards and for future traffic rule selection. SetClientHandshakeState
   297  // also triggers an immediate traffic rule re-selection, as the rules selected
   298  // upon tunnel establishment may no longer apply now that handshake values are
   299  // set.
   300  //
   301  // The authorizations received from the client handshake are verified and the
   302  // resulting list of authorized access types are applied to the client's tunnel
   303  // and traffic rules.
   304  //
   305  // A list of active authorization IDs, authorized access types, and traffic
   306  // rate limits are returned for responding to the client and logging.
   307  func (server *TunnelServer) SetClientHandshakeState(
   308  	sessionID string,
   309  	state handshakeState,
   310  	authorizations []string) (*handshakeStateInfo, error) {
   311  
   312  	return server.sshServer.setClientHandshakeState(sessionID, state, authorizations)
   313  }
   314  
   315  // GetClientHandshaked indicates whether the client has completed a handshake
   316  // and whether its traffic rules are immediately exhausted.
   317  func (server *TunnelServer) GetClientHandshaked(
   318  	sessionID string) (bool, bool, error) {
   319  
   320  	return server.sshServer.getClientHandshaked(sessionID)
   321  }
   322  
   323  // GetClientDisableDiscovery indicates whether discovery is disabled for the
   324  // client corresponding to sessionID.
   325  func (server *TunnelServer) GetClientDisableDiscovery(
   326  	sessionID string) (bool, error) {
   327  
   328  	return server.sshServer.getClientDisableDiscovery(sessionID)
   329  }
   330  
   331  // UpdateClientAPIParameters updates the recorded handshake API parameters for
   332  // the client corresponding to sessionID.
   333  func (server *TunnelServer) UpdateClientAPIParameters(
   334  	sessionID string,
   335  	apiParams common.APIParameters) error {
   336  
   337  	return server.sshServer.updateClientAPIParameters(sessionID, apiParams)
   338  }
   339  
   340  // AcceptClientDomainBytes indicates whether to accept domain bytes reported
   341  // by the client.
   342  func (server *TunnelServer) AcceptClientDomainBytes(
   343  	sessionID string) (bool, error) {
   344  
   345  	return server.sshServer.acceptClientDomainBytes(sessionID)
   346  }
   347  
   348  // SetEstablishTunnels sets whether new tunnels may be established or not.
   349  // When not establishing, incoming connections are immediately closed.
   350  func (server *TunnelServer) SetEstablishTunnels(establish bool) {
   351  	server.sshServer.setEstablishTunnels(establish)
   352  }
   353  
   354  // CheckEstablishTunnels returns whether new tunnels may be established or
   355  // not, and increments a metrics counter when establishment is disallowed.
   356  func (server *TunnelServer) CheckEstablishTunnels() bool {
   357  	return server.sshServer.checkEstablishTunnels()
   358  }
   359  
   360  // GetEstablishTunnelsMetrics returns whether tunnel establishment is
   361  // currently allowed and the number of tunnels rejected since due to not
   362  // establishing since the last GetEstablishTunnelsMetrics call.
   363  func (server *TunnelServer) GetEstablishTunnelsMetrics() (bool, int64) {
   364  	return server.sshServer.getEstablishTunnelsMetrics()
   365  }
   366  
   367  type sshServer struct {
   368  	// Note: 64-bit ints used with atomic operations are placed
   369  	// at the start of struct to ensure 64-bit alignment.
   370  	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
   371  	lastAuthLog                  int64
   372  	authFailedCount              int64
   373  	establishLimitedCount        int64
   374  	support                      *SupportServices
   375  	establishTunnels             int32
   376  	concurrentSSHHandshakes      semaphore.Semaphore
   377  	shutdownBroadcast            <-chan struct{}
   378  	sshHostKey                   ssh.Signer
   379  	clientsMutex                 sync.Mutex
   380  	stoppingClients              bool
   381  	acceptedClientCounts         map[string]map[string]int64
   382  	clients                      map[string]*sshClient
   383  	oslSessionCacheMutex         sync.Mutex
   384  	oslSessionCache              *cache.Cache
   385  	authorizationSessionIDsMutex sync.Mutex
   386  	authorizationSessionIDs      map[string]string
   387  	obfuscatorSeedHistory        *obfuscator.SeedHistory
   388  }
   389  
   390  func newSSHServer(
   391  	support *SupportServices,
   392  	shutdownBroadcast <-chan struct{}) (*sshServer, error) {
   393  
   394  	privateKey, err := ssh.ParseRawPrivateKey([]byte(support.Config.SSHPrivateKey))
   395  	if err != nil {
   396  		return nil, errors.Trace(err)
   397  	}
   398  
   399  	// TODO: use cert (ssh.NewCertSigner) for anti-fingerprint?
   400  	signer, err := ssh.NewSignerFromKey(privateKey)
   401  	if err != nil {
   402  		return nil, errors.Trace(err)
   403  	}
   404  
   405  	var concurrentSSHHandshakes semaphore.Semaphore
   406  	if support.Config.MaxConcurrentSSHHandshakes > 0 {
   407  		concurrentSSHHandshakes = semaphore.New(support.Config.MaxConcurrentSSHHandshakes)
   408  	}
   409  
   410  	// The OSL session cache temporarily retains OSL seed state
   411  	// progress for disconnected clients. This enables clients
   412  	// that disconnect and immediately reconnect to the same
   413  	// server to resume their OSL progress. Cached progress
   414  	// is referenced by session ID and is retained for
   415  	// OSL_SESSION_CACHE_TTL after disconnect.
   416  	//
   417  	// Note: session IDs are assumed to be unpredictable. If a
   418  	// rogue client could guess the session ID of another client,
   419  	// it could resume its OSL progress and, if the OSL config
   420  	// were known, infer some activity.
   421  	oslSessionCache := cache.New(OSL_SESSION_CACHE_TTL, 1*time.Minute)
   422  
   423  	return &sshServer{
   424  		support:                 support,
   425  		establishTunnels:        1,
   426  		concurrentSSHHandshakes: concurrentSSHHandshakes,
   427  		shutdownBroadcast:       shutdownBroadcast,
   428  		sshHostKey:              signer,
   429  		acceptedClientCounts:    make(map[string]map[string]int64),
   430  		clients:                 make(map[string]*sshClient),
   431  		oslSessionCache:         oslSessionCache,
   432  		authorizationSessionIDs: make(map[string]string),
   433  		obfuscatorSeedHistory:   obfuscator.NewSeedHistory(nil),
   434  	}, nil
   435  }
   436  
   437  func (sshServer *sshServer) setEstablishTunnels(establish bool) {
   438  
   439  	// Do nothing when the setting is already correct. This avoids
   440  	// spurious log messages when setEstablishTunnels is called
   441  	// periodically with the same setting.
   442  	if establish == (atomic.LoadInt32(&sshServer.establishTunnels) == 1) {
   443  		return
   444  	}
   445  
   446  	establishFlag := int32(1)
   447  	if !establish {
   448  		establishFlag = 0
   449  	}
   450  	atomic.StoreInt32(&sshServer.establishTunnels, establishFlag)
   451  
   452  	log.WithTraceFields(
   453  		LogFields{"establish": establish}).Info("establishing tunnels")
   454  }
   455  
   456  func (sshServer *sshServer) checkEstablishTunnels() bool {
   457  	establishTunnels := atomic.LoadInt32(&sshServer.establishTunnels) == 1
   458  	if !establishTunnels {
   459  		atomic.AddInt64(&sshServer.establishLimitedCount, 1)
   460  	}
   461  	return establishTunnels
   462  }
   463  
   464  func (sshServer *sshServer) getEstablishTunnelsMetrics() (bool, int64) {
   465  	return atomic.LoadInt32(&sshServer.establishTunnels) == 1,
   466  		atomic.SwapInt64(&sshServer.establishLimitedCount, 0)
   467  }
   468  
   469  // runListener is intended to run an a goroutine; it blocks
   470  // running a particular listener. If an unrecoverable error
   471  // occurs, it will send the error to the listenerError channel.
   472  func (sshServer *sshServer) runListener(sshListener *sshListener, listenerError chan<- error) {
   473  
   474  	handleClient := func(clientTunnelProtocol string, clientConn net.Conn) {
   475  
   476  		// Note: establish tunnel limiter cannot simply stop TCP
   477  		// listeners in all cases (e.g., meek) since SSH tunnels can
   478  		// span multiple TCP connections.
   479  
   480  		if !sshServer.checkEstablishTunnels() {
   481  			log.WithTrace().Debug("not establishing tunnels")
   482  			clientConn.Close()
   483  			return
   484  		}
   485  
   486  		// tunnelProtocol is used for stats and traffic rules. In many cases, its
   487  		// value is unambiguously determined by the listener port. In certain cases,
   488  		// such as multiple fronted protocols with a single backend listener, the
   489  		// client's reported tunnel protocol value is used. The caller must validate
   490  		// clientTunnelProtocol with protocol.IsValidClientTunnelProtocol.
   491  
   492  		tunnelProtocol := sshListener.tunnelProtocol
   493  		if clientTunnelProtocol != "" {
   494  			tunnelProtocol = clientTunnelProtocol
   495  		}
   496  
   497  		// sshListener.tunnelProtocol indictes the tunnel protocol run by the
   498  		// listener. For direct protocols, this is also the client tunnel protocol.
   499  		// For fronted protocols, the client may use a different protocol to connect
   500  		// to the front and then only the front-to-Psiphon server will use the
   501  		// listener protocol.
   502  		//
   503  		// A fronted meek client, for example, reports its first hop protocol in
   504  		// protocol.MeekCookieData.ClientTunnelProtocol. Most metrics record this
   505  		// value as relay_protocol, since the first hop is the one subject to
   506  		// adversarial conditions. In some cases, such as irregular tunnels, there
   507  		// is no ClientTunnelProtocol value available and the listener tunnel
   508  		// protocol will be logged.
   509  		//
   510  		// Similarly, listenerPort indicates the listening port, which is the dialed
   511  		// port number for direct protocols; while, for fronted protocols, the
   512  		// client may dial a different port for its first hop.
   513  
   514  		// Process each client connection concurrently.
   515  		go sshServer.handleClient(sshListener, tunnelProtocol, clientConn)
   516  	}
   517  
   518  	// Note: when exiting due to a unrecoverable error, be sure
   519  	// to try to send the error to listenerError so that the outer
   520  	// TunnelServer.Run will properly shut down instead of remaining
   521  	// running.
   522  
   523  	if protocol.TunnelProtocolUsesMeekHTTP(sshListener.tunnelProtocol) ||
   524  		protocol.TunnelProtocolUsesMeekHTTPS(sshListener.tunnelProtocol) {
   525  
   526  		meekServer, err := NewMeekServer(
   527  			sshServer.support,
   528  			sshListener.Listener,
   529  			sshListener.tunnelProtocol,
   530  			sshListener.port,
   531  			protocol.TunnelProtocolUsesMeekHTTPS(sshListener.tunnelProtocol),
   532  			protocol.TunnelProtocolUsesFrontedMeek(sshListener.tunnelProtocol),
   533  			protocol.TunnelProtocolUsesObfuscatedSessionTickets(sshListener.tunnelProtocol),
   534  			handleClient,
   535  			sshServer.shutdownBroadcast)
   536  
   537  		if err == nil {
   538  			err = meekServer.Run()
   539  		}
   540  
   541  		if err != nil {
   542  			select {
   543  			case listenerError <- errors.Trace(err):
   544  			default:
   545  			}
   546  			return
   547  		}
   548  
   549  	} else {
   550  
   551  		for {
   552  			conn, err := sshListener.Listener.Accept()
   553  
   554  			select {
   555  			case <-sshServer.shutdownBroadcast:
   556  				if err == nil {
   557  					conn.Close()
   558  				}
   559  				return
   560  			default:
   561  			}
   562  
   563  			if err != nil {
   564  				if e, ok := err.(net.Error); ok && e.Temporary() {
   565  					log.WithTraceFields(LogFields{"error": err}).Error("accept failed")
   566  					// Temporary error, keep running
   567  					continue
   568  				}
   569  
   570  				select {
   571  				case listenerError <- errors.Trace(err):
   572  				default:
   573  				}
   574  				return
   575  			}
   576  
   577  			handleClient("", conn)
   578  		}
   579  	}
   580  }
   581  
   582  // An accepted client has completed a direct TCP or meek connection and has a net.Conn. Registration
   583  // is for tracking the number of connections.
   584  func (sshServer *sshServer) registerAcceptedClient(tunnelProtocol, region string) {
   585  
   586  	sshServer.clientsMutex.Lock()
   587  	defer sshServer.clientsMutex.Unlock()
   588  
   589  	if sshServer.acceptedClientCounts[tunnelProtocol] == nil {
   590  		sshServer.acceptedClientCounts[tunnelProtocol] = make(map[string]int64)
   591  	}
   592  
   593  	sshServer.acceptedClientCounts[tunnelProtocol][region] += 1
   594  }
   595  
   596  func (sshServer *sshServer) unregisterAcceptedClient(tunnelProtocol, region string) {
   597  
   598  	sshServer.clientsMutex.Lock()
   599  	defer sshServer.clientsMutex.Unlock()
   600  
   601  	sshServer.acceptedClientCounts[tunnelProtocol][region] -= 1
   602  }
   603  
   604  // An established client has completed its SSH handshake and has a ssh.Conn. Registration is
   605  // for tracking the number of fully established clients and for maintaining a list of running
   606  // clients (for stopping at shutdown time).
   607  func (sshServer *sshServer) registerEstablishedClient(client *sshClient) bool {
   608  
   609  	sshServer.clientsMutex.Lock()
   610  
   611  	if sshServer.stoppingClients {
   612  		sshServer.clientsMutex.Unlock()
   613  		return false
   614  	}
   615  
   616  	// In the case of a duplicate client sessionID, the previous client is closed.
   617  	// - Well-behaved clients generate a random sessionID that should be unique (won't
   618  	//   accidentally conflict) and hard to guess (can't be targeted by a malicious
   619  	//   client).
   620  	// - Clients reuse the same sessionID when a tunnel is unexpectedly disconnected
   621  	//   and reestablished. In this case, when the same server is selected, this logic
   622  	//   will be hit; closing the old, dangling client is desirable.
   623  	// - Multi-tunnel clients should not normally use one server for multiple tunnels.
   624  
   625  	existingClient := sshServer.clients[client.sessionID]
   626  
   627  	sshServer.clientsMutex.Unlock()
   628  
   629  	if existingClient != nil {
   630  
   631  		// This case is expected to be common, and so logged at the lowest severity
   632  		// level.
   633  		log.WithTrace().Debug(
   634  			"stopping existing client with duplicate session ID")
   635  
   636  		existingClient.stop()
   637  
   638  		// Block until the existingClient is fully terminated. This is necessary to
   639  		// avoid this scenario:
   640  		// - existingClient is invoking handshakeAPIRequestHandler
   641  		// - sshServer.clients[client.sessionID] is updated to point to new client
   642  		// - existingClient's handshakeAPIRequestHandler invokes
   643  		//   SetClientHandshakeState but sets the handshake parameters for new
   644  		//   client
   645  		// - as a result, the new client handshake will fail (only a single handshake
   646  		//   is permitted) and the new client server_tunnel log will contain an
   647  		//   invalid mix of existing/new client fields
   648  		//
   649  		// Once existingClient.awaitStopped returns, all existingClient port
   650  		// forwards and request handlers have terminated, so no API handler, either
   651  		// tunneled web API or SSH API, will remain and it is safe to point
   652  		// sshServer.clients[client.sessionID] to the new client.
   653  		// Limitation: this scenario remains possible with _untunneled_ web API
   654  		// requests.
   655  		//
   656  		// Blocking also ensures existingClient.releaseAuthorizations is invoked before
   657  		// the new client attempts to submit the same authorizations.
   658  		//
   659  		// Perform blocking awaitStopped operation outside the
   660  		// sshServer.clientsMutex mutex to avoid blocking all other clients for the
   661  		// duration. We still expect and require that the stop process completes
   662  		// rapidly, e.g., does not block on network I/O, allowing the new client
   663  		// connection to proceed without delay.
   664  		//
   665  		// In addition, operations triggered by stop, and which must complete before
   666  		// awaitStopped returns, will attempt to lock sshServer.clientsMutex,
   667  		// including unregisterEstablishedClient.
   668  
   669  		existingClient.awaitStopped()
   670  	}
   671  
   672  	sshServer.clientsMutex.Lock()
   673  	defer sshServer.clientsMutex.Unlock()
   674  
   675  	// existingClient's stop will have removed it from sshServer.clients via
   676  	// unregisterEstablishedClient, so sshServer.clients[client.sessionID] should
   677  	// be nil -- unless yet another client instance using the same sessionID has
   678  	// connected in the meantime while awaiting existingClient stop. In this
   679  	// case, it's not clear which is the most recent connection from the client,
   680  	// so instead of this connection terminating more peers, it aborts.
   681  
   682  	if sshServer.clients[client.sessionID] != nil {
   683  		// As this is expected to be rare case, it's logged at a higher severity
   684  		// level.
   685  		log.WithTrace().Warning(
   686  			"aborting new client with duplicate session ID")
   687  		return false
   688  	}
   689  
   690  	sshServer.clients[client.sessionID] = client
   691  
   692  	return true
   693  }
   694  
   695  func (sshServer *sshServer) unregisterEstablishedClient(client *sshClient) {
   696  
   697  	sshServer.clientsMutex.Lock()
   698  
   699  	registeredClient := sshServer.clients[client.sessionID]
   700  
   701  	// registeredClient will differ from client when client is the existingClient
   702  	// terminated in registerEstablishedClient. In that case, registeredClient
   703  	// remains connected, and the sshServer.clients entry should be retained.
   704  	if registeredClient == client {
   705  		delete(sshServer.clients, client.sessionID)
   706  	}
   707  
   708  	sshServer.clientsMutex.Unlock()
   709  
   710  	client.stop()
   711  }
   712  
   713  type UpstreamStats map[string]interface{}
   714  type ProtocolStats map[string]map[string]interface{}
   715  type RegionStats map[string]map[string]map[string]interface{}
   716  
   717  func (sshServer *sshServer) getLoadStats() (
   718  	UpstreamStats, ProtocolStats, RegionStats) {
   719  
   720  	sshServer.clientsMutex.Lock()
   721  	defer sshServer.clientsMutex.Unlock()
   722  
   723  	// Explicitly populate with zeros to ensure 0 counts in log messages.
   724  
   725  	zeroClientStats := func() map[string]interface{} {
   726  		stats := make(map[string]interface{})
   727  		stats["accepted_clients"] = int64(0)
   728  		stats["established_clients"] = int64(0)
   729  		return stats
   730  	}
   731  
   732  	// Due to hot reload and changes to the underlying system configuration, the
   733  	// set of resolver IPs may change between getLoadStats calls, so this
   734  	// enumeration for zeroing is a best effort.
   735  	resolverIPs := sshServer.support.DNSResolver.GetAll()
   736  
   737  	// Fields which are primarily concerned with upstream/egress performance.
   738  	zeroUpstreamStats := func() map[string]interface{} {
   739  		stats := make(map[string]interface{})
   740  		stats["dialing_tcp_port_forwards"] = int64(0)
   741  		stats["tcp_port_forwards"] = int64(0)
   742  		stats["total_tcp_port_forwards"] = int64(0)
   743  		stats["udp_port_forwards"] = int64(0)
   744  		stats["total_udp_port_forwards"] = int64(0)
   745  		stats["tcp_port_forward_dialed_count"] = int64(0)
   746  		stats["tcp_port_forward_dialed_duration"] = int64(0)
   747  		stats["tcp_port_forward_failed_count"] = int64(0)
   748  		stats["tcp_port_forward_failed_duration"] = int64(0)
   749  		stats["tcp_port_forward_rejected_dialing_limit_count"] = int64(0)
   750  		stats["tcp_port_forward_rejected_disallowed_count"] = int64(0)
   751  		stats["udp_port_forward_rejected_disallowed_count"] = int64(0)
   752  		stats["tcp_ipv4_port_forward_dialed_count"] = int64(0)
   753  		stats["tcp_ipv4_port_forward_dialed_duration"] = int64(0)
   754  		stats["tcp_ipv4_port_forward_failed_count"] = int64(0)
   755  		stats["tcp_ipv4_port_forward_failed_duration"] = int64(0)
   756  		stats["tcp_ipv6_port_forward_dialed_count"] = int64(0)
   757  		stats["tcp_ipv6_port_forward_dialed_duration"] = int64(0)
   758  		stats["tcp_ipv6_port_forward_failed_count"] = int64(0)
   759  		stats["tcp_ipv6_port_forward_failed_duration"] = int64(0)
   760  
   761  		zeroDNSStats := func() map[string]int64 {
   762  			m := map[string]int64{"ALL": 0}
   763  			for _, resolverIP := range resolverIPs {
   764  				m[resolverIP.String()] = 0
   765  			}
   766  			return m
   767  		}
   768  
   769  		stats["dns_count"] = zeroDNSStats()
   770  		stats["dns_duration"] = zeroDNSStats()
   771  		stats["dns_failed_count"] = zeroDNSStats()
   772  		stats["dns_failed_duration"] = zeroDNSStats()
   773  		return stats
   774  	}
   775  
   776  	zeroProtocolStats := func() map[string]map[string]interface{} {
   777  		stats := make(map[string]map[string]interface{})
   778  		stats["ALL"] = zeroClientStats()
   779  		for tunnelProtocol := range sshServer.support.Config.TunnelProtocolPorts {
   780  			stats[tunnelProtocol] = zeroClientStats()
   781  		}
   782  		return stats
   783  	}
   784  
   785  	addInt64 := func(stats map[string]interface{}, name string, value int64) {
   786  		stats[name] = stats[name].(int64) + value
   787  	}
   788  
   789  	upstreamStats := zeroUpstreamStats()
   790  
   791  	// [<protocol or ALL>][<stat name>] -> count
   792  	protocolStats := zeroProtocolStats()
   793  
   794  	// [<region][<protocol or ALL>][<stat name>] -> count
   795  	regionStats := make(RegionStats)
   796  
   797  	// Note: as currently tracked/counted, each established client is also an accepted client
   798  
   799  	for tunnelProtocol, regionAcceptedClientCounts := range sshServer.acceptedClientCounts {
   800  		for region, acceptedClientCount := range regionAcceptedClientCounts {
   801  
   802  			if acceptedClientCount > 0 {
   803  				if regionStats[region] == nil {
   804  					regionStats[region] = zeroProtocolStats()
   805  				}
   806  
   807  				addInt64(protocolStats["ALL"], "accepted_clients", acceptedClientCount)
   808  				addInt64(protocolStats[tunnelProtocol], "accepted_clients", acceptedClientCount)
   809  
   810  				addInt64(regionStats[region]["ALL"], "accepted_clients", acceptedClientCount)
   811  				addInt64(regionStats[region][tunnelProtocol], "accepted_clients", acceptedClientCount)
   812  			}
   813  		}
   814  	}
   815  
   816  	for _, client := range sshServer.clients {
   817  
   818  		client.Lock()
   819  
   820  		tunnelProtocol := client.tunnelProtocol
   821  		region := client.geoIPData.Country
   822  
   823  		if regionStats[region] == nil {
   824  			regionStats[region] = zeroProtocolStats()
   825  		}
   826  
   827  		for _, stats := range []map[string]interface{}{
   828  			protocolStats["ALL"],
   829  			protocolStats[tunnelProtocol],
   830  			regionStats[region]["ALL"],
   831  			regionStats[region][tunnelProtocol]} {
   832  
   833  			addInt64(stats, "established_clients", 1)
   834  		}
   835  
   836  		// Note:
   837  		// - can't sum trafficState.peakConcurrentPortForwardCount to get a global peak
   838  		// - client.udpTrafficState.concurrentDialingPortForwardCount isn't meaningful
   839  
   840  		addInt64(upstreamStats, "dialing_tcp_port_forwards",
   841  			client.tcpTrafficState.concurrentDialingPortForwardCount)
   842  
   843  		addInt64(upstreamStats, "tcp_port_forwards",
   844  			client.tcpTrafficState.concurrentPortForwardCount)
   845  
   846  		addInt64(upstreamStats, "total_tcp_port_forwards",
   847  			client.tcpTrafficState.totalPortForwardCount)
   848  
   849  		addInt64(upstreamStats, "udp_port_forwards",
   850  			client.udpTrafficState.concurrentPortForwardCount)
   851  
   852  		addInt64(upstreamStats, "total_udp_port_forwards",
   853  			client.udpTrafficState.totalPortForwardCount)
   854  
   855  		addInt64(upstreamStats, "tcp_port_forward_dialed_count",
   856  			client.qualityMetrics.TCPPortForwardDialedCount)
   857  
   858  		addInt64(upstreamStats, "tcp_port_forward_dialed_duration",
   859  			int64(client.qualityMetrics.TCPPortForwardDialedDuration/time.Millisecond))
   860  
   861  		addInt64(upstreamStats, "tcp_port_forward_failed_count",
   862  			client.qualityMetrics.TCPPortForwardFailedCount)
   863  
   864  		addInt64(upstreamStats, "tcp_port_forward_failed_duration",
   865  			int64(client.qualityMetrics.TCPPortForwardFailedDuration/time.Millisecond))
   866  
   867  		addInt64(upstreamStats, "tcp_port_forward_rejected_dialing_limit_count",
   868  			client.qualityMetrics.TCPPortForwardRejectedDialingLimitCount)
   869  
   870  		addInt64(upstreamStats, "tcp_port_forward_rejected_disallowed_count",
   871  			client.qualityMetrics.TCPPortForwardRejectedDisallowedCount)
   872  
   873  		addInt64(upstreamStats, "udp_port_forward_rejected_disallowed_count",
   874  			client.qualityMetrics.UDPPortForwardRejectedDisallowedCount)
   875  
   876  		addInt64(upstreamStats, "tcp_ipv4_port_forward_dialed_count",
   877  			client.qualityMetrics.TCPIPv4PortForwardDialedCount)
   878  
   879  		addInt64(upstreamStats, "tcp_ipv4_port_forward_dialed_duration",
   880  			int64(client.qualityMetrics.TCPIPv4PortForwardDialedDuration/time.Millisecond))
   881  
   882  		addInt64(upstreamStats, "tcp_ipv4_port_forward_failed_count",
   883  			client.qualityMetrics.TCPIPv4PortForwardFailedCount)
   884  
   885  		addInt64(upstreamStats, "tcp_ipv4_port_forward_failed_duration",
   886  			int64(client.qualityMetrics.TCPIPv4PortForwardFailedDuration/time.Millisecond))
   887  
   888  		addInt64(upstreamStats, "tcp_ipv6_port_forward_dialed_count",
   889  			client.qualityMetrics.TCPIPv6PortForwardDialedCount)
   890  
   891  		addInt64(upstreamStats, "tcp_ipv6_port_forward_dialed_duration",
   892  			int64(client.qualityMetrics.TCPIPv6PortForwardDialedDuration/time.Millisecond))
   893  
   894  		addInt64(upstreamStats, "tcp_ipv6_port_forward_failed_count",
   895  			client.qualityMetrics.TCPIPv6PortForwardFailedCount)
   896  
   897  		addInt64(upstreamStats, "tcp_ipv6_port_forward_failed_duration",
   898  			int64(client.qualityMetrics.TCPIPv6PortForwardFailedDuration/time.Millisecond))
   899  
   900  		// DNS metrics limitations:
   901  		// - port forwards (sshClient.handleTCPChannel) don't know or log the resolver IP.
   902  		// - udpgw and packet tunnel transparent DNS use a heuristic to classify success/failure,
   903  		//   and there may be some delay before these code paths report DNS metrics.
   904  
   905  		// Every client.qualityMetrics DNS map has an "ALL" entry.
   906  
   907  		totalDNSCount := int64(0)
   908  		totalDNSFailedCount := int64(0)
   909  
   910  		for key, value := range client.qualityMetrics.DNSCount {
   911  			upstreamStats["dns_count"].(map[string]int64)[key] += value
   912  			totalDNSCount += value
   913  		}
   914  
   915  		for key, value := range client.qualityMetrics.DNSDuration {
   916  			upstreamStats["dns_duration"].(map[string]int64)[key] += int64(value / time.Millisecond)
   917  		}
   918  
   919  		for key, value := range client.qualityMetrics.DNSFailedCount {
   920  			upstreamStats["dns_failed_count"].(map[string]int64)[key] += value
   921  			totalDNSFailedCount += value
   922  		}
   923  
   924  		for key, value := range client.qualityMetrics.DNSFailedDuration {
   925  			upstreamStats["dns_failed_duration"].(map[string]int64)[key] += int64(value / time.Millisecond)
   926  		}
   927  
   928  		// Update client peak failure rate metrics, to be recorded in
   929  		// server_tunnel.
   930  		//
   931  		// Limitations:
   932  		//
   933  		// - This is a simple data sampling that doesn't require additional
   934  		//   timers or tracking logic. Since the rates are calculated on
   935  		//   getLoadStats events and using accumulated counts, these peaks
   936  		//   only represent the highest failure rate within a
   937  		//   Config.LoadMonitorPeriodSeconds non-sliding window. There is no
   938  		//   sample recorded for short tunnels with no overlapping
   939  		//   getLoadStats event.
   940  		//
   941  		// - There is no minimum sample window, as a getLoadStats event may
   942  		//   occur immediately after a client first connects. This may be
   943  		//   compensated for by adjusting
   944  		//   Config.PeakUpstreamFailureRateMinimumSampleSize, so as to only
   945  		//   consider failure rates with a larger number of samples.
   946  		//
   947  		// - Non-UDP "failures" are not currently tracked.
   948  
   949  		minimumSampleSize := int64(sshServer.support.Config.peakUpstreamFailureRateMinimumSampleSize)
   950  
   951  		sampleSize := client.qualityMetrics.TCPPortForwardDialedCount +
   952  			client.qualityMetrics.TCPPortForwardFailedCount
   953  
   954  		if sampleSize >= minimumSampleSize {
   955  
   956  			TCPPortForwardFailureRate := float64(client.qualityMetrics.TCPPortForwardFailedCount) /
   957  				float64(sampleSize)
   958  
   959  			if client.peakMetrics.TCPPortForwardFailureRate == nil {
   960  
   961  				client.peakMetrics.TCPPortForwardFailureRate = new(float64)
   962  				*client.peakMetrics.TCPPortForwardFailureRate = TCPPortForwardFailureRate
   963  				client.peakMetrics.TCPPortForwardFailureRateSampleSize = new(int64)
   964  				*client.peakMetrics.TCPPortForwardFailureRateSampleSize = sampleSize
   965  
   966  			} else if *client.peakMetrics.TCPPortForwardFailureRate < TCPPortForwardFailureRate {
   967  
   968  				*client.peakMetrics.TCPPortForwardFailureRate = TCPPortForwardFailureRate
   969  				*client.peakMetrics.TCPPortForwardFailureRateSampleSize = sampleSize
   970  			}
   971  		}
   972  
   973  		sampleSize = totalDNSCount + totalDNSFailedCount
   974  
   975  		if sampleSize >= minimumSampleSize {
   976  
   977  			DNSFailureRate := float64(totalDNSFailedCount) / float64(sampleSize)
   978  
   979  			if client.peakMetrics.DNSFailureRate == nil {
   980  
   981  				client.peakMetrics.DNSFailureRate = new(float64)
   982  				*client.peakMetrics.DNSFailureRate = DNSFailureRate
   983  				client.peakMetrics.DNSFailureRateSampleSize = new(int64)
   984  				*client.peakMetrics.DNSFailureRateSampleSize = sampleSize
   985  
   986  			} else if *client.peakMetrics.DNSFailureRate < DNSFailureRate {
   987  
   988  				*client.peakMetrics.DNSFailureRate = DNSFailureRate
   989  				*client.peakMetrics.DNSFailureRateSampleSize = sampleSize
   990  			}
   991  		}
   992  
   993  		// Reset quality metrics counters
   994  
   995  		client.qualityMetrics.reset()
   996  
   997  		client.Unlock()
   998  	}
   999  
  1000  	for _, client := range sshServer.clients {
  1001  
  1002  		client.Lock()
  1003  
  1004  		// Update client peak proximate (same region) concurrently connected
  1005  		// (other clients) client metrics, to be recorded in server_tunnel.
  1006  		// This operation requires a second loop over sshServer.clients since
  1007  		// established_clients is calculated in the first loop.
  1008  		//
  1009  		// Limitations:
  1010  		//
  1011  		// - This is an approximation, not a true peak, as it only samples
  1012  		//   data every Config.LoadMonitorPeriodSeconds period. There is no
  1013  		//   sample recorded for short tunnels with no overlapping
  1014  		//   getLoadStats event.
  1015  		//
  1016  		// - The "-1" calculation counts all but the current client as other
  1017  		//   clients; it can be the case that the same client has a dangling
  1018  		//   accepted connection that has yet to time-out server side. Due to
  1019  		//   NAT, we can't determine if the client is the same based on
  1020  		//   network address. For established clients,
  1021  		//   registerEstablishedClient ensures that any previous connection
  1022  		//   is first terminated, although this is only for the same
  1023  		//   session_id. Concurrent proximate clients may be considered an
  1024  		//   exact number of other _network connections_, even from the same
  1025  		//   client.
  1026  
  1027  		region := client.geoIPData.Country
  1028  		stats := regionStats[region]["ALL"]
  1029  
  1030  		n := stats["accepted_clients"].(int64) - 1
  1031  		if n >= 0 {
  1032  			if client.peakMetrics.concurrentProximateAcceptedClients == nil {
  1033  
  1034  				client.peakMetrics.concurrentProximateAcceptedClients = new(int64)
  1035  				*client.peakMetrics.concurrentProximateAcceptedClients = n
  1036  
  1037  			} else if *client.peakMetrics.concurrentProximateAcceptedClients < n {
  1038  
  1039  				*client.peakMetrics.concurrentProximateAcceptedClients = n
  1040  			}
  1041  		}
  1042  
  1043  		n = stats["established_clients"].(int64) - 1
  1044  		if n >= 0 {
  1045  			if client.peakMetrics.concurrentProximateEstablishedClients == nil {
  1046  
  1047  				client.peakMetrics.concurrentProximateEstablishedClients = new(int64)
  1048  				*client.peakMetrics.concurrentProximateEstablishedClients = n
  1049  
  1050  			} else if *client.peakMetrics.concurrentProximateEstablishedClients < n {
  1051  
  1052  				*client.peakMetrics.concurrentProximateEstablishedClients = n
  1053  			}
  1054  		}
  1055  
  1056  		client.Unlock()
  1057  	}
  1058  
  1059  	return upstreamStats, protocolStats, regionStats
  1060  }
  1061  
  1062  func (sshServer *sshServer) getEstablishedClientCount() int {
  1063  	sshServer.clientsMutex.Lock()
  1064  	defer sshServer.clientsMutex.Unlock()
  1065  	establishedClients := len(sshServer.clients)
  1066  	return establishedClients
  1067  }
  1068  
  1069  func (sshServer *sshServer) resetAllClientTrafficRules() {
  1070  
  1071  	sshServer.clientsMutex.Lock()
  1072  	clients := make(map[string]*sshClient)
  1073  	for sessionID, client := range sshServer.clients {
  1074  		clients[sessionID] = client
  1075  	}
  1076  	sshServer.clientsMutex.Unlock()
  1077  
  1078  	for _, client := range clients {
  1079  		client.setTrafficRules()
  1080  	}
  1081  }
  1082  
  1083  func (sshServer *sshServer) resetAllClientOSLConfigs() {
  1084  
  1085  	// Flush cached seed state. This has the same effect
  1086  	// and same limitations as calling setOSLConfig for
  1087  	// currently connected clients -- all progress is lost.
  1088  	sshServer.oslSessionCacheMutex.Lock()
  1089  	sshServer.oslSessionCache.Flush()
  1090  	sshServer.oslSessionCacheMutex.Unlock()
  1091  
  1092  	sshServer.clientsMutex.Lock()
  1093  	clients := make(map[string]*sshClient)
  1094  	for sessionID, client := range sshServer.clients {
  1095  		clients[sessionID] = client
  1096  	}
  1097  	sshServer.clientsMutex.Unlock()
  1098  
  1099  	for _, client := range clients {
  1100  		client.setOSLConfig()
  1101  	}
  1102  }
  1103  
  1104  func (sshServer *sshServer) setClientHandshakeState(
  1105  	sessionID string,
  1106  	state handshakeState,
  1107  	authorizations []string) (*handshakeStateInfo, error) {
  1108  
  1109  	sshServer.clientsMutex.Lock()
  1110  	client := sshServer.clients[sessionID]
  1111  	sshServer.clientsMutex.Unlock()
  1112  
  1113  	if client == nil {
  1114  		return nil, errors.TraceNew("unknown session ID")
  1115  	}
  1116  
  1117  	handshakeStateInfo, err := client.setHandshakeState(
  1118  		state, authorizations)
  1119  	if err != nil {
  1120  		return nil, errors.Trace(err)
  1121  	}
  1122  
  1123  	return handshakeStateInfo, nil
  1124  }
  1125  
  1126  func (sshServer *sshServer) getClientHandshaked(
  1127  	sessionID string) (bool, bool, error) {
  1128  
  1129  	sshServer.clientsMutex.Lock()
  1130  	client := sshServer.clients[sessionID]
  1131  	sshServer.clientsMutex.Unlock()
  1132  
  1133  	if client == nil {
  1134  		return false, false, errors.TraceNew("unknown session ID")
  1135  	}
  1136  
  1137  	completed, exhausted := client.getHandshaked()
  1138  
  1139  	return completed, exhausted, nil
  1140  }
  1141  
  1142  func (sshServer *sshServer) getClientDisableDiscovery(
  1143  	sessionID string) (bool, error) {
  1144  
  1145  	sshServer.clientsMutex.Lock()
  1146  	client := sshServer.clients[sessionID]
  1147  	sshServer.clientsMutex.Unlock()
  1148  
  1149  	if client == nil {
  1150  		return false, errors.TraceNew("unknown session ID")
  1151  	}
  1152  
  1153  	return client.getDisableDiscovery(), nil
  1154  }
  1155  
  1156  func (sshServer *sshServer) updateClientAPIParameters(
  1157  	sessionID string,
  1158  	apiParams common.APIParameters) error {
  1159  
  1160  	sshServer.clientsMutex.Lock()
  1161  	client := sshServer.clients[sessionID]
  1162  	sshServer.clientsMutex.Unlock()
  1163  
  1164  	if client == nil {
  1165  		return errors.TraceNew("unknown session ID")
  1166  	}
  1167  
  1168  	client.updateAPIParameters(apiParams)
  1169  
  1170  	return nil
  1171  }
  1172  
  1173  func (sshServer *sshServer) revokeClientAuthorizations(sessionID string) {
  1174  	sshServer.clientsMutex.Lock()
  1175  	client := sshServer.clients[sessionID]
  1176  	sshServer.clientsMutex.Unlock()
  1177  
  1178  	if client == nil {
  1179  		return
  1180  	}
  1181  
  1182  	// sshClient.handshakeState.authorizedAccessTypes is not cleared. Clearing
  1183  	// authorizedAccessTypes may cause sshClient.logTunnel to fail to log
  1184  	// access types. As the revocation may be due to legitimate use of an
  1185  	// authorization in multiple sessions by a single client, useful metrics
  1186  	// would be lost.
  1187  
  1188  	client.Lock()
  1189  	client.handshakeState.authorizationsRevoked = true
  1190  	client.Unlock()
  1191  
  1192  	// Select and apply new traffic rules, as filtered by the client's new
  1193  	// authorization state.
  1194  
  1195  	client.setTrafficRules()
  1196  }
  1197  
  1198  func (sshServer *sshServer) acceptClientDomainBytes(
  1199  	sessionID string) (bool, error) {
  1200  
  1201  	sshServer.clientsMutex.Lock()
  1202  	client := sshServer.clients[sessionID]
  1203  	sshServer.clientsMutex.Unlock()
  1204  
  1205  	if client == nil {
  1206  		return false, errors.TraceNew("unknown session ID")
  1207  	}
  1208  
  1209  	return client.acceptDomainBytes(), nil
  1210  }
  1211  
  1212  func (sshServer *sshServer) stopClients() {
  1213  
  1214  	sshServer.clientsMutex.Lock()
  1215  	sshServer.stoppingClients = true
  1216  	clients := sshServer.clients
  1217  	sshServer.clients = make(map[string]*sshClient)
  1218  	sshServer.clientsMutex.Unlock()
  1219  
  1220  	for _, client := range clients {
  1221  		client.stop()
  1222  	}
  1223  }
  1224  
  1225  func (sshServer *sshServer) handleClient(
  1226  	sshListener *sshListener, tunnelProtocol string, clientConn net.Conn) {
  1227  
  1228  	// Calling clientConn.RemoteAddr at this point, before any Read calls,
  1229  	// satisfies the constraint documented in tapdance.Listen.
  1230  
  1231  	clientAddr := clientConn.RemoteAddr()
  1232  
  1233  	// Check if there were irregularities during the network connection
  1234  	// establishment. When present, log and then behave as Obfuscated SSH does
  1235  	// when the client fails to provide a valid seed message.
  1236  	//
  1237  	// One concrete irregular case is failure to send a PROXY protocol header for
  1238  	// TAPDANCE-OSSH.
  1239  
  1240  	if indicator, ok := clientConn.(common.IrregularIndicator); ok {
  1241  
  1242  		tunnelErr := indicator.IrregularTunnelError()
  1243  
  1244  		if tunnelErr != nil {
  1245  
  1246  			logIrregularTunnel(
  1247  				sshServer.support,
  1248  				sshListener.tunnelProtocol,
  1249  				sshListener.port,
  1250  				common.IPAddressFromAddr(clientAddr),
  1251  				errors.Trace(tunnelErr),
  1252  				nil)
  1253  
  1254  			var afterFunc *time.Timer
  1255  			if sshServer.support.Config.sshHandshakeTimeout > 0 {
  1256  				afterFunc = time.AfterFunc(sshServer.support.Config.sshHandshakeTimeout, func() {
  1257  					clientConn.Close()
  1258  				})
  1259  			}
  1260  			io.Copy(ioutil.Discard, clientConn)
  1261  			clientConn.Close()
  1262  			afterFunc.Stop()
  1263  
  1264  			return
  1265  		}
  1266  	}
  1267  
  1268  	// Get any packet manipulation values from GetAppliedSpecName as soon as
  1269  	// possible due to the expiring TTL.
  1270  
  1271  	serverPacketManipulation := ""
  1272  	replayedServerPacketManipulation := false
  1273  
  1274  	if sshServer.support.Config.RunPacketManipulator &&
  1275  		protocol.TunnelProtocolMayUseServerPacketManipulation(tunnelProtocol) {
  1276  
  1277  		// A meekConn has synthetic address values, including the original client
  1278  		// address in cases where the client uses an upstream proxy to connect to
  1279  		// Psiphon. For meekConn, and any other conn implementing
  1280  		// UnderlyingTCPAddrSource, get the underlying TCP connection addresses.
  1281  		//
  1282  		// Limitation: a meek tunnel may consist of several TCP connections. The
  1283  		// server_packet_manipulation metric will reflect the packet manipulation
  1284  		// applied to the _first_ TCP connection only.
  1285  
  1286  		var localAddr, remoteAddr *net.TCPAddr
  1287  		var ok bool
  1288  		underlying, ok := clientConn.(common.UnderlyingTCPAddrSource)
  1289  		if ok {
  1290  			localAddr, remoteAddr, ok = underlying.GetUnderlyingTCPAddrs()
  1291  		} else {
  1292  			localAddr, ok = clientConn.LocalAddr().(*net.TCPAddr)
  1293  			if ok {
  1294  				remoteAddr, ok = clientConn.RemoteAddr().(*net.TCPAddr)
  1295  			}
  1296  		}
  1297  
  1298  		if ok {
  1299  			specName, extraData, err := sshServer.support.PacketManipulator.
  1300  				GetAppliedSpecName(localAddr, remoteAddr)
  1301  			if err == nil {
  1302  				serverPacketManipulation = specName
  1303  				replayedServerPacketManipulation, _ = extraData.(bool)
  1304  			}
  1305  		}
  1306  	}
  1307  
  1308  	geoIPData := sshServer.support.GeoIPService.Lookup(
  1309  		common.IPAddressFromAddr(clientAddr))
  1310  
  1311  	sshServer.registerAcceptedClient(tunnelProtocol, geoIPData.Country)
  1312  	defer sshServer.unregisterAcceptedClient(tunnelProtocol, geoIPData.Country)
  1313  
  1314  	// When configured, enforce a cap on the number of concurrent SSH
  1315  	// handshakes. This limits load spikes on busy servers when many clients
  1316  	// attempt to connect at once. Wait a short time, SSH_BEGIN_HANDSHAKE_TIMEOUT,
  1317  	// to acquire; waiting will avoid immediately creating more load on another
  1318  	// server in the network when the client tries a new candidate. Disconnect the
  1319  	// client when that wait time is exceeded.
  1320  	//
  1321  	// This mechanism limits memory allocations and CPU usage associated with the
  1322  	// SSH handshake. At this point, new direct TCP connections or new meek
  1323  	// connections, with associated resource usage, are already established. Those
  1324  	// connections are expected to be rate or load limited using other mechanisms.
  1325  	//
  1326  	// TODO:
  1327  	//
  1328  	// - deduct time spent acquiring the semaphore from SSH_HANDSHAKE_TIMEOUT in
  1329  	//   sshClient.run, since the client is also applying an SSH handshake timeout
  1330  	//   and won't exclude time spent waiting.
  1331  	// - each call to sshServer.handleClient (in sshServer.runListener) is invoked
  1332  	//   in its own goroutine, but shutdown doesn't synchronously await these
  1333  	//   goroutnes. Once this is synchronizes, the following context.WithTimeout
  1334  	//   should use an sshServer parent context to ensure blocking acquires
  1335  	//   interrupt immediately upon shutdown.
  1336  
  1337  	var onSSHHandshakeFinished func()
  1338  	if sshServer.support.Config.MaxConcurrentSSHHandshakes > 0 {
  1339  
  1340  		ctx, cancelFunc := context.WithTimeout(
  1341  			context.Background(),
  1342  			sshServer.support.Config.sshBeginHandshakeTimeout)
  1343  		defer cancelFunc()
  1344  
  1345  		err := sshServer.concurrentSSHHandshakes.Acquire(ctx, 1)
  1346  		if err != nil {
  1347  			clientConn.Close()
  1348  			// This is a debug log as the only possible error is context timeout.
  1349  			log.WithTraceFields(LogFields{"error": err}).Debug(
  1350  				"acquire SSH handshake semaphore failed")
  1351  			return
  1352  		}
  1353  
  1354  		onSSHHandshakeFinished = func() {
  1355  			sshServer.concurrentSSHHandshakes.Release(1)
  1356  		}
  1357  	}
  1358  
  1359  	sshClient := newSshClient(
  1360  		sshServer,
  1361  		sshListener,
  1362  		tunnelProtocol,
  1363  		serverPacketManipulation,
  1364  		replayedServerPacketManipulation,
  1365  		clientAddr,
  1366  		geoIPData)
  1367  
  1368  	// sshClient.run _must_ call onSSHHandshakeFinished to release the semaphore:
  1369  	// in any error case; or, as soon as the SSH handshake phase has successfully
  1370  	// completed.
  1371  
  1372  	sshClient.run(clientConn, onSSHHandshakeFinished)
  1373  }
  1374  
  1375  func (sshServer *sshServer) monitorPortForwardDialError(err error) {
  1376  
  1377  	// "err" is the error returned from a failed TCP or UDP port
  1378  	// forward dial. Certain system error codes indicate low resource
  1379  	// conditions: insufficient file descriptors, ephemeral ports, or
  1380  	// memory. For these cases, log an alert.
  1381  
  1382  	// TODO: also temporarily suspend new clients
  1383  
  1384  	// Note: don't log net.OpError.Error() as the full error string
  1385  	// may contain client destination addresses.
  1386  
  1387  	opErr, ok := err.(*net.OpError)
  1388  	if ok {
  1389  		if opErr.Err == syscall.EADDRNOTAVAIL ||
  1390  			opErr.Err == syscall.EAGAIN ||
  1391  			opErr.Err == syscall.ENOMEM ||
  1392  			opErr.Err == syscall.EMFILE ||
  1393  			opErr.Err == syscall.ENFILE {
  1394  
  1395  			log.WithTraceFields(
  1396  				LogFields{"error": opErr.Err}).Error(
  1397  				"port forward dial failed due to unavailable resource")
  1398  		}
  1399  	}
  1400  }
  1401  
  1402  type sshClient struct {
  1403  	sync.Mutex
  1404  	sshServer                            *sshServer
  1405  	sshListener                          *sshListener
  1406  	tunnelProtocol                       string
  1407  	sshConn                              ssh.Conn
  1408  	throttledConn                        *common.ThrottledConn
  1409  	serverPacketManipulation             string
  1410  	replayedServerPacketManipulation     bool
  1411  	clientAddr                           net.Addr
  1412  	geoIPData                            GeoIPData
  1413  	sessionID                            string
  1414  	isFirstTunnelInSession               bool
  1415  	supportsServerRequests               bool
  1416  	handshakeState                       handshakeState
  1417  	udpgwChannelHandler                  *udpgwPortForwardMultiplexer
  1418  	totalUdpgwChannelCount               int
  1419  	packetTunnelChannel                  ssh.Channel
  1420  	totalPacketTunnelChannelCount        int
  1421  	trafficRules                         TrafficRules
  1422  	tcpTrafficState                      trafficState
  1423  	udpTrafficState                      trafficState
  1424  	qualityMetrics                       *qualityMetrics
  1425  	tcpPortForwardLRU                    *common.LRUConns
  1426  	oslClientSeedState                   *osl.ClientSeedState
  1427  	signalIssueSLOKs                     chan struct{}
  1428  	runCtx                               context.Context
  1429  	stopRunning                          context.CancelFunc
  1430  	stopped                              chan struct{}
  1431  	tcpPortForwardDialingAvailableSignal context.CancelFunc
  1432  	releaseAuthorizations                func()
  1433  	stopTimer                            *time.Timer
  1434  	preHandshakeRandomStreamMetrics      randomStreamMetrics
  1435  	postHandshakeRandomStreamMetrics     randomStreamMetrics
  1436  	sendAlertRequests                    chan protocol.AlertRequest
  1437  	sentAlertRequests                    map[string]bool
  1438  	peakMetrics                          peakMetrics
  1439  	destinationBytesMetricsASN           string
  1440  	tcpDestinationBytesMetrics           destinationBytesMetrics
  1441  	udpDestinationBytesMetrics           destinationBytesMetrics
  1442  }
  1443  
  1444  type trafficState struct {
  1445  	bytesUp                               int64
  1446  	bytesDown                             int64
  1447  	concurrentDialingPortForwardCount     int64
  1448  	peakConcurrentDialingPortForwardCount int64
  1449  	concurrentPortForwardCount            int64
  1450  	peakConcurrentPortForwardCount        int64
  1451  	totalPortForwardCount                 int64
  1452  	availablePortForwardCond              *sync.Cond
  1453  }
  1454  
  1455  type randomStreamMetrics struct {
  1456  	count                 int64
  1457  	upstreamBytes         int64
  1458  	receivedUpstreamBytes int64
  1459  	downstreamBytes       int64
  1460  	sentDownstreamBytes   int64
  1461  }
  1462  
  1463  type peakMetrics struct {
  1464  	concurrentProximateAcceptedClients    *int64
  1465  	concurrentProximateEstablishedClients *int64
  1466  	TCPPortForwardFailureRate             *float64
  1467  	TCPPortForwardFailureRateSampleSize   *int64
  1468  	DNSFailureRate                        *float64
  1469  	DNSFailureRateSampleSize              *int64
  1470  }
  1471  
  1472  // qualityMetrics records upstream TCP dial attempts and
  1473  // elapsed time. Elapsed time includes the full TCP handshake
  1474  // and, in aggregate, is a measure of the quality of the
  1475  // upstream link. These stats are recorded by each sshClient
  1476  // and then reported and reset in sshServer.getLoadStats().
  1477  type qualityMetrics struct {
  1478  	TCPPortForwardDialedCount               int64
  1479  	TCPPortForwardDialedDuration            time.Duration
  1480  	TCPPortForwardFailedCount               int64
  1481  	TCPPortForwardFailedDuration            time.Duration
  1482  	TCPPortForwardRejectedDialingLimitCount int64
  1483  	TCPPortForwardRejectedDisallowedCount   int64
  1484  	UDPPortForwardRejectedDisallowedCount   int64
  1485  	TCPIPv4PortForwardDialedCount           int64
  1486  	TCPIPv4PortForwardDialedDuration        time.Duration
  1487  	TCPIPv4PortForwardFailedCount           int64
  1488  	TCPIPv4PortForwardFailedDuration        time.Duration
  1489  	TCPIPv6PortForwardDialedCount           int64
  1490  	TCPIPv6PortForwardDialedDuration        time.Duration
  1491  	TCPIPv6PortForwardFailedCount           int64
  1492  	TCPIPv6PortForwardFailedDuration        time.Duration
  1493  	DNSCount                                map[string]int64
  1494  	DNSDuration                             map[string]time.Duration
  1495  	DNSFailedCount                          map[string]int64
  1496  	DNSFailedDuration                       map[string]time.Duration
  1497  }
  1498  
  1499  func newQualityMetrics() *qualityMetrics {
  1500  	return &qualityMetrics{
  1501  		DNSCount:          make(map[string]int64),
  1502  		DNSDuration:       make(map[string]time.Duration),
  1503  		DNSFailedCount:    make(map[string]int64),
  1504  		DNSFailedDuration: make(map[string]time.Duration),
  1505  	}
  1506  }
  1507  
  1508  func (q *qualityMetrics) reset() {
  1509  
  1510  	q.TCPPortForwardDialedCount = 0
  1511  	q.TCPPortForwardDialedDuration = 0
  1512  	q.TCPPortForwardFailedCount = 0
  1513  	q.TCPPortForwardFailedDuration = 0
  1514  	q.TCPPortForwardRejectedDialingLimitCount = 0
  1515  	q.TCPPortForwardRejectedDisallowedCount = 0
  1516  
  1517  	q.UDPPortForwardRejectedDisallowedCount = 0
  1518  
  1519  	q.TCPIPv4PortForwardDialedCount = 0
  1520  	q.TCPIPv4PortForwardDialedDuration = 0
  1521  	q.TCPIPv4PortForwardFailedCount = 0
  1522  	q.TCPIPv4PortForwardFailedDuration = 0
  1523  
  1524  	q.TCPIPv6PortForwardDialedCount = 0
  1525  	q.TCPIPv6PortForwardDialedDuration = 0
  1526  	q.TCPIPv6PortForwardFailedCount = 0
  1527  	q.TCPIPv6PortForwardFailedDuration = 0
  1528  
  1529  	// Retain existing maps to avoid memory churn. The Go compiler optimizes map
  1530  	// clearing operations of the following form.
  1531  
  1532  	for k := range q.DNSCount {
  1533  		delete(q.DNSCount, k)
  1534  	}
  1535  	for k := range q.DNSDuration {
  1536  		delete(q.DNSDuration, k)
  1537  	}
  1538  	for k := range q.DNSFailedCount {
  1539  		delete(q.DNSFailedCount, k)
  1540  	}
  1541  	for k := range q.DNSFailedDuration {
  1542  		delete(q.DNSFailedDuration, k)
  1543  	}
  1544  }
  1545  
  1546  type handshakeStateInfo struct {
  1547  	activeAuthorizationIDs   []string
  1548  	authorizedAccessTypes    []string
  1549  	upstreamBytesPerSecond   int64
  1550  	downstreamBytesPerSecond int64
  1551  }
  1552  
  1553  type handshakeState struct {
  1554  	completed               bool
  1555  	apiProtocol             string
  1556  	apiParams               common.APIParameters
  1557  	activeAuthorizationIDs  []string
  1558  	authorizedAccessTypes   []string
  1559  	authorizationsRevoked   bool
  1560  	domainBytesChecksum     []byte
  1561  	establishedTunnelsCount int
  1562  	splitTunnelLookup       *splitTunnelLookup
  1563  }
  1564  
  1565  type destinationBytesMetrics struct {
  1566  	bytesUp   int64
  1567  	bytesDown int64
  1568  }
  1569  
  1570  func (d *destinationBytesMetrics) UpdateProgress(
  1571  	downstreamBytes, upstreamBytes, _ int64) {
  1572  
  1573  	// Concurrency: UpdateProgress may be called without holding the sshClient
  1574  	// lock; all accesses to bytesUp/bytesDown must use atomic operations.
  1575  
  1576  	atomic.AddInt64(&d.bytesUp, upstreamBytes)
  1577  	atomic.AddInt64(&d.bytesDown, downstreamBytes)
  1578  }
  1579  
  1580  func (d *destinationBytesMetrics) getBytesUp() int64 {
  1581  	return atomic.LoadInt64(&d.bytesUp)
  1582  }
  1583  
  1584  func (d *destinationBytesMetrics) getBytesDown() int64 {
  1585  	return atomic.LoadInt64(&d.bytesDown)
  1586  }
  1587  
  1588  type splitTunnelLookup struct {
  1589  	regions       []string
  1590  	regionsLookup map[string]bool
  1591  }
  1592  
  1593  func newSplitTunnelLookup(
  1594  	ownRegion string,
  1595  	otherRegions []string) (*splitTunnelLookup, error) {
  1596  
  1597  	length := len(otherRegions)
  1598  	if ownRegion != "" {
  1599  		length += 1
  1600  	}
  1601  
  1602  	// This length check is a sanity check and prevents clients shipping
  1603  	// excessively long lists which could impact performance.
  1604  	if length > 250 {
  1605  		return nil, errors.Tracef("too many regions: %d", length)
  1606  	}
  1607  
  1608  	// Create map lookups for lists where the number of values to compare
  1609  	// against exceeds a threshold where benchmarks show maps are faster than
  1610  	// looping through a slice. Otherwise use a slice for lookups. In both
  1611  	// cases, the input slice is no longer referenced.
  1612  
  1613  	if length >= stringLookupThreshold {
  1614  		regionsLookup := make(map[string]bool)
  1615  		if ownRegion != "" {
  1616  			regionsLookup[ownRegion] = true
  1617  		}
  1618  		for _, region := range otherRegions {
  1619  			regionsLookup[region] = true
  1620  		}
  1621  		return &splitTunnelLookup{
  1622  			regionsLookup: regionsLookup,
  1623  		}, nil
  1624  	} else {
  1625  		regions := []string{}
  1626  		if ownRegion != "" && !common.Contains(otherRegions, ownRegion) {
  1627  			regions = append(regions, ownRegion)
  1628  		}
  1629  		// TODO: check for other duplicate regions?
  1630  		regions = append(regions, otherRegions...)
  1631  		return &splitTunnelLookup{
  1632  			regions: regions,
  1633  		}, nil
  1634  	}
  1635  }
  1636  
  1637  func (lookup *splitTunnelLookup) lookup(region string) bool {
  1638  	if lookup.regionsLookup != nil {
  1639  		return lookup.regionsLookup[region]
  1640  	} else {
  1641  		return common.Contains(lookup.regions, region)
  1642  	}
  1643  }
  1644  
  1645  func newSshClient(
  1646  	sshServer *sshServer,
  1647  	sshListener *sshListener,
  1648  	tunnelProtocol string,
  1649  	serverPacketManipulation string,
  1650  	replayedServerPacketManipulation bool,
  1651  	clientAddr net.Addr,
  1652  	geoIPData GeoIPData) *sshClient {
  1653  
  1654  	runCtx, stopRunning := context.WithCancel(context.Background())
  1655  
  1656  	// isFirstTunnelInSession is defaulted to true so that the pre-handshake
  1657  	// traffic rules won't apply UnthrottleFirstTunnelOnly and negate any
  1658  	// unthrottled bytes during the initial protocol negotiation.
  1659  
  1660  	client := &sshClient{
  1661  		sshServer:                        sshServer,
  1662  		sshListener:                      sshListener,
  1663  		tunnelProtocol:                   tunnelProtocol,
  1664  		serverPacketManipulation:         serverPacketManipulation,
  1665  		replayedServerPacketManipulation: replayedServerPacketManipulation,
  1666  		clientAddr:                       clientAddr,
  1667  		geoIPData:                        geoIPData,
  1668  		isFirstTunnelInSession:           true,
  1669  		qualityMetrics:                   newQualityMetrics(),
  1670  		tcpPortForwardLRU:                common.NewLRUConns(),
  1671  		signalIssueSLOKs:                 make(chan struct{}, 1),
  1672  		runCtx:                           runCtx,
  1673  		stopRunning:                      stopRunning,
  1674  		stopped:                          make(chan struct{}),
  1675  		sendAlertRequests:                make(chan protocol.AlertRequest, ALERT_REQUEST_QUEUE_BUFFER_SIZE),
  1676  		sentAlertRequests:                make(map[string]bool),
  1677  	}
  1678  
  1679  	client.tcpTrafficState.availablePortForwardCond = sync.NewCond(new(sync.Mutex))
  1680  	client.udpTrafficState.availablePortForwardCond = sync.NewCond(new(sync.Mutex))
  1681  
  1682  	return client
  1683  }
  1684  
  1685  func (sshClient *sshClient) run(
  1686  	baseConn net.Conn, onSSHHandshakeFinished func()) {
  1687  
  1688  	// When run returns, the client has fully stopped, with all SSH state torn
  1689  	// down and no port forwards or API requests in progress.
  1690  	defer close(sshClient.stopped)
  1691  
  1692  	// onSSHHandshakeFinished must be called even if the SSH handshake is aborted.
  1693  	defer func() {
  1694  		if onSSHHandshakeFinished != nil {
  1695  			onSSHHandshakeFinished()
  1696  		}
  1697  	}()
  1698  
  1699  	// Set initial traffic rules, pre-handshake, based on currently known info.
  1700  	sshClient.setTrafficRules()
  1701  
  1702  	conn := baseConn
  1703  
  1704  	// Wrap the base client connection with an ActivityMonitoredConn which will
  1705  	// terminate the connection if no data is received before the deadline. This
  1706  	// timeout is in effect for the entire duration of the SSH connection. Clients
  1707  	// must actively use the connection or send SSH keep alive requests to keep
  1708  	// the connection active. Writes are not considered reliable activity indicators
  1709  	// due to buffering.
  1710  
  1711  	activityConn, err := common.NewActivityMonitoredConn(
  1712  		conn,
  1713  		SSH_CONNECTION_READ_DEADLINE,
  1714  		false,
  1715  		nil)
  1716  	if err != nil {
  1717  		conn.Close()
  1718  		if !isExpectedTunnelIOError(err) {
  1719  			log.WithTraceFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed")
  1720  		}
  1721  		return
  1722  	}
  1723  	conn = activityConn
  1724  
  1725  	// Further wrap the connection with burst monitoring, when enabled.
  1726  	//
  1727  	// Limitation: burst parameters are fixed for the duration of the tunnel
  1728  	// and do not change after a tactics hot reload.
  1729  
  1730  	var burstConn *common.BurstMonitoredConn
  1731  
  1732  	p, err := sshClient.sshServer.support.ServerTacticsParametersCache.Get(sshClient.geoIPData)
  1733  	if err != nil {
  1734  		log.WithTraceFields(LogFields{"error": errors.Trace(err)}).Warning(
  1735  			"ServerTacticsParametersCache.Get failed")
  1736  		return
  1737  	}
  1738  
  1739  	if !p.IsNil() {
  1740  		upstreamTargetBytes := int64(p.Int(parameters.ServerBurstUpstreamTargetBytes))
  1741  		upstreamDeadline := p.Duration(parameters.ServerBurstUpstreamDeadline)
  1742  		downstreamTargetBytes := int64(p.Int(parameters.ServerBurstDownstreamTargetBytes))
  1743  		downstreamDeadline := p.Duration(parameters.ServerBurstDownstreamDeadline)
  1744  
  1745  		if (upstreamDeadline != 0 && upstreamTargetBytes != 0) ||
  1746  			(downstreamDeadline != 0 && downstreamTargetBytes != 0) {
  1747  
  1748  			burstConn = common.NewBurstMonitoredConn(
  1749  				conn,
  1750  				true,
  1751  				upstreamTargetBytes, upstreamDeadline,
  1752  				downstreamTargetBytes, downstreamDeadline)
  1753  			conn = burstConn
  1754  		}
  1755  	}
  1756  
  1757  	// Allow garbage collection.
  1758  	p.Close()
  1759  
  1760  	// Further wrap the connection in a rate limiting ThrottledConn.
  1761  
  1762  	throttledConn := common.NewThrottledConn(conn, sshClient.rateLimits())
  1763  	conn = throttledConn
  1764  
  1765  	// Replay of server-side parameters is set or extended after a new tunnel
  1766  	// meets duration and bytes transferred targets. Set a timer now that expires
  1767  	// shortly after the target duration. When the timer fires, check the time of
  1768  	// last byte read (a read indicating a live connection with the client),
  1769  	// along with total bytes transferred and set or extend replay if the targets
  1770  	// are met.
  1771  	//
  1772  	// Both target checks are conservative: the tunnel may be healthy, but a byte
  1773  	// may not have been read in the last second when the timer fires. Or bytes
  1774  	// may be transferring, but not at the target level. Only clients that meet
  1775  	// the strict targets at the single check time will trigger replay; however,
  1776  	// this replay will impact all clients with similar GeoIP data.
  1777  	//
  1778  	// A deferred function cancels the timer and also increments the replay
  1779  	// failure counter, which will ultimately clear replay parameters, when the
  1780  	// tunnel fails before the API handshake is completed (this includes any
  1781  	// liveness test).
  1782  	//
  1783  	// A tunnel which fails to meet the targets but successfully completes any
  1784  	// liveness test and the API handshake is ignored in terms of replay scoring.
  1785  
  1786  	isReplayCandidate, replayWaitDuration, replayTargetDuration :=
  1787  		sshClient.sshServer.support.ReplayCache.GetReplayTargetDuration(sshClient.geoIPData)
  1788  
  1789  	if isReplayCandidate {
  1790  
  1791  		getFragmentorSeed := func() *prng.Seed {
  1792  			fragmentor, ok := baseConn.(common.FragmentorReplayAccessor)
  1793  			if ok {
  1794  				fragmentorSeed, _ := fragmentor.GetReplay()
  1795  				return fragmentorSeed
  1796  			}
  1797  			return nil
  1798  		}
  1799  
  1800  		setReplayAfterFunc := time.AfterFunc(
  1801  			replayWaitDuration,
  1802  			func() {
  1803  				if activityConn.GetActiveDuration() >= replayTargetDuration {
  1804  
  1805  					sshClient.Lock()
  1806  					bytesUp := sshClient.tcpTrafficState.bytesUp + sshClient.udpTrafficState.bytesUp
  1807  					bytesDown := sshClient.tcpTrafficState.bytesDown + sshClient.udpTrafficState.bytesDown
  1808  					sshClient.Unlock()
  1809  
  1810  					sshClient.sshServer.support.ReplayCache.SetReplayParameters(
  1811  						sshClient.tunnelProtocol,
  1812  						sshClient.geoIPData,
  1813  						sshClient.serverPacketManipulation,
  1814  						getFragmentorSeed(),
  1815  						bytesUp,
  1816  						bytesDown)
  1817  				}
  1818  			})
  1819  
  1820  		defer func() {
  1821  			setReplayAfterFunc.Stop()
  1822  			completed, _ := sshClient.getHandshaked()
  1823  			if !completed {
  1824  
  1825  				// Count a replay failure case when a tunnel used replay parameters
  1826  				// (excluding OSSH fragmentation, which doesn't use the ReplayCache) and
  1827  				// failed to complete the API handshake.
  1828  
  1829  				replayedFragmentation := false
  1830  				if sshClient.tunnelProtocol != protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH {
  1831  					fragmentor, ok := baseConn.(common.FragmentorReplayAccessor)
  1832  					if ok {
  1833  						_, replayedFragmentation = fragmentor.GetReplay()
  1834  					}
  1835  				}
  1836  				usedReplay := replayedFragmentation || sshClient.replayedServerPacketManipulation
  1837  
  1838  				if usedReplay {
  1839  					sshClient.sshServer.support.ReplayCache.FailedReplayParameters(
  1840  						sshClient.tunnelProtocol,
  1841  						sshClient.geoIPData,
  1842  						sshClient.serverPacketManipulation,
  1843  						getFragmentorSeed())
  1844  				}
  1845  			}
  1846  		}()
  1847  	}
  1848  
  1849  	// Run the initial [obfuscated] SSH handshake in a goroutine so we can both
  1850  	// respect shutdownBroadcast and implement a specific handshake timeout.
  1851  	// The timeout is to reclaim network resources in case the handshake takes
  1852  	// too long.
  1853  
  1854  	type sshNewServerConnResult struct {
  1855  		obfuscatedSSHConn *obfuscator.ObfuscatedSSHConn
  1856  		sshConn           *ssh.ServerConn
  1857  		channels          <-chan ssh.NewChannel
  1858  		requests          <-chan *ssh.Request
  1859  		err               error
  1860  	}
  1861  
  1862  	resultChannel := make(chan *sshNewServerConnResult, 2)
  1863  
  1864  	var sshHandshakeAfterFunc *time.Timer
  1865  	if sshClient.sshServer.support.Config.sshHandshakeTimeout > 0 {
  1866  		sshHandshakeAfterFunc = time.AfterFunc(sshClient.sshServer.support.Config.sshHandshakeTimeout, func() {
  1867  			resultChannel <- &sshNewServerConnResult{err: std_errors.New("ssh handshake timeout")}
  1868  		})
  1869  	}
  1870  
  1871  	go func(baseConn, conn net.Conn) {
  1872  		sshServerConfig := &ssh.ServerConfig{
  1873  			PasswordCallback: sshClient.passwordCallback,
  1874  			AuthLogCallback:  sshClient.authLogCallback,
  1875  			ServerVersion:    sshClient.sshServer.support.Config.SSHServerVersion,
  1876  		}
  1877  		sshServerConfig.AddHostKey(sshClient.sshServer.sshHostKey)
  1878  
  1879  		var err error
  1880  
  1881  		if protocol.TunnelProtocolUsesObfuscatedSSH(sshClient.tunnelProtocol) {
  1882  			// With Encrypt-then-MAC hash algorithms, packet length is
  1883  			// transmitted in plaintext, which aids in traffic analysis;
  1884  			// clients may still send Encrypt-then-MAC algorithms in their
  1885  			// KEX_INIT message, but do not select these algorithms.
  1886  			//
  1887  			// The exception is TUNNEL_PROTOCOL_SSH, which is intended to appear
  1888  			// like SSH on the wire.
  1889  			sshServerConfig.NoEncryptThenMACHash = true
  1890  
  1891  		} else {
  1892  			// For TUNNEL_PROTOCOL_SSH only, randomize KEX.
  1893  			if sshClient.sshServer.support.Config.ObfuscatedSSHKey != "" {
  1894  				sshServerConfig.KEXPRNGSeed, err = protocol.DeriveSSHServerKEXPRNGSeed(
  1895  					sshClient.sshServer.support.Config.ObfuscatedSSHKey)
  1896  				if err != nil {
  1897  					err = errors.Trace(err)
  1898  				}
  1899  			}
  1900  		}
  1901  
  1902  		result := &sshNewServerConnResult{}
  1903  
  1904  		// Wrap the connection in an SSH deobfuscator when required.
  1905  
  1906  		if err == nil && protocol.TunnelProtocolUsesObfuscatedSSH(sshClient.tunnelProtocol) {
  1907  
  1908  			// Note: NewServerObfuscatedSSHConn blocks on network I/O
  1909  			// TODO: ensure this won't block shutdown
  1910  			result.obfuscatedSSHConn, err = obfuscator.NewServerObfuscatedSSHConn(
  1911  				conn,
  1912  				sshClient.sshServer.support.Config.ObfuscatedSSHKey,
  1913  				sshClient.sshServer.obfuscatorSeedHistory,
  1914  				func(clientIP string, err error, logFields common.LogFields) {
  1915  					logIrregularTunnel(
  1916  						sshClient.sshServer.support,
  1917  						sshClient.sshListener.tunnelProtocol,
  1918  						sshClient.sshListener.port,
  1919  						clientIP,
  1920  						errors.Trace(err),
  1921  						LogFields(logFields))
  1922  				})
  1923  
  1924  			if err != nil {
  1925  				err = errors.Trace(err)
  1926  			} else {
  1927  				conn = result.obfuscatedSSHConn
  1928  			}
  1929  
  1930  			// Seed the fragmentor, when present, with seed derived from initial
  1931  			// obfuscator message. See tactics.Listener.Accept. This must preceed
  1932  			// ssh.NewServerConn to ensure fragmentor is seeded before downstream bytes
  1933  			// are written.
  1934  			if err == nil && sshClient.tunnelProtocol == protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH {
  1935  				fragmentor, ok := baseConn.(common.FragmentorReplayAccessor)
  1936  				if ok {
  1937  					var fragmentorPRNG *prng.PRNG
  1938  					fragmentorPRNG, err = result.obfuscatedSSHConn.GetDerivedPRNG("server-side-fragmentor")
  1939  					if err != nil {
  1940  						err = errors.Trace(err)
  1941  					} else {
  1942  						fragmentor.SetReplay(fragmentorPRNG)
  1943  					}
  1944  				}
  1945  			}
  1946  		}
  1947  
  1948  		if err == nil {
  1949  			result.sshConn, result.channels, result.requests, err =
  1950  				ssh.NewServerConn(conn, sshServerConfig)
  1951  			if err != nil {
  1952  				err = errors.Trace(err)
  1953  			}
  1954  		}
  1955  
  1956  		result.err = err
  1957  
  1958  		resultChannel <- result
  1959  
  1960  	}(baseConn, conn)
  1961  
  1962  	var result *sshNewServerConnResult
  1963  	select {
  1964  	case result = <-resultChannel:
  1965  	case <-sshClient.sshServer.shutdownBroadcast:
  1966  		// Close() will interrupt an ongoing handshake
  1967  		// TODO: wait for SSH handshake goroutines to exit before returning?
  1968  		conn.Close()
  1969  		return
  1970  	}
  1971  
  1972  	if sshHandshakeAfterFunc != nil {
  1973  		sshHandshakeAfterFunc.Stop()
  1974  	}
  1975  
  1976  	if result.err != nil {
  1977  		conn.Close()
  1978  		// This is a Debug log due to noise. The handshake often fails due to I/O
  1979  		// errors as clients frequently interrupt connections in progress when
  1980  		// client-side load balancing completes a connection to a different server.
  1981  		log.WithTraceFields(LogFields{"error": result.err}).Debug("SSH handshake failed")
  1982  		return
  1983  	}
  1984  
  1985  	// The SSH handshake has finished successfully; notify now to allow other
  1986  	// blocked SSH handshakes to proceed.
  1987  	if onSSHHandshakeFinished != nil {
  1988  		onSSHHandshakeFinished()
  1989  	}
  1990  	onSSHHandshakeFinished = nil
  1991  
  1992  	sshClient.Lock()
  1993  	sshClient.sshConn = result.sshConn
  1994  	sshClient.throttledConn = throttledConn
  1995  	sshClient.Unlock()
  1996  
  1997  	if !sshClient.sshServer.registerEstablishedClient(sshClient) {
  1998  		conn.Close()
  1999  		log.WithTrace().Warning("register failed")
  2000  		return
  2001  	}
  2002  
  2003  	sshClient.runTunnel(result.channels, result.requests)
  2004  
  2005  	// Note: sshServer.unregisterEstablishedClient calls sshClient.stop(),
  2006  	// which also closes underlying transport Conn.
  2007  
  2008  	sshClient.sshServer.unregisterEstablishedClient(sshClient)
  2009  
  2010  	// Log tunnel metrics.
  2011  
  2012  	var additionalMetrics []LogFields
  2013  
  2014  	// Add activity and burst metrics.
  2015  	//
  2016  	// The reported duration is based on last confirmed data transfer, which for
  2017  	// sshClient.activityConn.GetActiveDuration() is time of last read byte and
  2018  	// not conn close time. This is important for protocols such as meek. For
  2019  	// meek, the connection remains open until the HTTP session expires, which
  2020  	// may be some time after the tunnel has closed. (The meek protocol has no
  2021  	// allowance for signalling payload EOF, and even if it did the client may
  2022  	// not have the opportunity to send a final request with an EOF flag set.)
  2023  
  2024  	activityMetrics := make(LogFields)
  2025  	activityMetrics["start_time"] = activityConn.GetStartTime()
  2026  	activityMetrics["duration"] = int64(activityConn.GetActiveDuration() / time.Millisecond)
  2027  	additionalMetrics = append(additionalMetrics, activityMetrics)
  2028  
  2029  	if burstConn != nil {
  2030  		// Any outstanding burst should be recorded by burstConn.Close which should
  2031  		// be called by unregisterEstablishedClient.
  2032  		additionalMetrics = append(
  2033  			additionalMetrics, LogFields(burstConn.GetMetrics(activityConn.GetStartTime())))
  2034  	}
  2035  
  2036  	// Some conns report additional metrics. Meek conns report resiliency
  2037  	// metrics and fragmentor.Conns report fragmentor configs.
  2038  
  2039  	if metricsSource, ok := baseConn.(common.MetricsSource); ok {
  2040  		additionalMetrics = append(
  2041  			additionalMetrics, LogFields(metricsSource.GetMetrics()))
  2042  	}
  2043  	if result.obfuscatedSSHConn != nil {
  2044  		additionalMetrics = append(
  2045  			additionalMetrics, LogFields(result.obfuscatedSSHConn.GetMetrics()))
  2046  	}
  2047  
  2048  	// Add server-replay metrics.
  2049  
  2050  	replayMetrics := make(LogFields)
  2051  	replayedFragmentation := false
  2052  	fragmentor, ok := baseConn.(common.FragmentorReplayAccessor)
  2053  	if ok {
  2054  		_, replayedFragmentation = fragmentor.GetReplay()
  2055  	}
  2056  	replayMetrics["server_replay_fragmentation"] = replayedFragmentation
  2057  	replayMetrics["server_replay_packet_manipulation"] = sshClient.replayedServerPacketManipulation
  2058  	additionalMetrics = append(additionalMetrics, replayMetrics)
  2059  
  2060  	// Limitation: there's only one log per tunnel with bytes transferred
  2061  	// metrics, so the byte count can't be attributed to a certain day for
  2062  	// tunnels that remain connected for well over 24h. In practise, most
  2063  	// tunnels are short-lived, especially on mobile devices.
  2064  
  2065  	sshClient.logTunnel(additionalMetrics)
  2066  
  2067  	// Transfer OSL seed state -- the OSL progress -- from the closing
  2068  	// client to the session cache so the client can resume its progress
  2069  	// if it reconnects to this same server.
  2070  	// Note: following setOSLConfig order of locking.
  2071  
  2072  	sshClient.Lock()
  2073  	if sshClient.oslClientSeedState != nil {
  2074  		sshClient.sshServer.oslSessionCacheMutex.Lock()
  2075  		sshClient.oslClientSeedState.Hibernate()
  2076  		sshClient.sshServer.oslSessionCache.Set(
  2077  			sshClient.sessionID, sshClient.oslClientSeedState, cache.DefaultExpiration)
  2078  		sshClient.sshServer.oslSessionCacheMutex.Unlock()
  2079  		sshClient.oslClientSeedState = nil
  2080  	}
  2081  	sshClient.Unlock()
  2082  
  2083  	// Initiate cleanup of the GeoIP session cache. To allow for post-tunnel
  2084  	// final status requests, the lifetime of cached GeoIP records exceeds the
  2085  	// lifetime of the sshClient.
  2086  	sshClient.sshServer.support.GeoIPService.MarkSessionCacheToExpire(sshClient.sessionID)
  2087  }
  2088  
  2089  func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
  2090  
  2091  	expectedSessionIDLength := 2 * protocol.PSIPHON_API_CLIENT_SESSION_ID_LENGTH
  2092  	expectedSSHPasswordLength := 2 * SSH_PASSWORD_BYTE_LENGTH
  2093  
  2094  	var sshPasswordPayload protocol.SSHPasswordPayload
  2095  	err := json.Unmarshal(password, &sshPasswordPayload)
  2096  	if err != nil {
  2097  
  2098  		// Backwards compatibility case: instead of a JSON payload, older clients
  2099  		// send the hex encoded session ID prepended to the SSH password.
  2100  		// Note: there's an even older case where clients don't send any session ID,
  2101  		// but that's no longer supported.
  2102  		if len(password) == expectedSessionIDLength+expectedSSHPasswordLength {
  2103  			sshPasswordPayload.SessionId = string(password[0:expectedSessionIDLength])
  2104  			sshPasswordPayload.SshPassword = string(password[expectedSessionIDLength:])
  2105  		} else {
  2106  			return nil, errors.Tracef("invalid password payload for %q", conn.User())
  2107  		}
  2108  	}
  2109  
  2110  	if !isHexDigits(sshClient.sshServer.support.Config, sshPasswordPayload.SessionId) ||
  2111  		len(sshPasswordPayload.SessionId) != expectedSessionIDLength {
  2112  		return nil, errors.Tracef("invalid session ID for %q", conn.User())
  2113  	}
  2114  
  2115  	userOk := (subtle.ConstantTimeCompare(
  2116  		[]byte(conn.User()), []byte(sshClient.sshServer.support.Config.SSHUserName)) == 1)
  2117  
  2118  	passwordOk := (subtle.ConstantTimeCompare(
  2119  		[]byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.support.Config.SSHPassword)) == 1)
  2120  
  2121  	if !userOk || !passwordOk {
  2122  		return nil, errors.Tracef("invalid password for %q", conn.User())
  2123  	}
  2124  
  2125  	sessionID := sshPasswordPayload.SessionId
  2126  
  2127  	// The GeoIP session cache will be populated if there was a previous tunnel
  2128  	// with this session ID. This will be true up to GEOIP_SESSION_CACHE_TTL, which
  2129  	// is currently much longer than the OSL session cache, another option to use if
  2130  	// the GeoIP session cache is retired (the GeoIP session cache currently only
  2131  	// supports legacy use cases).
  2132  	isFirstTunnelInSession := !sshClient.sshServer.support.GeoIPService.InSessionCache(sessionID)
  2133  
  2134  	supportsServerRequests := common.Contains(
  2135  		sshPasswordPayload.ClientCapabilities, protocol.CLIENT_CAPABILITY_SERVER_REQUESTS)
  2136  
  2137  	sshClient.Lock()
  2138  
  2139  	// After this point, these values are read-only as they are read
  2140  	// without obtaining sshClient.Lock.
  2141  	sshClient.sessionID = sessionID
  2142  	sshClient.isFirstTunnelInSession = isFirstTunnelInSession
  2143  	sshClient.supportsServerRequests = supportsServerRequests
  2144  
  2145  	geoIPData := sshClient.geoIPData
  2146  
  2147  	sshClient.Unlock()
  2148  
  2149  	// Store the GeoIP data associated with the session ID. This makes
  2150  	// the GeoIP data available to the web server for web API requests.
  2151  	// A cache that's distinct from the sshClient record is used to allow
  2152  	// for or post-tunnel final status requests.
  2153  	// If the client is reconnecting with the same session ID, this call
  2154  	// will undo the expiry set by MarkSessionCacheToExpire.
  2155  	sshClient.sshServer.support.GeoIPService.SetSessionCache(sessionID, geoIPData)
  2156  
  2157  	return nil, nil
  2158  }
  2159  
  2160  func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) {
  2161  
  2162  	if err != nil {
  2163  
  2164  		if method == "none" && err.Error() == "ssh: no auth passed yet" {
  2165  			// In this case, the callback invocation is noise from auth negotiation
  2166  			return
  2167  		}
  2168  
  2169  		// Note: here we previously logged messages for fail2ban to act on. This is no longer
  2170  		// done as the complexity outweighs the benefits.
  2171  		//
  2172  		// - The SSH credential is not secret -- it's in the server entry. Attackers targeting
  2173  		//   the server likely already have the credential. On the other hand, random scanning and
  2174  		//   brute forcing is mitigated with high entropy random passwords, rate limiting
  2175  		//   (implemented on the host via iptables), and limited capabilities (the SSH session can
  2176  		//   only port forward).
  2177  		//
  2178  		// - fail2ban coverage was inconsistent; in the case of an unfronted meek protocol through
  2179  		//   an upstream proxy, the remote address is the upstream proxy, which should not be blocked.
  2180  		//   The X-Forwarded-For header cant be used instead as it may be forged and used to get IPs
  2181  		//   deliberately blocked; and in any case fail2ban adds iptables rules which can only block
  2182  		//   by direct remote IP, not by original client IP. Fronted meek has the same iptables issue.
  2183  		//
  2184  		// Random scanning and brute forcing of port 22 will result in log noise. To mitigate this,
  2185  		// not every authentication failure is logged. A summary log is emitted periodically to
  2186  		// retain some record of this activity in case this is relevant to, e.g., a performance
  2187  		// investigation.
  2188  
  2189  		atomic.AddInt64(&sshClient.sshServer.authFailedCount, 1)
  2190  
  2191  		lastAuthLog := monotime.Time(atomic.LoadInt64(&sshClient.sshServer.lastAuthLog))
  2192  		if monotime.Since(lastAuthLog) > SSH_AUTH_LOG_PERIOD {
  2193  			now := int64(monotime.Now())
  2194  			if atomic.CompareAndSwapInt64(&sshClient.sshServer.lastAuthLog, int64(lastAuthLog), now) {
  2195  				count := atomic.SwapInt64(&sshClient.sshServer.authFailedCount, 0)
  2196  				log.WithTraceFields(
  2197  					LogFields{"lastError": err, "failedCount": count}).Warning("authentication failures")
  2198  			}
  2199  		}
  2200  
  2201  		log.WithTraceFields(LogFields{"error": err, "method": method}).Debug("authentication failed")
  2202  
  2203  	} else {
  2204  
  2205  		log.WithTraceFields(LogFields{"error": err, "method": method}).Debug("authentication success")
  2206  	}
  2207  }
  2208  
  2209  // stop signals the ssh connection to shutdown. After sshConn.Wait returns,
  2210  // the SSH connection has terminated but sshClient.run may still be running and
  2211  // in the process of exiting.
  2212  //
  2213  // The shutdown process must complete rapidly and not, e.g., block on network
  2214  // I/O, as newly connecting clients need to await stop completion of any
  2215  // existing connection that shares the same session ID.
  2216  func (sshClient *sshClient) stop() {
  2217  	sshClient.sshConn.Close()
  2218  	sshClient.sshConn.Wait()
  2219  }
  2220  
  2221  // awaitStopped will block until sshClient.run has exited, at which point all
  2222  // worker goroutines associated with the sshClient, including any in-flight
  2223  // API handlers, will have exited.
  2224  func (sshClient *sshClient) awaitStopped() {
  2225  	<-sshClient.stopped
  2226  }
  2227  
  2228  // runTunnel handles/dispatches new channels and new requests from the client.
  2229  // When the SSH client connection closes, both the channels and requests channels
  2230  // will close and runTunnel will exit.
  2231  func (sshClient *sshClient) runTunnel(
  2232  	channels <-chan ssh.NewChannel,
  2233  	requests <-chan *ssh.Request) {
  2234  
  2235  	waitGroup := new(sync.WaitGroup)
  2236  
  2237  	// Start client SSH API request handler
  2238  
  2239  	waitGroup.Add(1)
  2240  	go func() {
  2241  		defer waitGroup.Done()
  2242  		sshClient.handleSSHRequests(requests)
  2243  	}()
  2244  
  2245  	// Start request senders
  2246  
  2247  	if sshClient.supportsServerRequests {
  2248  
  2249  		waitGroup.Add(1)
  2250  		go func() {
  2251  			defer waitGroup.Done()
  2252  			sshClient.runOSLSender()
  2253  		}()
  2254  
  2255  		waitGroup.Add(1)
  2256  		go func() {
  2257  			defer waitGroup.Done()
  2258  			sshClient.runAlertSender()
  2259  		}()
  2260  	}
  2261  
  2262  	// Start the TCP port forward manager
  2263  
  2264  	// The queue size is set to the traffic rules (MaxTCPPortForwardCount +
  2265  	// MaxTCPDialingPortForwardCount), which is a reasonable indication of resource
  2266  	// limits per client; when that value is not set, a default is used.
  2267  	// A limitation: this queue size is set once and doesn't change, for this client,
  2268  	// when traffic rules are reloaded.
  2269  	queueSize := sshClient.getTCPPortForwardQueueSize()
  2270  	if queueSize == 0 {
  2271  		queueSize = SSH_TCP_PORT_FORWARD_QUEUE_SIZE
  2272  	}
  2273  	newTCPPortForwards := make(chan *newTCPPortForward, queueSize)
  2274  
  2275  	waitGroup.Add(1)
  2276  	go func() {
  2277  		defer waitGroup.Done()
  2278  		sshClient.handleTCPPortForwards(waitGroup, newTCPPortForwards)
  2279  	}()
  2280  
  2281  	// Handle new channel (port forward) requests from the client.
  2282  
  2283  	for newChannel := range channels {
  2284  		switch newChannel.ChannelType() {
  2285  		case protocol.RANDOM_STREAM_CHANNEL_TYPE:
  2286  			sshClient.handleNewRandomStreamChannel(waitGroup, newChannel)
  2287  		case protocol.PACKET_TUNNEL_CHANNEL_TYPE:
  2288  			sshClient.handleNewPacketTunnelChannel(waitGroup, newChannel)
  2289  		case protocol.TCP_PORT_FORWARD_NO_SPLIT_TUNNEL_TYPE:
  2290  			// The protocol.TCP_PORT_FORWARD_NO_SPLIT_TUNNEL_TYPE is the same as
  2291  			// "direct-tcpip", except split tunnel channel rejections are disallowed
  2292  			// even if the client has enabled split tunnel. This channel type allows
  2293  			// the client to ensure tunneling for certain cases while split tunnel is
  2294  			// enabled.
  2295  			sshClient.handleNewTCPPortForwardChannel(waitGroup, newChannel, false, newTCPPortForwards)
  2296  		case "direct-tcpip":
  2297  			sshClient.handleNewTCPPortForwardChannel(waitGroup, newChannel, true, newTCPPortForwards)
  2298  		default:
  2299  			sshClient.rejectNewChannel(newChannel,
  2300  				fmt.Sprintf("unknown or unsupported channel type: %s", newChannel.ChannelType()))
  2301  		}
  2302  	}
  2303  
  2304  	// The channel loop is interrupted by a client
  2305  	// disconnect or by calling sshClient.stop().
  2306  
  2307  	// Stop the TCP port forward manager
  2308  	close(newTCPPortForwards)
  2309  
  2310  	// Stop all other worker goroutines
  2311  	sshClient.stopRunning()
  2312  
  2313  	if sshClient.sshServer.support.Config.RunPacketTunnel {
  2314  		// PacketTunnelServer.ClientDisconnected stops packet tunnel workers.
  2315  		sshClient.sshServer.support.PacketTunnelServer.ClientDisconnected(
  2316  			sshClient.sessionID)
  2317  	}
  2318  
  2319  	waitGroup.Wait()
  2320  
  2321  	sshClient.cleanupAuthorizations()
  2322  }
  2323  
  2324  func (sshClient *sshClient) handleSSHRequests(requests <-chan *ssh.Request) {
  2325  
  2326  	for request := range requests {
  2327  
  2328  		// Requests are processed serially; API responses must be sent in request order.
  2329  
  2330  		var responsePayload []byte
  2331  		var err error
  2332  
  2333  		if request.Type == "keepalive@openssh.com" {
  2334  
  2335  			// SSH keep alive round trips are used as speed test samples.
  2336  			responsePayload, err = tactics.MakeSpeedTestResponse(
  2337  				SSH_KEEP_ALIVE_PAYLOAD_MIN_BYTES, SSH_KEEP_ALIVE_PAYLOAD_MAX_BYTES)
  2338  
  2339  		} else {
  2340  
  2341  			// All other requests are assumed to be API requests.
  2342  
  2343  			sshClient.Lock()
  2344  			authorizedAccessTypes := sshClient.handshakeState.authorizedAccessTypes
  2345  			sshClient.Unlock()
  2346  
  2347  			// Note: unlock before use is only safe as long as referenced sshClient data,
  2348  			// such as slices in handshakeState, is read-only after initially set.
  2349  
  2350  			clientAddr := ""
  2351  			if sshClient.clientAddr != nil {
  2352  				clientAddr = sshClient.clientAddr.String()
  2353  			}
  2354  
  2355  			responsePayload, err = sshAPIRequestHandler(
  2356  				sshClient.sshServer.support,
  2357  				clientAddr,
  2358  				sshClient.geoIPData,
  2359  				authorizedAccessTypes,
  2360  				request.Type,
  2361  				request.Payload)
  2362  		}
  2363  
  2364  		if err == nil {
  2365  			err = request.Reply(true, responsePayload)
  2366  		} else {
  2367  			log.WithTraceFields(LogFields{"error": err}).Warning("request failed")
  2368  			err = request.Reply(false, nil)
  2369  		}
  2370  		if err != nil {
  2371  			if !isExpectedTunnelIOError(err) {
  2372  				log.WithTraceFields(LogFields{"error": err}).Warning("response failed")
  2373  			}
  2374  		}
  2375  
  2376  	}
  2377  
  2378  }
  2379  
  2380  type newTCPPortForward struct {
  2381  	enqueueTime   time.Time
  2382  	hostToConnect string
  2383  	portToConnect int
  2384  	doSplitTunnel bool
  2385  	newChannel    ssh.NewChannel
  2386  }
  2387  
  2388  func (sshClient *sshClient) handleTCPPortForwards(
  2389  	waitGroup *sync.WaitGroup,
  2390  	newTCPPortForwards chan *newTCPPortForward) {
  2391  
  2392  	// Lifecycle of a TCP port forward:
  2393  	//
  2394  	// 1. A "direct-tcpip" SSH request is received from the client.
  2395  	//
  2396  	//    A new TCP port forward request is enqueued. The queue delivers TCP port
  2397  	//    forward requests to the TCP port forward manager, which enforces the TCP
  2398  	//    port forward dial limit.
  2399  	//
  2400  	//    Enqueuing new requests allows for reading further SSH requests from the
  2401  	//    client without blocking when the dial limit is hit; this is to permit new
  2402  	//    UDP/udpgw port forwards to be restablished without delay. The maximum size
  2403  	//    of the queue enforces a hard cap on resources consumed by a client in the
  2404  	//    pre-dial phase. When the queue is full, new TCP port forwards are
  2405  	//    immediately rejected.
  2406  	//
  2407  	// 2. The TCP port forward manager dequeues the request.
  2408  	//
  2409  	//    The manager calls dialingTCPPortForward(), which increments
  2410  	//    concurrentDialingPortForwardCount, and calls
  2411  	//    isTCPDialingPortForwardLimitExceeded() to check the concurrent dialing
  2412  	//    count.
  2413  	//
  2414  	//    The manager enforces the concurrent TCP dial limit: when at the limit, the
  2415  	//    manager blocks waiting for the number of dials to drop below the limit before
  2416  	//    dispatching the request to handleTCPPortForward(), which will run in its own
  2417  	//    goroutine and will dial and relay the port forward.
  2418  	//
  2419  	//    The block delays the current request and also halts dequeuing of subsequent
  2420  	//    requests and could ultimately cause requests to be immediately rejected if
  2421  	//    the queue fills. These actions are intended to apply back pressure when
  2422  	//    upstream network resources are impaired.
  2423  	//
  2424  	//    The time spent in the queue is deducted from the port forward's dial timeout.
  2425  	//    The time spent blocking while at the dial limit is similarly deducted from
  2426  	//    the dial timeout. If the dial timeout has expired before the dial begins, the
  2427  	//    port forward is rejected and a stat is recorded.
  2428  	//
  2429  	// 3. handleTCPPortForward() performs the port forward dial and relaying.
  2430  	//
  2431  	//     a. Dial the target, using the dial timeout remaining after queue and blocking
  2432  	//        time is deducted.
  2433  	//
  2434  	//     b. If the dial fails, call abortedTCPPortForward() to decrement
  2435  	//        concurrentDialingPortForwardCount, freeing up a dial slot.
  2436  	//
  2437  	//     c. If the dial succeeds, call establishedPortForward(), which decrements
  2438  	//        concurrentDialingPortForwardCount and increments concurrentPortForwardCount,
  2439  	//        the "established" port forward count.
  2440  	//
  2441  	//    d. Check isPortForwardLimitExceeded(), which enforces the configurable limit on
  2442  	//       concurrentPortForwardCount, the number of _established_ TCP port forwards.
  2443  	//       If the limit is exceeded, the LRU established TCP port forward is closed and
  2444  	//       the newly established TCP port forward proceeds. This LRU logic allows some
  2445  	//       dangling resource consumption (e.g., TIME_WAIT) while providing a better
  2446  	//       experience for clients.
  2447  	//
  2448  	//    e. Relay data.
  2449  	//
  2450  	//    f. Call closedPortForward() which decrements concurrentPortForwardCount and
  2451  	//       records bytes transferred.
  2452  
  2453  	for newPortForward := range newTCPPortForwards {
  2454  
  2455  		remainingDialTimeout :=
  2456  			time.Duration(sshClient.getDialTCPPortForwardTimeoutMilliseconds())*time.Millisecond -
  2457  				time.Since(newPortForward.enqueueTime)
  2458  
  2459  		if remainingDialTimeout <= 0 {
  2460  			sshClient.updateQualityMetricsWithRejectedDialingLimit()
  2461  			sshClient.rejectNewChannel(
  2462  				newPortForward.newChannel, "TCP port forward timed out in queue")
  2463  			continue
  2464  		}
  2465  
  2466  		// Reserve a TCP dialing slot.
  2467  		//
  2468  		// TOCTOU note: important to increment counts _before_ checking limits; otherwise,
  2469  		// the client could potentially consume excess resources by initiating many port
  2470  		// forwards concurrently.
  2471  
  2472  		sshClient.dialingTCPPortForward()
  2473  
  2474  		// When max dials are in progress, wait up to remainingDialTimeout for dialing
  2475  		// to become available. This blocks all dequeing.
  2476  
  2477  		if sshClient.isTCPDialingPortForwardLimitExceeded() {
  2478  			blockStartTime := time.Now()
  2479  			ctx, cancelCtx := context.WithTimeout(sshClient.runCtx, remainingDialTimeout)
  2480  			sshClient.setTCPPortForwardDialingAvailableSignal(cancelCtx)
  2481  			<-ctx.Done()
  2482  			sshClient.setTCPPortForwardDialingAvailableSignal(nil)
  2483  			cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
  2484  			remainingDialTimeout -= time.Since(blockStartTime)
  2485  		}
  2486  
  2487  		if remainingDialTimeout <= 0 {
  2488  
  2489  			// Release the dialing slot here since handleTCPChannel() won't be called.
  2490  			sshClient.abortedTCPPortForward()
  2491  
  2492  			sshClient.updateQualityMetricsWithRejectedDialingLimit()
  2493  			sshClient.rejectNewChannel(
  2494  				newPortForward.newChannel, "TCP port forward timed out before dialing")
  2495  			continue
  2496  		}
  2497  
  2498  		// Dial and relay the TCP port forward. handleTCPChannel is run in its own worker goroutine.
  2499  		// handleTCPChannel will release the dialing slot reserved by dialingTCPPortForward(); and
  2500  		// will deal with remainingDialTimeout <= 0.
  2501  
  2502  		waitGroup.Add(1)
  2503  		go func(remainingDialTimeout time.Duration, newPortForward *newTCPPortForward) {
  2504  			defer waitGroup.Done()
  2505  			sshClient.handleTCPChannel(
  2506  				remainingDialTimeout,
  2507  				newPortForward.hostToConnect,
  2508  				newPortForward.portToConnect,
  2509  				newPortForward.doSplitTunnel,
  2510  				newPortForward.newChannel)
  2511  		}(remainingDialTimeout, newPortForward)
  2512  	}
  2513  }
  2514  
  2515  func (sshClient *sshClient) handleNewRandomStreamChannel(
  2516  	waitGroup *sync.WaitGroup, newChannel ssh.NewChannel) {
  2517  
  2518  	// A random stream channel returns the requested number of bytes -- random
  2519  	// bytes -- to the client while also consuming and discarding bytes sent
  2520  	// by the client.
  2521  	//
  2522  	// One use case for the random stream channel is a liveness test that the
  2523  	// client performs to confirm that the tunnel is live. As the liveness
  2524  	// test is performed in the concurrent establishment phase, before
  2525  	// selecting a single candidate for handshake, the random stream channel
  2526  	// is available pre-handshake, albeit with additional restrictions.
  2527  	//
  2528  	// The random stream is subject to throttling in traffic rules; for
  2529  	// unthrottled liveness tests, set EstablishmentRead/WriteBytesPerSecond as
  2530  	// required. The random stream maximum count and response size cap mitigate
  2531  	// clients abusing the facility to waste server resources.
  2532  	//
  2533  	// Like all other channels, this channel type is handled asynchronously,
  2534  	// so it's possible to run at any point in the tunnel lifecycle.
  2535  	//
  2536  	// Up/downstream byte counts don't include SSH packet and request
  2537  	// marshalling overhead.
  2538  
  2539  	var request protocol.RandomStreamRequest
  2540  	err := json.Unmarshal(newChannel.ExtraData(), &request)
  2541  	if err != nil {
  2542  		sshClient.rejectNewChannel(newChannel, fmt.Sprintf("invalid request: %s", err))
  2543  		return
  2544  	}
  2545  
  2546  	if request.UpstreamBytes > RANDOM_STREAM_MAX_BYTES {
  2547  		sshClient.rejectNewChannel(newChannel,
  2548  			fmt.Sprintf("invalid upstream bytes: %d", request.UpstreamBytes))
  2549  		return
  2550  	}
  2551  
  2552  	if request.DownstreamBytes > RANDOM_STREAM_MAX_BYTES {
  2553  		sshClient.rejectNewChannel(newChannel,
  2554  			fmt.Sprintf("invalid downstream bytes: %d", request.DownstreamBytes))
  2555  		return
  2556  	}
  2557  
  2558  	var metrics *randomStreamMetrics
  2559  
  2560  	sshClient.Lock()
  2561  
  2562  	if !sshClient.handshakeState.completed {
  2563  		metrics = &sshClient.preHandshakeRandomStreamMetrics
  2564  	} else {
  2565  		metrics = &sshClient.postHandshakeRandomStreamMetrics
  2566  	}
  2567  
  2568  	countOk := true
  2569  	if !sshClient.handshakeState.completed &&
  2570  		metrics.count >= PRE_HANDSHAKE_RANDOM_STREAM_MAX_COUNT {
  2571  		countOk = false
  2572  	} else {
  2573  		metrics.count++
  2574  	}
  2575  
  2576  	sshClient.Unlock()
  2577  
  2578  	if !countOk {
  2579  		sshClient.rejectNewChannel(newChannel, "max count exceeded")
  2580  		return
  2581  	}
  2582  
  2583  	channel, requests, err := newChannel.Accept()
  2584  	if err != nil {
  2585  		if !isExpectedTunnelIOError(err) {
  2586  			log.WithTraceFields(LogFields{"error": err}).Warning("accept new channel failed")
  2587  		}
  2588  		return
  2589  	}
  2590  	go ssh.DiscardRequests(requests)
  2591  
  2592  	waitGroup.Add(1)
  2593  	go func() {
  2594  		defer waitGroup.Done()
  2595  
  2596  		upstream := new(sync.WaitGroup)
  2597  		received := 0
  2598  		sent := 0
  2599  
  2600  		if request.UpstreamBytes > 0 {
  2601  
  2602  			// Process streams concurrently to minimize elapsed time. This also
  2603  			// avoids a unidirectional flow burst early in the tunnel lifecycle.
  2604  
  2605  			upstream.Add(1)
  2606  			go func() {
  2607  				defer upstream.Done()
  2608  				n, err := io.CopyN(ioutil.Discard, channel, int64(request.UpstreamBytes))
  2609  				received = int(n)
  2610  				if err != nil {
  2611  					if !isExpectedTunnelIOError(err) {
  2612  						log.WithTraceFields(LogFields{"error": err}).Warning("receive failed")
  2613  					}
  2614  				}
  2615  			}()
  2616  		}
  2617  
  2618  		if request.DownstreamBytes > 0 {
  2619  			n, err := io.CopyN(channel, rand.Reader, int64(request.DownstreamBytes))
  2620  			sent = int(n)
  2621  			if err != nil {
  2622  				if !isExpectedTunnelIOError(err) {
  2623  					log.WithTraceFields(LogFields{"error": err}).Warning("send failed")
  2624  				}
  2625  			}
  2626  		}
  2627  
  2628  		upstream.Wait()
  2629  
  2630  		sshClient.Lock()
  2631  		metrics.upstreamBytes += int64(request.UpstreamBytes)
  2632  		metrics.receivedUpstreamBytes += int64(received)
  2633  		metrics.downstreamBytes += int64(request.DownstreamBytes)
  2634  		metrics.sentDownstreamBytes += int64(sent)
  2635  		sshClient.Unlock()
  2636  
  2637  		channel.Close()
  2638  	}()
  2639  }
  2640  
  2641  func (sshClient *sshClient) handleNewPacketTunnelChannel(
  2642  	waitGroup *sync.WaitGroup, newChannel ssh.NewChannel) {
  2643  
  2644  	// packet tunnel channels are handled by the packet tunnel server
  2645  	// component. Each client may have at most one packet tunnel channel.
  2646  
  2647  	if !sshClient.sshServer.support.Config.RunPacketTunnel {
  2648  		sshClient.rejectNewChannel(newChannel, "unsupported packet tunnel channel type")
  2649  		return
  2650  	}
  2651  
  2652  	// Accept this channel immediately. This channel will replace any
  2653  	// previously existing packet tunnel channel for this client.
  2654  
  2655  	packetTunnelChannel, requests, err := newChannel.Accept()
  2656  	if err != nil {
  2657  		if !isExpectedTunnelIOError(err) {
  2658  			log.WithTraceFields(LogFields{"error": err}).Warning("accept new channel failed")
  2659  		}
  2660  		return
  2661  	}
  2662  	go ssh.DiscardRequests(requests)
  2663  
  2664  	sshClient.setPacketTunnelChannel(packetTunnelChannel)
  2665  
  2666  	// PacketTunnelServer will run the client's packet tunnel. If necessary, ClientConnected
  2667  	// will stop packet tunnel workers for any previous packet tunnel channel.
  2668  
  2669  	checkAllowedTCPPortFunc := func(upstreamIPAddress net.IP, port int) bool {
  2670  		return sshClient.isPortForwardPermitted(portForwardTypeTCP, upstreamIPAddress, port)
  2671  	}
  2672  
  2673  	checkAllowedUDPPortFunc := func(upstreamIPAddress net.IP, port int) bool {
  2674  		return sshClient.isPortForwardPermitted(portForwardTypeUDP, upstreamIPAddress, port)
  2675  	}
  2676  
  2677  	checkAllowedDomainFunc := func(domain string) bool {
  2678  		ok, _ := sshClient.isDomainPermitted(domain)
  2679  		return ok
  2680  	}
  2681  
  2682  	flowActivityUpdaterMaker := func(
  2683  		isTCP bool, upstreamHostname string, upstreamIPAddress net.IP) []tun.FlowActivityUpdater {
  2684  
  2685  		trafficType := portForwardTypeTCP
  2686  		if !isTCP {
  2687  			trafficType = portForwardTypeUDP
  2688  		}
  2689  
  2690  		activityUpdaters := sshClient.getActivityUpdaters(trafficType, upstreamIPAddress)
  2691  
  2692  		flowUpdaters := make([]tun.FlowActivityUpdater, len(activityUpdaters))
  2693  		for i, activityUpdater := range activityUpdaters {
  2694  			flowUpdaters[i] = activityUpdater
  2695  		}
  2696  
  2697  		return flowUpdaters
  2698  	}
  2699  
  2700  	metricUpdater := func(
  2701  		TCPApplicationBytesDown, TCPApplicationBytesUp,
  2702  		UDPApplicationBytesDown, UDPApplicationBytesUp int64) {
  2703  
  2704  		sshClient.Lock()
  2705  		sshClient.tcpTrafficState.bytesDown += TCPApplicationBytesDown
  2706  		sshClient.tcpTrafficState.bytesUp += TCPApplicationBytesUp
  2707  		sshClient.udpTrafficState.bytesDown += UDPApplicationBytesDown
  2708  		sshClient.udpTrafficState.bytesUp += UDPApplicationBytesUp
  2709  		sshClient.Unlock()
  2710  	}
  2711  
  2712  	dnsQualityReporter := sshClient.updateQualityMetricsWithDNSResult
  2713  
  2714  	err = sshClient.sshServer.support.PacketTunnelServer.ClientConnected(
  2715  		sshClient.sessionID,
  2716  		packetTunnelChannel,
  2717  		checkAllowedTCPPortFunc,
  2718  		checkAllowedUDPPortFunc,
  2719  		checkAllowedDomainFunc,
  2720  		flowActivityUpdaterMaker,
  2721  		metricUpdater,
  2722  		dnsQualityReporter)
  2723  	if err != nil {
  2724  		log.WithTraceFields(LogFields{"error": err}).Warning("start packet tunnel client failed")
  2725  		sshClient.setPacketTunnelChannel(nil)
  2726  	}
  2727  }
  2728  
  2729  func (sshClient *sshClient) handleNewTCPPortForwardChannel(
  2730  	waitGroup *sync.WaitGroup,
  2731  	newChannel ssh.NewChannel,
  2732  	allowSplitTunnel bool,
  2733  	newTCPPortForwards chan *newTCPPortForward) {
  2734  
  2735  	// udpgw client connections are dispatched immediately (clients use this for
  2736  	// DNS, so it's essential to not block; and only one udpgw connection is
  2737  	// retained at a time).
  2738  	//
  2739  	// All other TCP port forwards are dispatched via the TCP port forward
  2740  	// manager queue.
  2741  
  2742  	// http://tools.ietf.org/html/rfc4254#section-7.2
  2743  	var directTcpipExtraData struct {
  2744  		HostToConnect       string
  2745  		PortToConnect       uint32
  2746  		OriginatorIPAddress string
  2747  		OriginatorPort      uint32
  2748  	}
  2749  
  2750  	err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData)
  2751  	if err != nil {
  2752  		sshClient.rejectNewChannel(newChannel, "invalid extra data")
  2753  		return
  2754  	}
  2755  
  2756  	// Intercept TCP port forwards to a specified udpgw server and handle directly.
  2757  	// TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type?
  2758  	isUdpgwChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" &&
  2759  		sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress ==
  2760  			net.JoinHostPort(directTcpipExtraData.HostToConnect, strconv.Itoa(int(directTcpipExtraData.PortToConnect)))
  2761  
  2762  	if isUdpgwChannel {
  2763  
  2764  		// Dispatch immediately. handleUDPChannel runs the udpgw protocol in its
  2765  		// own worker goroutine.
  2766  
  2767  		waitGroup.Add(1)
  2768  		go func(channel ssh.NewChannel) {
  2769  			defer waitGroup.Done()
  2770  			sshClient.handleUdpgwChannel(channel)
  2771  		}(newChannel)
  2772  
  2773  	} else {
  2774  
  2775  		// Dispatch via TCP port forward manager. When the queue is full, the channel
  2776  		// is immediately rejected.
  2777  		//
  2778  		// Split tunnel logic is enabled for this TCP port forward when the client
  2779  		// has enabled split tunnel mode and the channel type allows it.
  2780  
  2781  		doSplitTunnel := sshClient.handshakeState.splitTunnelLookup != nil && allowSplitTunnel
  2782  
  2783  		tcpPortForward := &newTCPPortForward{
  2784  			enqueueTime:   time.Now(),
  2785  			hostToConnect: directTcpipExtraData.HostToConnect,
  2786  			portToConnect: int(directTcpipExtraData.PortToConnect),
  2787  			doSplitTunnel: doSplitTunnel,
  2788  			newChannel:    newChannel,
  2789  		}
  2790  
  2791  		select {
  2792  		case newTCPPortForwards <- tcpPortForward:
  2793  		default:
  2794  			sshClient.updateQualityMetricsWithRejectedDialingLimit()
  2795  			sshClient.rejectNewChannel(newChannel, "TCP port forward dial queue full")
  2796  		}
  2797  	}
  2798  }
  2799  
  2800  func (sshClient *sshClient) cleanupAuthorizations() {
  2801  	sshClient.Lock()
  2802  
  2803  	if sshClient.releaseAuthorizations != nil {
  2804  		sshClient.releaseAuthorizations()
  2805  	}
  2806  
  2807  	if sshClient.stopTimer != nil {
  2808  		sshClient.stopTimer.Stop()
  2809  	}
  2810  
  2811  	sshClient.Unlock()
  2812  }
  2813  
  2814  // setPacketTunnelChannel sets the single packet tunnel channel
  2815  // for this sshClient. Any existing packet tunnel channel is
  2816  // closed.
  2817  func (sshClient *sshClient) setPacketTunnelChannel(channel ssh.Channel) {
  2818  	sshClient.Lock()
  2819  	if sshClient.packetTunnelChannel != nil {
  2820  		sshClient.packetTunnelChannel.Close()
  2821  	}
  2822  	sshClient.packetTunnelChannel = channel
  2823  	sshClient.totalPacketTunnelChannelCount += 1
  2824  	sshClient.Unlock()
  2825  }
  2826  
  2827  // setUdpgwChannelHandler sets the single udpgw channel handler for this
  2828  // sshClient. Each sshClient may have only one concurrent udpgw
  2829  // channel/handler. Each udpgw channel multiplexes many UDP port forwards via
  2830  // the udpgw protocol. Any existing udpgw channel/handler is closed.
  2831  func (sshClient *sshClient) setUdpgwChannelHandler(udpgwChannelHandler *udpgwPortForwardMultiplexer) bool {
  2832  	sshClient.Lock()
  2833  	if sshClient.udpgwChannelHandler != nil {
  2834  		previousHandler := sshClient.udpgwChannelHandler
  2835  		sshClient.udpgwChannelHandler = nil
  2836  
  2837  		// stop must be run without holding the sshClient mutex lock, as the
  2838  		// udpgw goroutines may attempt to lock the same mutex. For example,
  2839  		// udpgwPortForwardMultiplexer.run calls sshClient.establishedPortForward
  2840  		// which calls sshClient.allocatePortForward.
  2841  		sshClient.Unlock()
  2842  		previousHandler.stop()
  2843  		sshClient.Lock()
  2844  
  2845  		// In case some other channel has set the sshClient.udpgwChannelHandler
  2846  		// in the meantime, fail. The caller should discard this channel/handler.
  2847  		if sshClient.udpgwChannelHandler != nil {
  2848  			sshClient.Unlock()
  2849  			return false
  2850  		}
  2851  	}
  2852  	sshClient.udpgwChannelHandler = udpgwChannelHandler
  2853  	sshClient.totalUdpgwChannelCount += 1
  2854  	sshClient.Unlock()
  2855  	return true
  2856  }
  2857  
  2858  var serverTunnelStatParams = append(
  2859  	[]requestParamSpec{
  2860  		{"last_connected", isLastConnected, requestParamOptional},
  2861  		{"establishment_duration", isIntString, requestParamOptional}},
  2862  	baseSessionAndDialParams...)
  2863  
  2864  func (sshClient *sshClient) logTunnel(additionalMetrics []LogFields) {
  2865  
  2866  	sshClient.Lock()
  2867  
  2868  	logFields := getRequestLogFields(
  2869  		"server_tunnel",
  2870  		sshClient.geoIPData,
  2871  		sshClient.handshakeState.authorizedAccessTypes,
  2872  		sshClient.handshakeState.apiParams,
  2873  		serverTunnelStatParams)
  2874  
  2875  	// "relay_protocol" is sent with handshake API parameters. In pre-
  2876  	// handshake logTunnel cases, this value is not yet known. As
  2877  	// sshClient.tunnelProtocol is authoritative, set this value
  2878  	// unconditionally, overwriting any value from handshake.
  2879  	logFields["relay_protocol"] = sshClient.tunnelProtocol
  2880  
  2881  	if sshClient.serverPacketManipulation != "" {
  2882  		logFields["server_packet_manipulation"] = sshClient.serverPacketManipulation
  2883  	}
  2884  	if sshClient.sshListener.BPFProgramName != "" {
  2885  		logFields["server_bpf"] = sshClient.sshListener.BPFProgramName
  2886  	}
  2887  	logFields["session_id"] = sshClient.sessionID
  2888  	logFields["is_first_tunnel_in_session"] = sshClient.isFirstTunnelInSession
  2889  	logFields["handshake_completed"] = sshClient.handshakeState.completed
  2890  	logFields["bytes_up_tcp"] = sshClient.tcpTrafficState.bytesUp
  2891  	logFields["bytes_down_tcp"] = sshClient.tcpTrafficState.bytesDown
  2892  	logFields["peak_concurrent_dialing_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentDialingPortForwardCount
  2893  	logFields["peak_concurrent_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentPortForwardCount
  2894  	logFields["total_port_forward_count_tcp"] = sshClient.tcpTrafficState.totalPortForwardCount
  2895  	logFields["bytes_up_udp"] = sshClient.udpTrafficState.bytesUp
  2896  	logFields["bytes_down_udp"] = sshClient.udpTrafficState.bytesDown
  2897  	// sshClient.udpTrafficState.peakConcurrentDialingPortForwardCount isn't meaningful
  2898  	logFields["peak_concurrent_port_forward_count_udp"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount
  2899  	logFields["total_port_forward_count_udp"] = sshClient.udpTrafficState.totalPortForwardCount
  2900  	logFields["total_udpgw_channel_count"] = sshClient.totalUdpgwChannelCount
  2901  	logFields["total_packet_tunnel_channel_count"] = sshClient.totalPacketTunnelChannelCount
  2902  
  2903  	logFields["pre_handshake_random_stream_count"] = sshClient.preHandshakeRandomStreamMetrics.count
  2904  	logFields["pre_handshake_random_stream_upstream_bytes"] = sshClient.preHandshakeRandomStreamMetrics.upstreamBytes
  2905  	logFields["pre_handshake_random_stream_received_upstream_bytes"] = sshClient.preHandshakeRandomStreamMetrics.receivedUpstreamBytes
  2906  	logFields["pre_handshake_random_stream_downstream_bytes"] = sshClient.preHandshakeRandomStreamMetrics.downstreamBytes
  2907  	logFields["pre_handshake_random_stream_sent_downstream_bytes"] = sshClient.preHandshakeRandomStreamMetrics.sentDownstreamBytes
  2908  	logFields["random_stream_count"] = sshClient.postHandshakeRandomStreamMetrics.count
  2909  	logFields["random_stream_upstream_bytes"] = sshClient.postHandshakeRandomStreamMetrics.upstreamBytes
  2910  	logFields["random_stream_received_upstream_bytes"] = sshClient.postHandshakeRandomStreamMetrics.receivedUpstreamBytes
  2911  	logFields["random_stream_downstream_bytes"] = sshClient.postHandshakeRandomStreamMetrics.downstreamBytes
  2912  	logFields["random_stream_sent_downstream_bytes"] = sshClient.postHandshakeRandomStreamMetrics.sentDownstreamBytes
  2913  
  2914  	if sshClient.destinationBytesMetricsASN != "" {
  2915  
  2916  		// Check if the configured DestinationBytesMetricsASN has changed
  2917  		// (or been cleared). If so, don't log and discard the accumulated
  2918  		// bytes to ensure we don't continue to record stats as previously
  2919  		// configured.
  2920  		//
  2921  		// Any counts accumulated before the DestinationBytesMetricsASN change
  2922  		// are lost. At this time we can't change
  2923  		// sshClient.destinationBytesMetricsASN dynamically, after a tactics
  2924  		// hot reload, as there may be destination bytes port forwards that
  2925  		// were in place before the change, which will continue to count.
  2926  
  2927  		logDestBytes := true
  2928  		if sshClient.sshServer.support.ServerTacticsParametersCache != nil {
  2929  			p, err := sshClient.sshServer.support.ServerTacticsParametersCache.Get(sshClient.geoIPData)
  2930  			if err != nil || p.IsNil() ||
  2931  				sshClient.destinationBytesMetricsASN != p.String(parameters.DestinationBytesMetricsASN) {
  2932  				logDestBytes = false
  2933  			}
  2934  		}
  2935  
  2936  		if logDestBytes {
  2937  			bytesUpTCP := sshClient.tcpDestinationBytesMetrics.getBytesUp()
  2938  			bytesDownTCP := sshClient.tcpDestinationBytesMetrics.getBytesDown()
  2939  			bytesUpUDP := sshClient.udpDestinationBytesMetrics.getBytesUp()
  2940  			bytesDownUDP := sshClient.udpDestinationBytesMetrics.getBytesDown()
  2941  
  2942  			logFields["dest_bytes_asn"] = sshClient.destinationBytesMetricsASN
  2943  			logFields["dest_bytes_up_tcp"] = bytesUpTCP
  2944  			logFields["dest_bytes_down_tcp"] = bytesDownTCP
  2945  			logFields["dest_bytes_up_udp"] = bytesUpUDP
  2946  			logFields["dest_bytes_down_udp"] = bytesDownUDP
  2947  			logFields["dest_bytes"] = bytesUpTCP + bytesDownTCP + bytesUpUDP + bytesDownUDP
  2948  		}
  2949  	}
  2950  
  2951  	// Only log fields for peakMetrics when there is data recorded, otherwise
  2952  	// omit the field.
  2953  	if sshClient.peakMetrics.concurrentProximateAcceptedClients != nil {
  2954  		logFields["peak_concurrent_proximate_accepted_clients"] = *sshClient.peakMetrics.concurrentProximateAcceptedClients
  2955  	}
  2956  	if sshClient.peakMetrics.concurrentProximateEstablishedClients != nil {
  2957  		logFields["peak_concurrent_proximate_established_clients"] = *sshClient.peakMetrics.concurrentProximateEstablishedClients
  2958  	}
  2959  	if sshClient.peakMetrics.TCPPortForwardFailureRate != nil && sshClient.peakMetrics.TCPPortForwardFailureRateSampleSize != nil {
  2960  		logFields["peak_tcp_port_forward_failure_rate"] = *sshClient.peakMetrics.TCPPortForwardFailureRate
  2961  		logFields["peak_tcp_port_forward_failure_rate_sample_size"] = *sshClient.peakMetrics.TCPPortForwardFailureRateSampleSize
  2962  	}
  2963  	if sshClient.peakMetrics.DNSFailureRate != nil && sshClient.peakMetrics.DNSFailureRateSampleSize != nil {
  2964  		logFields["peak_dns_failure_rate"] = *sshClient.peakMetrics.DNSFailureRate
  2965  		logFields["peak_dns_failure_rate_sample_size"] = *sshClient.peakMetrics.DNSFailureRateSampleSize
  2966  	}
  2967  
  2968  	// Pre-calculate a total-tunneled-bytes field. This total is used
  2969  	// extensively in analytics and is more performant when pre-calculated.
  2970  	logFields["bytes"] = sshClient.tcpTrafficState.bytesUp +
  2971  		sshClient.tcpTrafficState.bytesDown +
  2972  		sshClient.udpTrafficState.bytesUp +
  2973  		sshClient.udpTrafficState.bytesDown
  2974  
  2975  	// Merge in additional metrics from the optional metrics source
  2976  	for _, metrics := range additionalMetrics {
  2977  		for name, value := range metrics {
  2978  			// Don't overwrite any basic fields
  2979  			if logFields[name] == nil {
  2980  				logFields[name] = value
  2981  			}
  2982  		}
  2983  	}
  2984  
  2985  	// Retain lock when invoking LogRawFieldsWithTimestamp to block any
  2986  	// concurrent writes to variables referenced by logFields.
  2987  	log.LogRawFieldsWithTimestamp(logFields)
  2988  
  2989  	sshClient.Unlock()
  2990  }
  2991  
  2992  var blocklistHitsStatParams = []requestParamSpec{
  2993  	{"propagation_channel_id", isHexDigits, 0},
  2994  	{"sponsor_id", isHexDigits, 0},
  2995  	{"client_version", isIntString, requestParamLogStringAsInt},
  2996  	{"client_platform", isClientPlatform, 0},
  2997  	{"client_features", isAnyString, requestParamOptional | requestParamArray},
  2998  	{"client_build_rev", isHexDigits, requestParamOptional},
  2999  	{"device_region", isAnyString, requestParamOptional},
  3000  	{"egress_region", isRegionCode, requestParamOptional},
  3001  	{"session_id", isHexDigits, 0},
  3002  	{"last_connected", isLastConnected, requestParamOptional},
  3003  }
  3004  
  3005  func (sshClient *sshClient) logBlocklistHits(IP net.IP, domain string, tags []BlocklistTag) {
  3006  
  3007  	sshClient.Lock()
  3008  
  3009  	logFields := getRequestLogFields(
  3010  		"server_blocklist_hit",
  3011  		sshClient.geoIPData,
  3012  		sshClient.handshakeState.authorizedAccessTypes,
  3013  		sshClient.handshakeState.apiParams,
  3014  		blocklistHitsStatParams)
  3015  
  3016  	logFields["session_id"] = sshClient.sessionID
  3017  
  3018  	// Note: see comment in logTunnel regarding unlock and concurrent access.
  3019  
  3020  	sshClient.Unlock()
  3021  
  3022  	for _, tag := range tags {
  3023  		if IP != nil {
  3024  			logFields["blocklist_ip_address"] = IP.String()
  3025  		}
  3026  		if domain != "" {
  3027  			logFields["blocklist_domain"] = domain
  3028  		}
  3029  		logFields["blocklist_source"] = tag.Source
  3030  		logFields["blocklist_subject"] = tag.Subject
  3031  
  3032  		log.LogRawFieldsWithTimestamp(logFields)
  3033  	}
  3034  }
  3035  
  3036  func (sshClient *sshClient) runOSLSender() {
  3037  
  3038  	for {
  3039  		// Await a signal that there are SLOKs to send
  3040  		// TODO: use reflect.SelectCase, and optionally await timer here?
  3041  		select {
  3042  		case <-sshClient.signalIssueSLOKs:
  3043  		case <-sshClient.runCtx.Done():
  3044  			return
  3045  		}
  3046  
  3047  		retryDelay := SSH_SEND_OSL_INITIAL_RETRY_DELAY
  3048  		for {
  3049  			err := sshClient.sendOSLRequest()
  3050  			if err == nil {
  3051  				break
  3052  			}
  3053  			if !isExpectedTunnelIOError(err) {
  3054  				log.WithTraceFields(LogFields{"error": err}).Warning("sendOSLRequest failed")
  3055  			}
  3056  
  3057  			// If the request failed, retry after a delay (with exponential backoff)
  3058  			// or when signaled that there are additional SLOKs to send
  3059  			retryTimer := time.NewTimer(retryDelay)
  3060  			select {
  3061  			case <-retryTimer.C:
  3062  			case <-sshClient.signalIssueSLOKs:
  3063  			case <-sshClient.runCtx.Done():
  3064  				retryTimer.Stop()
  3065  				return
  3066  			}
  3067  			retryTimer.Stop()
  3068  			retryDelay *= SSH_SEND_OSL_RETRY_FACTOR
  3069  		}
  3070  	}
  3071  }
  3072  
  3073  // sendOSLRequest will invoke osl.GetSeedPayload to issue SLOKs and
  3074  // generate a payload, and send an OSL request to the client when
  3075  // there are new SLOKs in the payload.
  3076  func (sshClient *sshClient) sendOSLRequest() error {
  3077  
  3078  	seedPayload := sshClient.getOSLSeedPayload()
  3079  
  3080  	// Don't send when no SLOKs. This will happen when signalIssueSLOKs
  3081  	// is received but no new SLOKs are issued.
  3082  	if len(seedPayload.SLOKs) == 0 {
  3083  		return nil
  3084  	}
  3085  
  3086  	oslRequest := protocol.OSLRequest{
  3087  		SeedPayload: seedPayload,
  3088  	}
  3089  	requestPayload, err := json.Marshal(oslRequest)
  3090  	if err != nil {
  3091  		return errors.Trace(err)
  3092  	}
  3093  
  3094  	ok, _, err := sshClient.sshConn.SendRequest(
  3095  		protocol.PSIPHON_API_OSL_REQUEST_NAME,
  3096  		true,
  3097  		requestPayload)
  3098  	if err != nil {
  3099  		return errors.Trace(err)
  3100  	}
  3101  	if !ok {
  3102  		return errors.TraceNew("client rejected request")
  3103  	}
  3104  
  3105  	sshClient.clearOSLSeedPayload()
  3106  
  3107  	return nil
  3108  }
  3109  
  3110  // runAlertSender dequeues and sends alert requests to the client. As these
  3111  // alerts are informational, there is no retry logic and no SSH client
  3112  // acknowledgement (wantReply) is requested. This worker scheme allows
  3113  // nonconcurrent components including udpgw and packet tunnel to enqueue
  3114  // alerts without blocking their traffic processing.
  3115  func (sshClient *sshClient) runAlertSender() {
  3116  	for {
  3117  		select {
  3118  		case <-sshClient.runCtx.Done():
  3119  			return
  3120  
  3121  		case request := <-sshClient.sendAlertRequests:
  3122  			payload, err := json.Marshal(request)
  3123  			if err != nil {
  3124  				log.WithTraceFields(LogFields{"error": err}).Warning("Marshal failed")
  3125  				break
  3126  			}
  3127  			_, _, err = sshClient.sshConn.SendRequest(
  3128  				protocol.PSIPHON_API_ALERT_REQUEST_NAME,
  3129  				false,
  3130  				payload)
  3131  			if err != nil && !isExpectedTunnelIOError(err) {
  3132  				log.WithTraceFields(LogFields{"error": err}).Warning("SendRequest failed")
  3133  				break
  3134  			}
  3135  			sshClient.Lock()
  3136  			sshClient.sentAlertRequests[fmt.Sprintf("%+v", request)] = true
  3137  			sshClient.Unlock()
  3138  		}
  3139  	}
  3140  }
  3141  
  3142  // enqueueAlertRequest enqueues an alert request to be sent to the client.
  3143  // Only one request is sent per tunnel per protocol.AlertRequest value;
  3144  // subsequent alerts with the same value are dropped. enqueueAlertRequest will
  3145  // not block until the queue exceeds ALERT_REQUEST_QUEUE_BUFFER_SIZE.
  3146  func (sshClient *sshClient) enqueueAlertRequest(request protocol.AlertRequest) {
  3147  	sshClient.Lock()
  3148  	if sshClient.sentAlertRequests[fmt.Sprintf("%+v", request)] {
  3149  		sshClient.Unlock()
  3150  		return
  3151  	}
  3152  	sshClient.Unlock()
  3153  	select {
  3154  	case <-sshClient.runCtx.Done():
  3155  	case sshClient.sendAlertRequests <- request:
  3156  	}
  3157  }
  3158  
  3159  func (sshClient *sshClient) enqueueDisallowedTrafficAlertRequest() {
  3160  
  3161  	reason := protocol.PSIPHON_API_ALERT_DISALLOWED_TRAFFIC
  3162  	actionURLs := sshClient.getAlertActionURLs(reason)
  3163  
  3164  	sshClient.enqueueAlertRequest(
  3165  		protocol.AlertRequest{
  3166  			Reason:     reason,
  3167  			ActionURLs: actionURLs,
  3168  		})
  3169  }
  3170  
  3171  func (sshClient *sshClient) enqueueUnsafeTrafficAlertRequest(tags []BlocklistTag) {
  3172  
  3173  	reason := protocol.PSIPHON_API_ALERT_UNSAFE_TRAFFIC
  3174  	actionURLs := sshClient.getAlertActionURLs(reason)
  3175  
  3176  	for _, tag := range tags {
  3177  		sshClient.enqueueAlertRequest(
  3178  			protocol.AlertRequest{
  3179  				Reason:     reason,
  3180  				Subject:    tag.Subject,
  3181  				ActionURLs: actionURLs,
  3182  			})
  3183  	}
  3184  }
  3185  
  3186  func (sshClient *sshClient) getAlertActionURLs(alertReason string) []string {
  3187  
  3188  	sshClient.Lock()
  3189  	sponsorID, _ := getStringRequestParam(
  3190  		sshClient.handshakeState.apiParams, "sponsor_id")
  3191  	sshClient.Unlock()
  3192  
  3193  	return sshClient.sshServer.support.PsinetDatabase.GetAlertActionURLs(
  3194  		alertReason,
  3195  		sponsorID,
  3196  		sshClient.geoIPData.Country,
  3197  		sshClient.geoIPData.ASN)
  3198  }
  3199  
  3200  func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, logMessage string) {
  3201  
  3202  	// We always return the reject reason "Prohibited":
  3203  	// - Traffic rules and connection limits may prohibit the connection.
  3204  	// - External firewall rules may prohibit the connection, and this is not currently
  3205  	//   distinguishable from other failure modes.
  3206  	// - We limit the failure information revealed to the client.
  3207  	reason := ssh.Prohibited
  3208  
  3209  	// Note: Debug level, as logMessage may contain user traffic destination address information
  3210  	log.WithTraceFields(
  3211  		LogFields{
  3212  			"channelType":  newChannel.ChannelType(),
  3213  			"logMessage":   logMessage,
  3214  			"rejectReason": reason.String(),
  3215  		}).Debug("reject new channel")
  3216  
  3217  	// Note: logMessage is internal, for logging only; just the reject reason is sent to the client.
  3218  	newChannel.Reject(reason, reason.String())
  3219  }
  3220  
  3221  // setHandshakeState records that a client has completed a handshake API request.
  3222  // Some parameters from the handshake request may be used in future traffic rule
  3223  // selection. Port forwards are disallowed until a handshake is complete. The
  3224  // handshake parameters are included in the session summary log recorded in
  3225  // sshClient.stop().
  3226  func (sshClient *sshClient) setHandshakeState(
  3227  	state handshakeState,
  3228  	authorizations []string) (*handshakeStateInfo, error) {
  3229  
  3230  	sshClient.Lock()
  3231  	completed := sshClient.handshakeState.completed
  3232  	if !completed {
  3233  		sshClient.handshakeState = state
  3234  	}
  3235  	sshClient.Unlock()
  3236  
  3237  	// Client must only perform one handshake
  3238  	if completed {
  3239  		return nil, errors.TraceNew("handshake already completed")
  3240  	}
  3241  
  3242  	// Verify the authorizations submitted by the client. Verified, active
  3243  	// (non-expired) access types will be available for traffic rules
  3244  	// filtering.
  3245  	//
  3246  	// When an authorization is active but expires while the client is
  3247  	// connected, the client is disconnected to ensure the access is reset.
  3248  	// This is implemented by setting a timer to perform the disconnect at the
  3249  	// expiry time of the soonest expiring authorization.
  3250  	//
  3251  	// sshServer.authorizationSessionIDs tracks the unique mapping of active
  3252  	// authorization IDs to client session IDs  and is used to detect and
  3253  	// prevent multiple malicious clients from reusing a single authorization
  3254  	// (within the scope of this server).
  3255  
  3256  	// authorizationIDs and authorizedAccessTypes are returned to the client
  3257  	// and logged, respectively; initialize to empty lists so the
  3258  	// protocol/logs don't need to handle 'null' values.
  3259  	authorizationIDs := make([]string, 0)
  3260  	authorizedAccessTypes := make([]string, 0)
  3261  	var stopTime time.Time
  3262  
  3263  	for i, authorization := range authorizations {
  3264  
  3265  		// This sanity check mitigates malicious clients causing excess CPU use.
  3266  		if i >= MAX_AUTHORIZATIONS {
  3267  			log.WithTrace().Warning("too many authorizations")
  3268  			break
  3269  		}
  3270  
  3271  		verifiedAuthorization, err := accesscontrol.VerifyAuthorization(
  3272  			&sshClient.sshServer.support.Config.AccessControlVerificationKeyRing,
  3273  			authorization)
  3274  
  3275  		if err != nil {
  3276  			log.WithTraceFields(
  3277  				LogFields{"error": err}).Warning("verify authorization failed")
  3278  			continue
  3279  		}
  3280  
  3281  		authorizationID := base64.StdEncoding.EncodeToString(verifiedAuthorization.ID)
  3282  
  3283  		if common.Contains(authorizedAccessTypes, verifiedAuthorization.AccessType) {
  3284  			log.WithTraceFields(
  3285  				LogFields{"accessType": verifiedAuthorization.AccessType}).Warning("duplicate authorization access type")
  3286  			continue
  3287  		}
  3288  
  3289  		authorizationIDs = append(authorizationIDs, authorizationID)
  3290  		authorizedAccessTypes = append(authorizedAccessTypes, verifiedAuthorization.AccessType)
  3291  
  3292  		if stopTime.IsZero() || stopTime.After(verifiedAuthorization.Expires) {
  3293  			stopTime = verifiedAuthorization.Expires
  3294  		}
  3295  	}
  3296  
  3297  	// Associate all verified authorizationIDs with this client's session ID.
  3298  	// Handle cases where previous associations exist:
  3299  	//
  3300  	// - Multiple malicious clients reusing a single authorization. In this
  3301  	//   case, authorizations are revoked from the previous client.
  3302  	//
  3303  	// - The client reconnected with a new session ID due to user toggling.
  3304  	//   This case is expected due to server affinity. This cannot be
  3305  	//   distinguished from the previous case and the same action is taken;
  3306  	//   this will have no impact on a legitimate client as the previous
  3307  	//   session is dangling.
  3308  	//
  3309  	// - The client automatically reconnected with the same session ID. This
  3310  	//   case is not expected as sshServer.registerEstablishedClient
  3311  	//   synchronously calls sshClient.releaseAuthorizations; as a safe guard,
  3312  	//   this case is distinguished and no revocation action is taken.
  3313  
  3314  	sshClient.sshServer.authorizationSessionIDsMutex.Lock()
  3315  	for _, authorizationID := range authorizationIDs {
  3316  		sessionID, ok := sshClient.sshServer.authorizationSessionIDs[authorizationID]
  3317  		if ok && sessionID != sshClient.sessionID {
  3318  
  3319  			logFields := LogFields{
  3320  				"event_name":                 "irregular_tunnel",
  3321  				"tunnel_error":               "duplicate active authorization",
  3322  				"duplicate_authorization_id": authorizationID,
  3323  			}
  3324  			sshClient.geoIPData.SetLogFields(logFields)
  3325  			duplicateGeoIPData := sshClient.sshServer.support.GeoIPService.GetSessionCache(sessionID)
  3326  			if duplicateGeoIPData != sshClient.geoIPData {
  3327  				duplicateGeoIPData.SetLogFieldsWithPrefix("duplicate_authorization_", logFields)
  3328  			}
  3329  			log.LogRawFieldsWithTimestamp(logFields)
  3330  
  3331  			// Invoke asynchronously to avoid deadlocks.
  3332  			// TODO: invoke only once for each distinct sessionID?
  3333  			go sshClient.sshServer.revokeClientAuthorizations(sessionID)
  3334  		}
  3335  		sshClient.sshServer.authorizationSessionIDs[authorizationID] = sshClient.sessionID
  3336  	}
  3337  	sshClient.sshServer.authorizationSessionIDsMutex.Unlock()
  3338  
  3339  	if len(authorizationIDs) > 0 {
  3340  
  3341  		sshClient.Lock()
  3342  
  3343  		// Make the authorizedAccessTypes available for traffic rules filtering.
  3344  
  3345  		sshClient.handshakeState.activeAuthorizationIDs = authorizationIDs
  3346  		sshClient.handshakeState.authorizedAccessTypes = authorizedAccessTypes
  3347  
  3348  		// On exit, sshClient.runTunnel will call releaseAuthorizations, which
  3349  		// will release the authorization IDs so the client can reconnect and
  3350  		// present the same authorizations again. sshClient.runTunnel will
  3351  		// also cancel the stopTimer in case it has not yet fired.
  3352  		// Note: termination of the stopTimer goroutine is not synchronized.
  3353  
  3354  		sshClient.releaseAuthorizations = func() {
  3355  			sshClient.sshServer.authorizationSessionIDsMutex.Lock()
  3356  			for _, authorizationID := range authorizationIDs {
  3357  				sessionID, ok := sshClient.sshServer.authorizationSessionIDs[authorizationID]
  3358  				if ok && sessionID == sshClient.sessionID {
  3359  					delete(sshClient.sshServer.authorizationSessionIDs, authorizationID)
  3360  				}
  3361  			}
  3362  			sshClient.sshServer.authorizationSessionIDsMutex.Unlock()
  3363  		}
  3364  
  3365  		sshClient.stopTimer = time.AfterFunc(
  3366  			time.Until(stopTime),
  3367  			func() {
  3368  				sshClient.stop()
  3369  			})
  3370  
  3371  		sshClient.Unlock()
  3372  	}
  3373  
  3374  	upstreamBytesPerSecond, downstreamBytesPerSecond := sshClient.setTrafficRules()
  3375  
  3376  	sshClient.setOSLConfig()
  3377  
  3378  	// Set destination bytes metrics.
  3379  	//
  3380  	// Limitation: this is a one-time operation and doesn't get reset when
  3381  	// tactics are hot-reloaded. This allows us to simply retain any
  3382  	// destination byte counts accumulated and eventually log in
  3383  	// server_tunnel, without having to deal with a destination change
  3384  	// mid-tunnel. As typical tunnels are short, and destination changes can
  3385  	// be applied gradually, handling mid-tunnel changes is not a priority.
  3386  	sshClient.setDestinationBytesMetrics()
  3387  
  3388  	return &handshakeStateInfo{
  3389  		activeAuthorizationIDs:   authorizationIDs,
  3390  		authorizedAccessTypes:    authorizedAccessTypes,
  3391  		upstreamBytesPerSecond:   upstreamBytesPerSecond,
  3392  		downstreamBytesPerSecond: downstreamBytesPerSecond,
  3393  	}, nil
  3394  }
  3395  
  3396  // getHandshaked returns whether the client has completed a handshake API
  3397  // request and whether the traffic rules that were selected after the
  3398  // handshake immediately exhaust the client.
  3399  //
  3400  // When the client is immediately exhausted it will be closed; but this
  3401  // takes effect asynchronously. The "exhausted" return value is used to
  3402  // prevent API requests by clients that will close.
  3403  func (sshClient *sshClient) getHandshaked() (bool, bool) {
  3404  	sshClient.Lock()
  3405  	defer sshClient.Unlock()
  3406  
  3407  	completed := sshClient.handshakeState.completed
  3408  
  3409  	exhausted := false
  3410  
  3411  	// Notes:
  3412  	// - "Immediately exhausted" is when CloseAfterExhausted is set and
  3413  	//   either ReadUnthrottledBytes or WriteUnthrottledBytes starts from
  3414  	//   0, so no bytes would be read or written. This check does not
  3415  	//   examine whether 0 bytes _remain_ in the ThrottledConn.
  3416  	// - This check is made against the current traffic rules, which
  3417  	//   could have changed in a hot reload since the handshake.
  3418  
  3419  	if completed &&
  3420  		*sshClient.trafficRules.RateLimits.CloseAfterExhausted &&
  3421  		(*sshClient.trafficRules.RateLimits.ReadUnthrottledBytes == 0 ||
  3422  			*sshClient.trafficRules.RateLimits.WriteUnthrottledBytes == 0) {
  3423  
  3424  		exhausted = true
  3425  	}
  3426  
  3427  	return completed, exhausted
  3428  }
  3429  
  3430  func (sshClient *sshClient) getDisableDiscovery() bool {
  3431  	sshClient.Lock()
  3432  	defer sshClient.Unlock()
  3433  
  3434  	return *sshClient.trafficRules.DisableDiscovery
  3435  }
  3436  
  3437  func (sshClient *sshClient) updateAPIParameters(
  3438  	apiParams common.APIParameters) {
  3439  
  3440  	sshClient.Lock()
  3441  	defer sshClient.Unlock()
  3442  
  3443  	// Only update after handshake has initialized API params.
  3444  	if !sshClient.handshakeState.completed {
  3445  		return
  3446  	}
  3447  
  3448  	for name, value := range apiParams {
  3449  		sshClient.handshakeState.apiParams[name] = value
  3450  	}
  3451  }
  3452  
  3453  func (sshClient *sshClient) acceptDomainBytes() bool {
  3454  	sshClient.Lock()
  3455  	defer sshClient.Unlock()
  3456  
  3457  	// When the domain bytes checksum differs from the checksum sent to the
  3458  	// client in the handshake response, the psinet regex configuration has
  3459  	// changed. In this case, drop the stats so we don't continue to record
  3460  	// stats as previously configured.
  3461  	//
  3462  	// Limitations:
  3463  	// - The checksum comparison may result in dropping some stats for a
  3464  	//   domain that remains in the new configuration.
  3465  	// - We don't push new regexs to the clients, so clients that remain
  3466  	//   connected will continue to send stats that will be dropped; and
  3467  	//   those clients will not send stats as newly configured until after
  3468  	//   reconnecting.
  3469  	// - Due to the design of
  3470  	//   transferstats.ReportRecentBytesTransferredForServer in the client,
  3471  	//   the client may accumulate stats, reconnect before its next status
  3472  	//   request, get a new regex configuration, and then send the previously
  3473  	//   accumulated stats in its next status request. The checksum scheme
  3474  	//   won't prevent the reporting of those stats.
  3475  
  3476  	sponsorID, _ := getStringRequestParam(sshClient.handshakeState.apiParams, "sponsor_id")
  3477  
  3478  	domainBytesChecksum := sshClient.sshServer.support.PsinetDatabase.GetDomainBytesChecksum(sponsorID)
  3479  
  3480  	return bytes.Equal(sshClient.handshakeState.domainBytesChecksum, domainBytesChecksum)
  3481  }
  3482  
  3483  // setOSLConfig resets the client's OSL seed state based on the latest OSL config
  3484  // As sshClient.oslClientSeedState may be reset by a concurrent goroutine,
  3485  // oslClientSeedState must only be accessed within the sshClient mutex.
  3486  func (sshClient *sshClient) setOSLConfig() {
  3487  	sshClient.Lock()
  3488  	defer sshClient.Unlock()
  3489  
  3490  	propagationChannelID, err := getStringRequestParam(
  3491  		sshClient.handshakeState.apiParams, "propagation_channel_id")
  3492  	if err != nil {
  3493  		// This should not fail as long as client has sent valid handshake
  3494  		return
  3495  	}
  3496  
  3497  	// Use a cached seed state if one is found for the client's
  3498  	// session ID. This enables resuming progress made in a previous
  3499  	// tunnel.
  3500  	// Note: go-cache is already concurency safe; the additional mutex
  3501  	// is necessary to guarantee that Get/Delete is atomic; although in
  3502  	// practice no two concurrent clients should ever supply the same
  3503  	// session ID.
  3504  
  3505  	sshClient.sshServer.oslSessionCacheMutex.Lock()
  3506  	oslClientSeedState, found := sshClient.sshServer.oslSessionCache.Get(sshClient.sessionID)
  3507  	if found {
  3508  		sshClient.sshServer.oslSessionCache.Delete(sshClient.sessionID)
  3509  		sshClient.sshServer.oslSessionCacheMutex.Unlock()
  3510  		sshClient.oslClientSeedState = oslClientSeedState.(*osl.ClientSeedState)
  3511  		sshClient.oslClientSeedState.Resume(sshClient.signalIssueSLOKs)
  3512  		return
  3513  	}
  3514  	sshClient.sshServer.oslSessionCacheMutex.Unlock()
  3515  
  3516  	// Two limitations when setOSLConfig() is invoked due to an
  3517  	// OSL config hot reload:
  3518  	//
  3519  	// 1. any partial progress towards SLOKs is lost.
  3520  	//
  3521  	// 2. all existing osl.ClientSeedPortForwards for existing
  3522  	//    port forwards will not send progress to the new client
  3523  	//    seed state.
  3524  
  3525  	sshClient.oslClientSeedState = sshClient.sshServer.support.OSLConfig.NewClientSeedState(
  3526  		sshClient.geoIPData.Country,
  3527  		propagationChannelID,
  3528  		sshClient.signalIssueSLOKs)
  3529  }
  3530  
  3531  // newClientSeedPortForward will return nil when no seeding is
  3532  // associated with the specified ipAddress.
  3533  func (sshClient *sshClient) newClientSeedPortForward(IPAddress net.IP) *osl.ClientSeedPortForward {
  3534  	sshClient.Lock()
  3535  	defer sshClient.Unlock()
  3536  
  3537  	// Will not be initialized before handshake.
  3538  	if sshClient.oslClientSeedState == nil {
  3539  		return nil
  3540  	}
  3541  
  3542  	return sshClient.oslClientSeedState.NewClientSeedPortForward(IPAddress)
  3543  }
  3544  
  3545  // getOSLSeedPayload returns a payload containing all seeded SLOKs for
  3546  // this client's session.
  3547  func (sshClient *sshClient) getOSLSeedPayload() *osl.SeedPayload {
  3548  	sshClient.Lock()
  3549  	defer sshClient.Unlock()
  3550  
  3551  	// Will not be initialized before handshake.
  3552  	if sshClient.oslClientSeedState == nil {
  3553  		return &osl.SeedPayload{SLOKs: make([]*osl.SLOK, 0)}
  3554  	}
  3555  
  3556  	return sshClient.oslClientSeedState.GetSeedPayload()
  3557  }
  3558  
  3559  func (sshClient *sshClient) clearOSLSeedPayload() {
  3560  	sshClient.Lock()
  3561  	defer sshClient.Unlock()
  3562  
  3563  	sshClient.oslClientSeedState.ClearSeedPayload()
  3564  }
  3565  
  3566  func (sshClient *sshClient) setDestinationBytesMetrics() {
  3567  	sshClient.Lock()
  3568  	defer sshClient.Unlock()
  3569  
  3570  	// Limitation: the server-side tactics cache is used to avoid the overhead
  3571  	// of an additional tactics filtering per tunnel. As this cache is
  3572  	// designed for GeoIP filtering only, handshake API parameters are not
  3573  	// applied to tactics filtering in this case.
  3574  
  3575  	tacticsCache := sshClient.sshServer.support.ServerTacticsParametersCache
  3576  	if tacticsCache == nil {
  3577  		return
  3578  	}
  3579  
  3580  	p, err := tacticsCache.Get(sshClient.geoIPData)
  3581  	if err != nil {
  3582  		log.WithTraceFields(LogFields{"error": err}).Warning("get tactics failed")
  3583  		return
  3584  	}
  3585  	if p.IsNil() {
  3586  		return
  3587  	}
  3588  
  3589  	sshClient.destinationBytesMetricsASN = p.String(parameters.DestinationBytesMetricsASN)
  3590  }
  3591  
  3592  func (sshClient *sshClient) newDestinationBytesMetricsUpdater(portForwardType int, IPAddress net.IP) *destinationBytesMetrics {
  3593  	sshClient.Lock()
  3594  	defer sshClient.Unlock()
  3595  
  3596  	if sshClient.destinationBytesMetricsASN == "" {
  3597  		return nil
  3598  	}
  3599  
  3600  	if sshClient.sshServer.support.GeoIPService.LookupISPForIP(IPAddress).ASN != sshClient.destinationBytesMetricsASN {
  3601  		return nil
  3602  	}
  3603  
  3604  	if portForwardType == portForwardTypeTCP {
  3605  		return &sshClient.tcpDestinationBytesMetrics
  3606  	}
  3607  
  3608  	return &sshClient.udpDestinationBytesMetrics
  3609  }
  3610  
  3611  func (sshClient *sshClient) getActivityUpdaters(portForwardType int, IPAddress net.IP) []common.ActivityUpdater {
  3612  	var updaters []common.ActivityUpdater
  3613  
  3614  	clientSeedPortForward := sshClient.newClientSeedPortForward(IPAddress)
  3615  	if clientSeedPortForward != nil {
  3616  		updaters = append(updaters, clientSeedPortForward)
  3617  	}
  3618  
  3619  	destinationBytesMetrics := sshClient.newDestinationBytesMetricsUpdater(portForwardType, IPAddress)
  3620  	if destinationBytesMetrics != nil {
  3621  		updaters = append(updaters, destinationBytesMetrics)
  3622  	}
  3623  
  3624  	return updaters
  3625  }
  3626  
  3627  // setTrafficRules resets the client's traffic rules based on the latest server config
  3628  // and client properties. As sshClient.trafficRules may be reset by a concurrent
  3629  // goroutine, trafficRules must only be accessed within the sshClient mutex.
  3630  func (sshClient *sshClient) setTrafficRules() (int64, int64) {
  3631  	sshClient.Lock()
  3632  	defer sshClient.Unlock()
  3633  
  3634  	isFirstTunnelInSession := sshClient.isFirstTunnelInSession &&
  3635  		sshClient.handshakeState.establishedTunnelsCount == 0
  3636  
  3637  	sshClient.trafficRules = sshClient.sshServer.support.TrafficRulesSet.GetTrafficRules(
  3638  		isFirstTunnelInSession,
  3639  		sshClient.tunnelProtocol,
  3640  		sshClient.geoIPData,
  3641  		sshClient.handshakeState)
  3642  
  3643  	if sshClient.throttledConn != nil {
  3644  		// Any existing throttling state is reset.
  3645  		sshClient.throttledConn.SetLimits(
  3646  			sshClient.trafficRules.RateLimits.CommonRateLimits(
  3647  				sshClient.handshakeState.completed))
  3648  	}
  3649  
  3650  	return *sshClient.trafficRules.RateLimits.ReadBytesPerSecond,
  3651  		*sshClient.trafficRules.RateLimits.WriteBytesPerSecond
  3652  }
  3653  
  3654  func (sshClient *sshClient) rateLimits() common.RateLimits {
  3655  	sshClient.Lock()
  3656  	defer sshClient.Unlock()
  3657  
  3658  	return sshClient.trafficRules.RateLimits.CommonRateLimits(
  3659  		sshClient.handshakeState.completed)
  3660  }
  3661  
  3662  func (sshClient *sshClient) idleTCPPortForwardTimeout() time.Duration {
  3663  	sshClient.Lock()
  3664  	defer sshClient.Unlock()
  3665  
  3666  	return time.Duration(*sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds) * time.Millisecond
  3667  }
  3668  
  3669  func (sshClient *sshClient) idleUDPPortForwardTimeout() time.Duration {
  3670  	sshClient.Lock()
  3671  	defer sshClient.Unlock()
  3672  
  3673  	return time.Duration(*sshClient.trafficRules.IdleUDPPortForwardTimeoutMilliseconds) * time.Millisecond
  3674  }
  3675  
  3676  func (sshClient *sshClient) setTCPPortForwardDialingAvailableSignal(signal context.CancelFunc) {
  3677  	sshClient.Lock()
  3678  	defer sshClient.Unlock()
  3679  
  3680  	sshClient.tcpPortForwardDialingAvailableSignal = signal
  3681  }
  3682  
  3683  const (
  3684  	portForwardTypeTCP = iota
  3685  	portForwardTypeUDP
  3686  )
  3687  
  3688  func (sshClient *sshClient) isPortForwardPermitted(
  3689  	portForwardType int,
  3690  	remoteIP net.IP,
  3691  	port int) bool {
  3692  
  3693  	// Disallow connection to bogons.
  3694  	//
  3695  	// As a security measure, this is a failsafe. The server should be run on a
  3696  	// host with correctly configured firewall rules.
  3697  	//
  3698  	// This check also avoids spurious disallowed traffic alerts for destinations
  3699  	// that are impossible to reach.
  3700  
  3701  	if !sshClient.sshServer.support.Config.AllowBogons && common.IsBogon(remoteIP) {
  3702  		return false
  3703  	}
  3704  
  3705  	// Blocklist check.
  3706  	//
  3707  	// Limitation: isPortForwardPermitted is not called in transparent DNS
  3708  	// forwarding cases. As the destination IP address is rewritten in these
  3709  	// cases, a blocklist entry won't be dialed in any case. However, no logs
  3710  	// will be recorded.
  3711  
  3712  	if !sshClient.isIPPermitted(remoteIP) {
  3713  		return false
  3714  	}
  3715  
  3716  	// Don't lock before calling logBlocklistHits.
  3717  	// Unlock before calling enqueueDisallowedTrafficAlertRequest/log.
  3718  
  3719  	sshClient.Lock()
  3720  
  3721  	allowed := true
  3722  
  3723  	// Client must complete handshake before port forwards are permitted.
  3724  	if !sshClient.handshakeState.completed {
  3725  		allowed = false
  3726  	}
  3727  
  3728  	if allowed {
  3729  		// Traffic rules checks.
  3730  		switch portForwardType {
  3731  		case portForwardTypeTCP:
  3732  			if !sshClient.trafficRules.AllowTCPPort(remoteIP, port) {
  3733  				allowed = false
  3734  			}
  3735  		case portForwardTypeUDP:
  3736  			if !sshClient.trafficRules.AllowUDPPort(remoteIP, port) {
  3737  				allowed = false
  3738  			}
  3739  		}
  3740  	}
  3741  
  3742  	sshClient.Unlock()
  3743  
  3744  	if allowed {
  3745  		return true
  3746  	}
  3747  
  3748  	switch portForwardType {
  3749  	case portForwardTypeTCP:
  3750  		sshClient.updateQualityMetricsWithTCPRejectedDisallowed()
  3751  	case portForwardTypeUDP:
  3752  		sshClient.updateQualityMetricsWithUDPRejectedDisallowed()
  3753  	}
  3754  
  3755  	sshClient.enqueueDisallowedTrafficAlertRequest()
  3756  
  3757  	log.WithTraceFields(
  3758  		LogFields{
  3759  			"type": portForwardType,
  3760  			"port": port,
  3761  		}).Debug("port forward denied by traffic rules")
  3762  
  3763  	return false
  3764  }
  3765  
  3766  // isDomainPermitted returns true when the specified domain may be resolved
  3767  // and returns false and a reject reason otherwise.
  3768  func (sshClient *sshClient) isDomainPermitted(domain string) (bool, string) {
  3769  
  3770  	// We're not doing comprehensive validation, to avoid overhead per port
  3771  	// forward. This is a simple sanity check to ensure we don't process
  3772  	// blantantly invalid input.
  3773  	//
  3774  	// TODO: validate with dns.IsDomainName?
  3775  	if len(domain) > 255 {
  3776  		return false, "invalid domain name"
  3777  	}
  3778  
  3779  	tags := sshClient.sshServer.support.Blocklist.LookupDomain(domain)
  3780  	if len(tags) > 0 {
  3781  
  3782  		sshClient.logBlocklistHits(nil, domain, tags)
  3783  
  3784  		if sshClient.sshServer.support.Config.BlocklistActive {
  3785  			// Actively alert and block
  3786  			sshClient.enqueueUnsafeTrafficAlertRequest(tags)
  3787  			return false, "port forward not permitted"
  3788  		}
  3789  	}
  3790  
  3791  	return true, ""
  3792  }
  3793  
  3794  func (sshClient *sshClient) isIPPermitted(remoteIP net.IP) bool {
  3795  
  3796  	tags := sshClient.sshServer.support.Blocklist.LookupIP(remoteIP)
  3797  	if len(tags) > 0 {
  3798  
  3799  		sshClient.logBlocklistHits(remoteIP, "", tags)
  3800  
  3801  		if sshClient.sshServer.support.Config.BlocklistActive {
  3802  			// Actively alert and block
  3803  			sshClient.enqueueUnsafeTrafficAlertRequest(tags)
  3804  			return false
  3805  		}
  3806  	}
  3807  
  3808  	return true
  3809  }
  3810  
  3811  func (sshClient *sshClient) isTCPDialingPortForwardLimitExceeded() bool {
  3812  
  3813  	sshClient.Lock()
  3814  	defer sshClient.Unlock()
  3815  
  3816  	state := &sshClient.tcpTrafficState
  3817  	max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
  3818  
  3819  	if max > 0 && state.concurrentDialingPortForwardCount >= int64(max) {
  3820  		return true
  3821  	}
  3822  	return false
  3823  }
  3824  
  3825  func (sshClient *sshClient) getTCPPortForwardQueueSize() int {
  3826  
  3827  	sshClient.Lock()
  3828  	defer sshClient.Unlock()
  3829  
  3830  	return *sshClient.trafficRules.MaxTCPPortForwardCount +
  3831  		*sshClient.trafficRules.MaxTCPDialingPortForwardCount
  3832  }
  3833  
  3834  func (sshClient *sshClient) getDialTCPPortForwardTimeoutMilliseconds() int {
  3835  
  3836  	sshClient.Lock()
  3837  	defer sshClient.Unlock()
  3838  
  3839  	return *sshClient.trafficRules.DialTCPPortForwardTimeoutMilliseconds
  3840  }
  3841  
  3842  func (sshClient *sshClient) dialingTCPPortForward() {
  3843  
  3844  	sshClient.Lock()
  3845  	defer sshClient.Unlock()
  3846  
  3847  	state := &sshClient.tcpTrafficState
  3848  
  3849  	state.concurrentDialingPortForwardCount += 1
  3850  	if state.concurrentDialingPortForwardCount > state.peakConcurrentDialingPortForwardCount {
  3851  		state.peakConcurrentDialingPortForwardCount = state.concurrentDialingPortForwardCount
  3852  	}
  3853  }
  3854  
  3855  func (sshClient *sshClient) abortedTCPPortForward() {
  3856  
  3857  	sshClient.Lock()
  3858  	defer sshClient.Unlock()
  3859  
  3860  	sshClient.tcpTrafficState.concurrentDialingPortForwardCount -= 1
  3861  }
  3862  
  3863  func (sshClient *sshClient) allocatePortForward(portForwardType int) bool {
  3864  
  3865  	sshClient.Lock()
  3866  	defer sshClient.Unlock()
  3867  
  3868  	// Check if at port forward limit. The subsequent counter
  3869  	// changes must be atomic with the limit check to ensure
  3870  	// the counter never exceeds the limit in the case of
  3871  	// concurrent allocations.
  3872  
  3873  	var max int
  3874  	var state *trafficState
  3875  	if portForwardType == portForwardTypeTCP {
  3876  		max = *sshClient.trafficRules.MaxTCPPortForwardCount
  3877  		state = &sshClient.tcpTrafficState
  3878  	} else {
  3879  		max = *sshClient.trafficRules.MaxUDPPortForwardCount
  3880  		state = &sshClient.udpTrafficState
  3881  	}
  3882  
  3883  	if max > 0 && state.concurrentPortForwardCount >= int64(max) {
  3884  		return false
  3885  	}
  3886  
  3887  	// Update port forward counters.
  3888  
  3889  	if portForwardType == portForwardTypeTCP {
  3890  
  3891  		// Assumes TCP port forwards called dialingTCPPortForward
  3892  		state.concurrentDialingPortForwardCount -= 1
  3893  
  3894  		if sshClient.tcpPortForwardDialingAvailableSignal != nil {
  3895  
  3896  			max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount
  3897  			if max <= 0 || state.concurrentDialingPortForwardCount < int64(max) {
  3898  				sshClient.tcpPortForwardDialingAvailableSignal()
  3899  			}
  3900  		}
  3901  	}
  3902  
  3903  	state.concurrentPortForwardCount += 1
  3904  	if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount {
  3905  		state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount
  3906  	}
  3907  	state.totalPortForwardCount += 1
  3908  
  3909  	return true
  3910  }
  3911  
  3912  // establishedPortForward increments the concurrent port
  3913  // forward counter. closedPortForward decrements it, so it
  3914  // must always be called for each establishedPortForward
  3915  // call.
  3916  //
  3917  // When at the limit of established port forwards, the LRU
  3918  // existing port forward is closed to make way for the newly
  3919  // established one. There can be a minor delay as, in addition
  3920  // to calling Close() on the port forward net.Conn,
  3921  // establishedPortForward waits for the LRU's closedPortForward()
  3922  // call which will decrement the concurrent counter. This
  3923  // ensures all resources associated with the LRU (socket,
  3924  // goroutine) are released or will very soon be released before
  3925  // proceeding.
  3926  func (sshClient *sshClient) establishedPortForward(
  3927  	portForwardType int, portForwardLRU *common.LRUConns) {
  3928  
  3929  	// Do not lock sshClient here.
  3930  
  3931  	var state *trafficState
  3932  	if portForwardType == portForwardTypeTCP {
  3933  		state = &sshClient.tcpTrafficState
  3934  	} else {
  3935  		state = &sshClient.udpTrafficState
  3936  	}
  3937  
  3938  	// When the maximum number of port forwards is already
  3939  	// established, close the LRU. CloseOldest will call
  3940  	// Close on the port forward net.Conn. Both TCP and
  3941  	// UDP port forwards have handler goroutines that may
  3942  	// be blocked calling Read on the net.Conn. Close will
  3943  	// eventually interrupt the Read and cause the handlers
  3944  	// to exit, but not immediately. So the following logic
  3945  	// waits for a LRU handler to be interrupted and signal
  3946  	// availability.
  3947  	//
  3948  	// Notes:
  3949  	//
  3950  	// - the port forward limit can change via a traffic
  3951  	//   rules hot reload; the condition variable handles
  3952  	//   this case whereas a channel-based semaphore would
  3953  	//   not.
  3954  	//
  3955  	// - if a number of goroutines exceeding the total limit
  3956  	//   arrive here all concurrently, some CloseOldest() calls
  3957  	//   will have no effect as there can be less existing port
  3958  	//   forwards than new ones. In this case, the new port
  3959  	//   forward will be delayed. This is highly unlikely in
  3960  	//   practise since UDP calls to establishedPortForward are
  3961  	//   serialized and TCP calls are limited by the dial
  3962  	//   queue/count.
  3963  
  3964  	if !sshClient.allocatePortForward(portForwardType) {
  3965  
  3966  		portForwardLRU.CloseOldest()
  3967  		log.WithTrace().Debug("closed LRU port forward")
  3968  
  3969  		state.availablePortForwardCond.L.Lock()
  3970  		for !sshClient.allocatePortForward(portForwardType) {
  3971  			state.availablePortForwardCond.Wait()
  3972  		}
  3973  		state.availablePortForwardCond.L.Unlock()
  3974  	}
  3975  }
  3976  
  3977  func (sshClient *sshClient) closedPortForward(
  3978  	portForwardType int, bytesUp, bytesDown int64) {
  3979  
  3980  	sshClient.Lock()
  3981  
  3982  	var state *trafficState
  3983  	if portForwardType == portForwardTypeTCP {
  3984  		state = &sshClient.tcpTrafficState
  3985  	} else {
  3986  		state = &sshClient.udpTrafficState
  3987  	}
  3988  
  3989  	state.concurrentPortForwardCount -= 1
  3990  	state.bytesUp += bytesUp
  3991  	state.bytesDown += bytesDown
  3992  
  3993  	sshClient.Unlock()
  3994  
  3995  	// Signal any goroutine waiting in establishedPortForward
  3996  	// that an established port forward slot is available.
  3997  	state.availablePortForwardCond.Signal()
  3998  }
  3999  
  4000  func (sshClient *sshClient) updateQualityMetricsWithDialResult(
  4001  	tcpPortForwardDialSuccess bool, dialDuration time.Duration, IP net.IP) {
  4002  
  4003  	sshClient.Lock()
  4004  	defer sshClient.Unlock()
  4005  
  4006  	if tcpPortForwardDialSuccess {
  4007  		sshClient.qualityMetrics.TCPPortForwardDialedCount += 1
  4008  		sshClient.qualityMetrics.TCPPortForwardDialedDuration += dialDuration
  4009  		if IP.To4() != nil {
  4010  			sshClient.qualityMetrics.TCPIPv4PortForwardDialedCount += 1
  4011  			sshClient.qualityMetrics.TCPIPv4PortForwardDialedDuration += dialDuration
  4012  		} else if IP != nil {
  4013  			sshClient.qualityMetrics.TCPIPv6PortForwardDialedCount += 1
  4014  			sshClient.qualityMetrics.TCPIPv6PortForwardDialedDuration += dialDuration
  4015  		}
  4016  	} else {
  4017  		sshClient.qualityMetrics.TCPPortForwardFailedCount += 1
  4018  		sshClient.qualityMetrics.TCPPortForwardFailedDuration += dialDuration
  4019  		if IP.To4() != nil {
  4020  			sshClient.qualityMetrics.TCPIPv4PortForwardFailedCount += 1
  4021  			sshClient.qualityMetrics.TCPIPv4PortForwardFailedDuration += dialDuration
  4022  		} else if IP != nil {
  4023  			sshClient.qualityMetrics.TCPIPv6PortForwardFailedCount += 1
  4024  			sshClient.qualityMetrics.TCPIPv6PortForwardFailedDuration += dialDuration
  4025  		}
  4026  	}
  4027  }
  4028  
  4029  func (sshClient *sshClient) updateQualityMetricsWithRejectedDialingLimit() {
  4030  
  4031  	sshClient.Lock()
  4032  	defer sshClient.Unlock()
  4033  
  4034  	sshClient.qualityMetrics.TCPPortForwardRejectedDialingLimitCount += 1
  4035  }
  4036  
  4037  func (sshClient *sshClient) updateQualityMetricsWithTCPRejectedDisallowed() {
  4038  
  4039  	sshClient.Lock()
  4040  	defer sshClient.Unlock()
  4041  
  4042  	sshClient.qualityMetrics.TCPPortForwardRejectedDisallowedCount += 1
  4043  }
  4044  
  4045  func (sshClient *sshClient) updateQualityMetricsWithUDPRejectedDisallowed() {
  4046  
  4047  	sshClient.Lock()
  4048  	defer sshClient.Unlock()
  4049  
  4050  	sshClient.qualityMetrics.UDPPortForwardRejectedDisallowedCount += 1
  4051  }
  4052  
  4053  func (sshClient *sshClient) updateQualityMetricsWithDNSResult(
  4054  	success bool, duration time.Duration, resolverIP net.IP) {
  4055  
  4056  	sshClient.Lock()
  4057  	defer sshClient.Unlock()
  4058  
  4059  	resolver := ""
  4060  	if resolverIP != nil {
  4061  		resolver = resolverIP.String()
  4062  	}
  4063  	if success {
  4064  		sshClient.qualityMetrics.DNSCount["ALL"] += 1
  4065  		sshClient.qualityMetrics.DNSDuration["ALL"] += duration
  4066  		if resolver != "" {
  4067  			sshClient.qualityMetrics.DNSCount[resolver] += 1
  4068  			sshClient.qualityMetrics.DNSDuration[resolver] += duration
  4069  		}
  4070  	} else {
  4071  		sshClient.qualityMetrics.DNSFailedCount["ALL"] += 1
  4072  		sshClient.qualityMetrics.DNSFailedDuration["ALL"] += duration
  4073  		if resolver != "" {
  4074  			sshClient.qualityMetrics.DNSFailedCount[resolver] += 1
  4075  			sshClient.qualityMetrics.DNSFailedDuration[resolver] += duration
  4076  		}
  4077  	}
  4078  }
  4079  
  4080  func (sshClient *sshClient) handleTCPChannel(
  4081  	remainingDialTimeout time.Duration,
  4082  	hostToConnect string,
  4083  	portToConnect int,
  4084  	doSplitTunnel bool,
  4085  	newChannel ssh.NewChannel) {
  4086  
  4087  	// Assumptions:
  4088  	// - sshClient.dialingTCPPortForward() has been called
  4089  	// - remainingDialTimeout > 0
  4090  
  4091  	established := false
  4092  	defer func() {
  4093  		if !established {
  4094  			sshClient.abortedTCPPortForward()
  4095  		}
  4096  	}()
  4097  
  4098  	// Transparently redirect web API request connections.
  4099  
  4100  	isWebServerPortForward := false
  4101  	config := sshClient.sshServer.support.Config
  4102  	if config.WebServerPortForwardAddress != "" {
  4103  		destination := net.JoinHostPort(hostToConnect, strconv.Itoa(portToConnect))
  4104  		if destination == config.WebServerPortForwardAddress {
  4105  			isWebServerPortForward = true
  4106  			if config.WebServerPortForwardRedirectAddress != "" {
  4107  				// Note: redirect format is validated when config is loaded
  4108  				host, portStr, _ := net.SplitHostPort(config.WebServerPortForwardRedirectAddress)
  4109  				port, _ := strconv.Atoi(portStr)
  4110  				hostToConnect = host
  4111  				portToConnect = port
  4112  			}
  4113  		}
  4114  	}
  4115  
  4116  	// Validate the domain name and check the domain blocklist before dialing.
  4117  	//
  4118  	// The IP blocklist is checked in isPortForwardPermitted, which also provides
  4119  	// IP blocklist checking for the packet tunnel code path. When hostToConnect
  4120  	// is an IP address, the following hostname resolution step effectively
  4121  	// performs no actions and next immediate step is the isPortForwardPermitted
  4122  	// check.
  4123  	//
  4124  	// Limitation: this case handles port forwards where the client sends the
  4125  	// destination domain in the SSH port forward request but does not currently
  4126  	// handle DNS-over-TCP; in the DNS-over-TCP case, a client may bypass the
  4127  	// block list check.
  4128  
  4129  	if !isWebServerPortForward &&
  4130  		net.ParseIP(hostToConnect) == nil {
  4131  
  4132  		ok, rejectMessage := sshClient.isDomainPermitted(hostToConnect)
  4133  		if !ok {
  4134  			// Note: not recording a port forward failure in this case
  4135  			sshClient.rejectNewChannel(newChannel, rejectMessage)
  4136  			return
  4137  		}
  4138  	}
  4139  
  4140  	// Dial the remote address.
  4141  	//
  4142  	// Hostname resolution is performed explicitly, as a separate step, as the
  4143  	// target IP address is used for traffic rules (AllowSubnets), OSL seed
  4144  	// progress, and IP address blocklists.
  4145  	//
  4146  	// Contexts are used for cancellation (via sshClient.runCtx, which is
  4147  	// cancelled when the client is stopping) and timeouts.
  4148  
  4149  	dialStartTime := time.Now()
  4150  
  4151  	IP := net.ParseIP(hostToConnect)
  4152  
  4153  	if IP == nil {
  4154  
  4155  		// Resolve the hostname
  4156  
  4157  		log.WithTraceFields(LogFields{"hostToConnect": hostToConnect}).Debug("resolving")
  4158  
  4159  		ctx, cancelCtx := context.WithTimeout(sshClient.runCtx, remainingDialTimeout)
  4160  		IPs, err := (&net.Resolver{}).LookupIPAddr(ctx, hostToConnect)
  4161  		cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
  4162  
  4163  		resolveElapsedTime := time.Since(dialStartTime)
  4164  
  4165  		// Record DNS metrics. If LookupIPAddr returns net.DNSError.IsNotFound, this
  4166  		// is "no such host" and not a DNS failure. Limitation: the resolver IP is
  4167  		// not known.
  4168  
  4169  		dnsErr, ok := err.(*net.DNSError)
  4170  		dnsNotFound := ok && dnsErr.IsNotFound
  4171  		dnsSuccess := err == nil || dnsNotFound
  4172  		sshClient.updateQualityMetricsWithDNSResult(dnsSuccess, resolveElapsedTime, nil)
  4173  
  4174  		// IPv4 is preferred in case the host has limited IPv6 routing. IPv6 is
  4175  		// selected and attempted only when there's no IPv4 option.
  4176  		// TODO: shuffle list to try other IPs?
  4177  
  4178  		for _, ip := range IPs {
  4179  			if ip.IP.To4() != nil {
  4180  				IP = ip.IP
  4181  				break
  4182  			}
  4183  		}
  4184  		if IP == nil && len(IPs) > 0 {
  4185  			// If there are no IPv4 IPs, the first IP is IPv6.
  4186  			IP = IPs[0].IP
  4187  		}
  4188  
  4189  		if err == nil && IP == nil {
  4190  			err = std_errors.New("no IP address")
  4191  		}
  4192  
  4193  		if err != nil {
  4194  
  4195  			// Record a port forward failure
  4196  			sshClient.updateQualityMetricsWithDialResult(false, resolveElapsedTime, IP)
  4197  
  4198  			sshClient.rejectNewChannel(newChannel, fmt.Sprintf("LookupIP failed: %s", err))
  4199  			return
  4200  		}
  4201  
  4202  		remainingDialTimeout -= resolveElapsedTime
  4203  	}
  4204  
  4205  	if remainingDialTimeout <= 0 {
  4206  		sshClient.rejectNewChannel(newChannel, "TCP port forward timed out resolving")
  4207  		return
  4208  	}
  4209  
  4210  	// When the client has indicated split tunnel mode and when the channel is
  4211  	// not of type protocol.TCP_PORT_FORWARD_NO_SPLIT_TUNNEL_TYPE, check if the
  4212  	// client and the port forward destination are in the same GeoIP country. If
  4213  	// so, reject the port forward with a distinct response code that indicates
  4214  	// to the client that this port forward should be performed locally, direct
  4215  	// and untunneled.
  4216  	//
  4217  	// Clients are expected to cache untunneled responses to avoid this round
  4218  	// trip in the immediate future and reduce server load.
  4219  	//
  4220  	// When the countries differ, immediately proceed with the standard port
  4221  	// forward. No additional round trip is required.
  4222  	//
  4223  	// If either GeoIP country is "None", one or both countries are unknown
  4224  	// and there is no match.
  4225  	//
  4226  	// Traffic rules, such as allowed ports, are not enforced for port forward
  4227  	// destinations classified as untunneled.
  4228  	//
  4229  	// Domain and IP blocklists still apply to port forward destinations
  4230  	// classified as untunneled.
  4231  	//
  4232  	// The client's use of split tunnel mode is logged in server_tunnel metrics
  4233  	// as the boolean value split_tunnel. As they may indicate some information
  4234  	// about browsing activity, no other split tunnel metrics are logged.
  4235  
  4236  	if doSplitTunnel {
  4237  
  4238  		destinationGeoIPData := sshClient.sshServer.support.GeoIPService.LookupIP(IP)
  4239  
  4240  		if sshClient.geoIPData.Country != GEOIP_UNKNOWN_VALUE &&
  4241  			sshClient.handshakeState.splitTunnelLookup.lookup(
  4242  				destinationGeoIPData.Country) {
  4243  
  4244  			// Since isPortForwardPermitted is not called in this case, explicitly call
  4245  			// ipBlocklistCheck. The domain blocklist case is handled above.
  4246  			if !sshClient.isIPPermitted(IP) {
  4247  				// Note: not recording a port forward failure in this case
  4248  				sshClient.rejectNewChannel(newChannel, "port forward not permitted")
  4249  				return
  4250  			}
  4251  
  4252  			newChannel.Reject(protocol.CHANNEL_REJECT_REASON_SPLIT_TUNNEL, "")
  4253  			return
  4254  		}
  4255  	}
  4256  
  4257  	// Enforce traffic rules, using the resolved IP address.
  4258  
  4259  	if !isWebServerPortForward &&
  4260  		!sshClient.isPortForwardPermitted(
  4261  			portForwardTypeTCP,
  4262  			IP,
  4263  			portToConnect) {
  4264  		// Note: not recording a port forward failure in this case
  4265  		sshClient.rejectNewChannel(newChannel, "port forward not permitted")
  4266  		return
  4267  	}
  4268  
  4269  	// TCP dial.
  4270  
  4271  	remoteAddr := net.JoinHostPort(IP.String(), strconv.Itoa(portToConnect))
  4272  
  4273  	log.WithTraceFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing")
  4274  
  4275  	ctx, cancelCtx := context.WithTimeout(sshClient.runCtx, remainingDialTimeout)
  4276  	fwdConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", remoteAddr)
  4277  	cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled"
  4278  
  4279  	// Record port forward success or failure
  4280  	sshClient.updateQualityMetricsWithDialResult(err == nil, time.Since(dialStartTime), IP)
  4281  
  4282  	if err != nil {
  4283  
  4284  		// Monitor for low resource error conditions
  4285  		sshClient.sshServer.monitorPortForwardDialError(err)
  4286  
  4287  		sshClient.rejectNewChannel(newChannel, fmt.Sprintf("DialTimeout failed: %s", err))
  4288  		return
  4289  	}
  4290  
  4291  	// The upstream TCP port forward connection has been established. Schedule
  4292  	// some cleanup and notify the SSH client that the channel is accepted.
  4293  
  4294  	defer fwdConn.Close()
  4295  
  4296  	fwdChannel, requests, err := newChannel.Accept()
  4297  	if err != nil {
  4298  		if !isExpectedTunnelIOError(err) {
  4299  			log.WithTraceFields(LogFields{"error": err}).Warning("accept new channel failed")
  4300  		}
  4301  		return
  4302  	}
  4303  	go ssh.DiscardRequests(requests)
  4304  	defer fwdChannel.Close()
  4305  
  4306  	// Release the dialing slot and acquire an established slot.
  4307  	//
  4308  	// establishedPortForward increments the concurrent TCP port
  4309  	// forward counter and closes the LRU existing TCP port forward
  4310  	// when already at the limit.
  4311  	//
  4312  	// Known limitations:
  4313  	//
  4314  	// - Closed LRU TCP sockets will enter the TIME_WAIT state,
  4315  	//   continuing to consume some resources.
  4316  
  4317  	sshClient.establishedPortForward(portForwardTypeTCP, sshClient.tcpPortForwardLRU)
  4318  
  4319  	// "established = true" cancels the deferred abortedTCPPortForward()
  4320  	established = true
  4321  
  4322  	// TODO: 64-bit alignment? https://golang.org/pkg/sync/atomic/#pkg-note-BUG
  4323  	var bytesUp, bytesDown int64
  4324  	defer func() {
  4325  		sshClient.closedPortForward(
  4326  			portForwardTypeTCP, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown))
  4327  	}()
  4328  
  4329  	lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn)
  4330  	defer lruEntry.Remove()
  4331  
  4332  	// ActivityMonitoredConn monitors the TCP port forward I/O and updates
  4333  	// its LRU status. ActivityMonitoredConn also times out I/O on the port
  4334  	// forward if both reads and writes have been idle for the specified
  4335  	// duration.
  4336  
  4337  	fwdConn, err = common.NewActivityMonitoredConn(
  4338  		fwdConn,
  4339  		sshClient.idleTCPPortForwardTimeout(),
  4340  		true,
  4341  		lruEntry,
  4342  		sshClient.getActivityUpdaters(portForwardTypeTCP, IP)...)
  4343  	if err != nil {
  4344  		log.WithTraceFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed")
  4345  		return
  4346  	}
  4347  
  4348  	// Relay channel to forwarded connection.
  4349  
  4350  	log.WithTraceFields(LogFields{"remoteAddr": remoteAddr}).Debug("relaying")
  4351  
  4352  	// TODO: relay errors to fwdChannel.Stderr()?
  4353  	relayWaitGroup := new(sync.WaitGroup)
  4354  	relayWaitGroup.Add(1)
  4355  	go func() {
  4356  		defer relayWaitGroup.Done()
  4357  		// io.Copy allocates a 32K temporary buffer, and each port forward relay
  4358  		// uses two of these buffers; using common.CopyBuffer with a smaller buffer
  4359  		// reduces the overall memory footprint.
  4360  		bytes, err := common.CopyBuffer(
  4361  			fwdChannel, fwdConn, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
  4362  		atomic.AddInt64(&bytesDown, bytes)
  4363  		if err != nil && err != io.EOF {
  4364  			// Debug since errors such as "connection reset by peer" occur during normal operation
  4365  			log.WithTraceFields(LogFields{"error": err}).Debug("downstream TCP relay failed")
  4366  		}
  4367  		// Interrupt upstream io.Copy when downstream is shutting down.
  4368  		// TODO: this is done to quickly cleanup the port forward when
  4369  		// fwdConn has a read timeout, but is it clean -- upstream may still
  4370  		// be flowing?
  4371  		fwdChannel.Close()
  4372  	}()
  4373  	bytes, err := common.CopyBuffer(
  4374  		fwdConn, fwdChannel, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE))
  4375  	atomic.AddInt64(&bytesUp, bytes)
  4376  	if err != nil && err != io.EOF {
  4377  		log.WithTraceFields(LogFields{"error": err}).Debug("upstream TCP relay failed")
  4378  	}
  4379  	// Shutdown special case: fwdChannel will be closed and return EOF when
  4380  	// the SSH connection is closed, but we need to explicitly close fwdConn
  4381  	// to interrupt the downstream io.Copy, which may be blocked on a
  4382  	// fwdConn.Read().
  4383  	fwdConn.Close()
  4384  
  4385  	relayWaitGroup.Wait()
  4386  
  4387  	log.WithTraceFields(
  4388  		LogFields{
  4389  			"remoteAddr": remoteAddr,
  4390  			"bytesUp":    atomic.LoadInt64(&bytesUp),
  4391  			"bytesDown":  atomic.LoadInt64(&bytesDown)}).Debug("exiting")
  4392  }