github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/server/tunnelServer.go (about) 1 /* 2 * Copyright (c) 2016, Psiphon Inc. 3 * All rights reserved. 4 * 5 * This program is free software: you can redistribute it and/or modify 6 * it under the terms of the GNU General Public License as published by 7 * the Free Software Foundation, either version 3 of the License, or 8 * (at your option) any later version. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU General Public License for more details. 14 * 15 * You should have received a copy of the GNU General Public License 16 * along with this program. If not, see <http://www.gnu.org/licenses/>. 17 * 18 */ 19 20 package server 21 22 import ( 23 "bytes" 24 "context" 25 "crypto/rand" 26 "crypto/subtle" 27 "encoding/base64" 28 "encoding/json" 29 std_errors "errors" 30 "fmt" 31 "io" 32 "io/ioutil" 33 "net" 34 "strconv" 35 "sync" 36 "sync/atomic" 37 "syscall" 38 "time" 39 40 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common" 41 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/accesscontrol" 42 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh" 43 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors" 44 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/monotime" 45 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/obfuscator" 46 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl" 47 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters" 48 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng" 49 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol" 50 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic" 51 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/refraction" 52 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tactics" 53 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/tun" 54 "github.com/marusama/semaphore" 55 cache "github.com/patrickmn/go-cache" 56 ) 57 58 const ( 59 SSH_AUTH_LOG_PERIOD = 30 * time.Minute 60 SSH_HANDSHAKE_TIMEOUT = 30 * time.Second 61 SSH_BEGIN_HANDSHAKE_TIMEOUT = 1 * time.Second 62 SSH_CONNECTION_READ_DEADLINE = 5 * time.Minute 63 SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE = 8192 64 SSH_TCP_PORT_FORWARD_QUEUE_SIZE = 1024 65 SSH_KEEP_ALIVE_PAYLOAD_MIN_BYTES = 0 66 SSH_KEEP_ALIVE_PAYLOAD_MAX_BYTES = 256 67 SSH_SEND_OSL_INITIAL_RETRY_DELAY = 30 * time.Second 68 SSH_SEND_OSL_RETRY_FACTOR = 2 69 OSL_SESSION_CACHE_TTL = 5 * time.Minute 70 MAX_AUTHORIZATIONS = 16 71 PRE_HANDSHAKE_RANDOM_STREAM_MAX_COUNT = 1 72 RANDOM_STREAM_MAX_BYTES = 10485760 73 ALERT_REQUEST_QUEUE_BUFFER_SIZE = 16 74 ) 75 76 // TunnelServer is the main server that accepts Psiphon client 77 // connections, via various obfuscation protocols, and provides 78 // port forwarding (TCP and UDP) services to the Psiphon client. 79 // At its core, TunnelServer is an SSH server. SSH is the base 80 // protocol that provides port forward multiplexing, and transport 81 // security. Layered on top of SSH, optionally, is Obfuscated SSH 82 // and meek protocols, which provide further circumvention 83 // capabilities. 84 type TunnelServer struct { 85 runWaitGroup *sync.WaitGroup 86 listenerError chan error 87 shutdownBroadcast <-chan struct{} 88 sshServer *sshServer 89 } 90 91 type sshListener struct { 92 net.Listener 93 localAddress string 94 tunnelProtocol string 95 port int 96 BPFProgramName string 97 } 98 99 // NewTunnelServer initializes a new tunnel server. 100 func NewTunnelServer( 101 support *SupportServices, 102 shutdownBroadcast <-chan struct{}) (*TunnelServer, error) { 103 104 sshServer, err := newSSHServer(support, shutdownBroadcast) 105 if err != nil { 106 return nil, errors.Trace(err) 107 } 108 109 return &TunnelServer{ 110 runWaitGroup: new(sync.WaitGroup), 111 listenerError: make(chan error), 112 shutdownBroadcast: shutdownBroadcast, 113 sshServer: sshServer, 114 }, nil 115 } 116 117 // Run runs the tunnel server; this function blocks while running a selection of 118 // listeners that handle connection using various obfuscation protocols. 119 // 120 // Run listens on each designated tunnel port and spawns new goroutines to handle 121 // each client connection. It halts when shutdownBroadcast is signaled. A list of active 122 // clients is maintained, and when halting all clients are cleanly shutdown. 123 // 124 // Each client goroutine handles its own obfuscation (optional), SSH handshake, SSH 125 // authentication, and then looping on client new channel requests. "direct-tcpip" 126 // channels, dynamic port fowards, are supported. When the UDPInterceptUdpgwServerAddress 127 // config parameter is configured, UDP port forwards over a TCP stream, following 128 // the udpgw protocol, are handled. 129 // 130 // A new goroutine is spawned to handle each port forward for each client. Each port 131 // forward tracks its bytes transferred. Overall per-client stats for connection duration, 132 // GeoIP, number of port forwards, and bytes transferred are tracked and logged when the 133 // client shuts down. 134 // 135 // Note: client handler goroutines may still be shutting down after Run() returns. See 136 // comment in sshClient.stop(). TODO: fully synchronized shutdown. 137 func (server *TunnelServer) Run() error { 138 139 // TODO: should TunnelServer hold its own support pointer? 140 support := server.sshServer.support 141 142 // First bind all listeners; once all are successful, 143 // start accepting connections on each. 144 145 var listeners []*sshListener 146 147 for tunnelProtocol, listenPort := range support.Config.TunnelProtocolPorts { 148 149 localAddress := net.JoinHostPort( 150 support.Config.ServerIPAddress, strconv.Itoa(listenPort)) 151 152 var listener net.Listener 153 var BPFProgramName string 154 var err error 155 156 if protocol.TunnelProtocolUsesFrontedMeekQUIC(tunnelProtocol) { 157 158 // For FRONTED-MEEK-QUIC-OSSH, no listener implemented. The edge-to-server 159 // hop uses HTTPS and the client tunnel protocol is distinguished using 160 // protocol.MeekCookieData.ClientTunnelProtocol. 161 continue 162 163 } else if protocol.TunnelProtocolUsesQUIC(tunnelProtocol) { 164 165 logTunnelProtocol := tunnelProtocol 166 listener, err = quic.Listen( 167 CommonLogger(log), 168 func(clientAddress string, err error, logFields common.LogFields) { 169 logIrregularTunnel( 170 support, logTunnelProtocol, listenPort, clientAddress, 171 errors.Trace(err), LogFields(logFields)) 172 }, 173 localAddress, 174 support.Config.ObfuscatedSSHKey, 175 support.Config.EnableGQUIC) 176 177 } else if protocol.TunnelProtocolUsesRefractionNetworking(tunnelProtocol) { 178 179 listener, err = refraction.Listen(localAddress) 180 181 } else if protocol.TunnelProtocolUsesFrontedMeek(tunnelProtocol) { 182 183 listener, err = net.Listen("tcp", localAddress) 184 185 } else { 186 187 // Only direct, unfronted protocol listeners use TCP BPF circumvention 188 // programs. 189 listener, BPFProgramName, err = newTCPListenerWithBPF(support, localAddress) 190 } 191 192 if err != nil { 193 for _, existingListener := range listeners { 194 existingListener.Listener.Close() 195 } 196 return errors.Trace(err) 197 } 198 199 tacticsListener := NewTacticsListener( 200 support, 201 listener, 202 tunnelProtocol, 203 func(IP string) GeoIPData { return support.GeoIPService.Lookup(IP) }) 204 205 log.WithTraceFields( 206 LogFields{ 207 "localAddress": localAddress, 208 "tunnelProtocol": tunnelProtocol, 209 "BPFProgramName": BPFProgramName, 210 }).Info("listening") 211 212 listeners = append( 213 listeners, 214 &sshListener{ 215 Listener: tacticsListener, 216 localAddress: localAddress, 217 port: listenPort, 218 tunnelProtocol: tunnelProtocol, 219 BPFProgramName: BPFProgramName, 220 }) 221 } 222 223 for _, listener := range listeners { 224 server.runWaitGroup.Add(1) 225 go func(listener *sshListener) { 226 defer server.runWaitGroup.Done() 227 228 log.WithTraceFields( 229 LogFields{ 230 "localAddress": listener.localAddress, 231 "tunnelProtocol": listener.tunnelProtocol, 232 }).Info("running") 233 234 server.sshServer.runListener( 235 listener, 236 server.listenerError) 237 238 log.WithTraceFields( 239 LogFields{ 240 "localAddress": listener.localAddress, 241 "tunnelProtocol": listener.tunnelProtocol, 242 }).Info("stopped") 243 244 }(listener) 245 } 246 247 var err error 248 select { 249 case <-server.shutdownBroadcast: 250 case err = <-server.listenerError: 251 } 252 253 for _, listener := range listeners { 254 listener.Close() 255 } 256 server.sshServer.stopClients() 257 server.runWaitGroup.Wait() 258 259 log.WithTrace().Info("stopped") 260 261 return err 262 } 263 264 // GetLoadStats returns load stats for the tunnel server. The stats are 265 // broken down by protocol ("SSH", "OSSH", etc.) and type. Types of stats 266 // include current connected client count, total number of current port 267 // forwards. 268 func (server *TunnelServer) GetLoadStats() ( 269 UpstreamStats, ProtocolStats, RegionStats) { 270 271 return server.sshServer.getLoadStats() 272 } 273 274 // GetEstablishedClientCount returns the number of currently established 275 // clients. 276 func (server *TunnelServer) GetEstablishedClientCount() int { 277 return server.sshServer.getEstablishedClientCount() 278 } 279 280 // ResetAllClientTrafficRules resets all established client traffic rules 281 // to use the latest config and client properties. Any existing traffic 282 // rule state is lost, including throttling state. 283 func (server *TunnelServer) ResetAllClientTrafficRules() { 284 server.sshServer.resetAllClientTrafficRules() 285 } 286 287 // ResetAllClientOSLConfigs resets all established client OSL state to use 288 // the latest OSL config. Any existing OSL state is lost, including partial 289 // progress towards SLOKs. 290 func (server *TunnelServer) ResetAllClientOSLConfigs() { 291 server.sshServer.resetAllClientOSLConfigs() 292 } 293 294 // SetClientHandshakeState sets the handshake state -- that it completed and 295 // what parameters were passed -- in sshClient. This state is used for allowing 296 // port forwards and for future traffic rule selection. SetClientHandshakeState 297 // also triggers an immediate traffic rule re-selection, as the rules selected 298 // upon tunnel establishment may no longer apply now that handshake values are 299 // set. 300 // 301 // The authorizations received from the client handshake are verified and the 302 // resulting list of authorized access types are applied to the client's tunnel 303 // and traffic rules. 304 // 305 // A list of active authorization IDs, authorized access types, and traffic 306 // rate limits are returned for responding to the client and logging. 307 func (server *TunnelServer) SetClientHandshakeState( 308 sessionID string, 309 state handshakeState, 310 authorizations []string) (*handshakeStateInfo, error) { 311 312 return server.sshServer.setClientHandshakeState(sessionID, state, authorizations) 313 } 314 315 // GetClientHandshaked indicates whether the client has completed a handshake 316 // and whether its traffic rules are immediately exhausted. 317 func (server *TunnelServer) GetClientHandshaked( 318 sessionID string) (bool, bool, error) { 319 320 return server.sshServer.getClientHandshaked(sessionID) 321 } 322 323 // GetClientDisableDiscovery indicates whether discovery is disabled for the 324 // client corresponding to sessionID. 325 func (server *TunnelServer) GetClientDisableDiscovery( 326 sessionID string) (bool, error) { 327 328 return server.sshServer.getClientDisableDiscovery(sessionID) 329 } 330 331 // UpdateClientAPIParameters updates the recorded handshake API parameters for 332 // the client corresponding to sessionID. 333 func (server *TunnelServer) UpdateClientAPIParameters( 334 sessionID string, 335 apiParams common.APIParameters) error { 336 337 return server.sshServer.updateClientAPIParameters(sessionID, apiParams) 338 } 339 340 // AcceptClientDomainBytes indicates whether to accept domain bytes reported 341 // by the client. 342 func (server *TunnelServer) AcceptClientDomainBytes( 343 sessionID string) (bool, error) { 344 345 return server.sshServer.acceptClientDomainBytes(sessionID) 346 } 347 348 // SetEstablishTunnels sets whether new tunnels may be established or not. 349 // When not establishing, incoming connections are immediately closed. 350 func (server *TunnelServer) SetEstablishTunnels(establish bool) { 351 server.sshServer.setEstablishTunnels(establish) 352 } 353 354 // CheckEstablishTunnels returns whether new tunnels may be established or 355 // not, and increments a metrics counter when establishment is disallowed. 356 func (server *TunnelServer) CheckEstablishTunnels() bool { 357 return server.sshServer.checkEstablishTunnels() 358 } 359 360 // GetEstablishTunnelsMetrics returns whether tunnel establishment is 361 // currently allowed and the number of tunnels rejected since due to not 362 // establishing since the last GetEstablishTunnelsMetrics call. 363 func (server *TunnelServer) GetEstablishTunnelsMetrics() (bool, int64) { 364 return server.sshServer.getEstablishTunnelsMetrics() 365 } 366 367 type sshServer struct { 368 // Note: 64-bit ints used with atomic operations are placed 369 // at the start of struct to ensure 64-bit alignment. 370 // (https://golang.org/pkg/sync/atomic/#pkg-note-BUG) 371 lastAuthLog int64 372 authFailedCount int64 373 establishLimitedCount int64 374 support *SupportServices 375 establishTunnels int32 376 concurrentSSHHandshakes semaphore.Semaphore 377 shutdownBroadcast <-chan struct{} 378 sshHostKey ssh.Signer 379 clientsMutex sync.Mutex 380 stoppingClients bool 381 acceptedClientCounts map[string]map[string]int64 382 clients map[string]*sshClient 383 oslSessionCacheMutex sync.Mutex 384 oslSessionCache *cache.Cache 385 authorizationSessionIDsMutex sync.Mutex 386 authorizationSessionIDs map[string]string 387 obfuscatorSeedHistory *obfuscator.SeedHistory 388 } 389 390 func newSSHServer( 391 support *SupportServices, 392 shutdownBroadcast <-chan struct{}) (*sshServer, error) { 393 394 privateKey, err := ssh.ParseRawPrivateKey([]byte(support.Config.SSHPrivateKey)) 395 if err != nil { 396 return nil, errors.Trace(err) 397 } 398 399 // TODO: use cert (ssh.NewCertSigner) for anti-fingerprint? 400 signer, err := ssh.NewSignerFromKey(privateKey) 401 if err != nil { 402 return nil, errors.Trace(err) 403 } 404 405 var concurrentSSHHandshakes semaphore.Semaphore 406 if support.Config.MaxConcurrentSSHHandshakes > 0 { 407 concurrentSSHHandshakes = semaphore.New(support.Config.MaxConcurrentSSHHandshakes) 408 } 409 410 // The OSL session cache temporarily retains OSL seed state 411 // progress for disconnected clients. This enables clients 412 // that disconnect and immediately reconnect to the same 413 // server to resume their OSL progress. Cached progress 414 // is referenced by session ID and is retained for 415 // OSL_SESSION_CACHE_TTL after disconnect. 416 // 417 // Note: session IDs are assumed to be unpredictable. If a 418 // rogue client could guess the session ID of another client, 419 // it could resume its OSL progress and, if the OSL config 420 // were known, infer some activity. 421 oslSessionCache := cache.New(OSL_SESSION_CACHE_TTL, 1*time.Minute) 422 423 return &sshServer{ 424 support: support, 425 establishTunnels: 1, 426 concurrentSSHHandshakes: concurrentSSHHandshakes, 427 shutdownBroadcast: shutdownBroadcast, 428 sshHostKey: signer, 429 acceptedClientCounts: make(map[string]map[string]int64), 430 clients: make(map[string]*sshClient), 431 oslSessionCache: oslSessionCache, 432 authorizationSessionIDs: make(map[string]string), 433 obfuscatorSeedHistory: obfuscator.NewSeedHistory(nil), 434 }, nil 435 } 436 437 func (sshServer *sshServer) setEstablishTunnels(establish bool) { 438 439 // Do nothing when the setting is already correct. This avoids 440 // spurious log messages when setEstablishTunnels is called 441 // periodically with the same setting. 442 if establish == (atomic.LoadInt32(&sshServer.establishTunnels) == 1) { 443 return 444 } 445 446 establishFlag := int32(1) 447 if !establish { 448 establishFlag = 0 449 } 450 atomic.StoreInt32(&sshServer.establishTunnels, establishFlag) 451 452 log.WithTraceFields( 453 LogFields{"establish": establish}).Info("establishing tunnels") 454 } 455 456 func (sshServer *sshServer) checkEstablishTunnels() bool { 457 establishTunnels := atomic.LoadInt32(&sshServer.establishTunnels) == 1 458 if !establishTunnels { 459 atomic.AddInt64(&sshServer.establishLimitedCount, 1) 460 } 461 return establishTunnels 462 } 463 464 func (sshServer *sshServer) getEstablishTunnelsMetrics() (bool, int64) { 465 return atomic.LoadInt32(&sshServer.establishTunnels) == 1, 466 atomic.SwapInt64(&sshServer.establishLimitedCount, 0) 467 } 468 469 // runListener is intended to run an a goroutine; it blocks 470 // running a particular listener. If an unrecoverable error 471 // occurs, it will send the error to the listenerError channel. 472 func (sshServer *sshServer) runListener(sshListener *sshListener, listenerError chan<- error) { 473 474 handleClient := func(clientTunnelProtocol string, clientConn net.Conn) { 475 476 // Note: establish tunnel limiter cannot simply stop TCP 477 // listeners in all cases (e.g., meek) since SSH tunnels can 478 // span multiple TCP connections. 479 480 if !sshServer.checkEstablishTunnels() { 481 log.WithTrace().Debug("not establishing tunnels") 482 clientConn.Close() 483 return 484 } 485 486 // tunnelProtocol is used for stats and traffic rules. In many cases, its 487 // value is unambiguously determined by the listener port. In certain cases, 488 // such as multiple fronted protocols with a single backend listener, the 489 // client's reported tunnel protocol value is used. The caller must validate 490 // clientTunnelProtocol with protocol.IsValidClientTunnelProtocol. 491 492 tunnelProtocol := sshListener.tunnelProtocol 493 if clientTunnelProtocol != "" { 494 tunnelProtocol = clientTunnelProtocol 495 } 496 497 // sshListener.tunnelProtocol indictes the tunnel protocol run by the 498 // listener. For direct protocols, this is also the client tunnel protocol. 499 // For fronted protocols, the client may use a different protocol to connect 500 // to the front and then only the front-to-Psiphon server will use the 501 // listener protocol. 502 // 503 // A fronted meek client, for example, reports its first hop protocol in 504 // protocol.MeekCookieData.ClientTunnelProtocol. Most metrics record this 505 // value as relay_protocol, since the first hop is the one subject to 506 // adversarial conditions. In some cases, such as irregular tunnels, there 507 // is no ClientTunnelProtocol value available and the listener tunnel 508 // protocol will be logged. 509 // 510 // Similarly, listenerPort indicates the listening port, which is the dialed 511 // port number for direct protocols; while, for fronted protocols, the 512 // client may dial a different port for its first hop. 513 514 // Process each client connection concurrently. 515 go sshServer.handleClient(sshListener, tunnelProtocol, clientConn) 516 } 517 518 // Note: when exiting due to a unrecoverable error, be sure 519 // to try to send the error to listenerError so that the outer 520 // TunnelServer.Run will properly shut down instead of remaining 521 // running. 522 523 if protocol.TunnelProtocolUsesMeekHTTP(sshListener.tunnelProtocol) || 524 protocol.TunnelProtocolUsesMeekHTTPS(sshListener.tunnelProtocol) { 525 526 meekServer, err := NewMeekServer( 527 sshServer.support, 528 sshListener.Listener, 529 sshListener.tunnelProtocol, 530 sshListener.port, 531 protocol.TunnelProtocolUsesMeekHTTPS(sshListener.tunnelProtocol), 532 protocol.TunnelProtocolUsesFrontedMeek(sshListener.tunnelProtocol), 533 protocol.TunnelProtocolUsesObfuscatedSessionTickets(sshListener.tunnelProtocol), 534 handleClient, 535 sshServer.shutdownBroadcast) 536 537 if err == nil { 538 err = meekServer.Run() 539 } 540 541 if err != nil { 542 select { 543 case listenerError <- errors.Trace(err): 544 default: 545 } 546 return 547 } 548 549 } else { 550 551 for { 552 conn, err := sshListener.Listener.Accept() 553 554 select { 555 case <-sshServer.shutdownBroadcast: 556 if err == nil { 557 conn.Close() 558 } 559 return 560 default: 561 } 562 563 if err != nil { 564 if e, ok := err.(net.Error); ok && e.Temporary() { 565 log.WithTraceFields(LogFields{"error": err}).Error("accept failed") 566 // Temporary error, keep running 567 continue 568 } 569 570 select { 571 case listenerError <- errors.Trace(err): 572 default: 573 } 574 return 575 } 576 577 handleClient("", conn) 578 } 579 } 580 } 581 582 // An accepted client has completed a direct TCP or meek connection and has a net.Conn. Registration 583 // is for tracking the number of connections. 584 func (sshServer *sshServer) registerAcceptedClient(tunnelProtocol, region string) { 585 586 sshServer.clientsMutex.Lock() 587 defer sshServer.clientsMutex.Unlock() 588 589 if sshServer.acceptedClientCounts[tunnelProtocol] == nil { 590 sshServer.acceptedClientCounts[tunnelProtocol] = make(map[string]int64) 591 } 592 593 sshServer.acceptedClientCounts[tunnelProtocol][region] += 1 594 } 595 596 func (sshServer *sshServer) unregisterAcceptedClient(tunnelProtocol, region string) { 597 598 sshServer.clientsMutex.Lock() 599 defer sshServer.clientsMutex.Unlock() 600 601 sshServer.acceptedClientCounts[tunnelProtocol][region] -= 1 602 } 603 604 // An established client has completed its SSH handshake and has a ssh.Conn. Registration is 605 // for tracking the number of fully established clients and for maintaining a list of running 606 // clients (for stopping at shutdown time). 607 func (sshServer *sshServer) registerEstablishedClient(client *sshClient) bool { 608 609 sshServer.clientsMutex.Lock() 610 611 if sshServer.stoppingClients { 612 sshServer.clientsMutex.Unlock() 613 return false 614 } 615 616 // In the case of a duplicate client sessionID, the previous client is closed. 617 // - Well-behaved clients generate a random sessionID that should be unique (won't 618 // accidentally conflict) and hard to guess (can't be targeted by a malicious 619 // client). 620 // - Clients reuse the same sessionID when a tunnel is unexpectedly disconnected 621 // and reestablished. In this case, when the same server is selected, this logic 622 // will be hit; closing the old, dangling client is desirable. 623 // - Multi-tunnel clients should not normally use one server for multiple tunnels. 624 625 existingClient := sshServer.clients[client.sessionID] 626 627 sshServer.clientsMutex.Unlock() 628 629 if existingClient != nil { 630 631 // This case is expected to be common, and so logged at the lowest severity 632 // level. 633 log.WithTrace().Debug( 634 "stopping existing client with duplicate session ID") 635 636 existingClient.stop() 637 638 // Block until the existingClient is fully terminated. This is necessary to 639 // avoid this scenario: 640 // - existingClient is invoking handshakeAPIRequestHandler 641 // - sshServer.clients[client.sessionID] is updated to point to new client 642 // - existingClient's handshakeAPIRequestHandler invokes 643 // SetClientHandshakeState but sets the handshake parameters for new 644 // client 645 // - as a result, the new client handshake will fail (only a single handshake 646 // is permitted) and the new client server_tunnel log will contain an 647 // invalid mix of existing/new client fields 648 // 649 // Once existingClient.awaitStopped returns, all existingClient port 650 // forwards and request handlers have terminated, so no API handler, either 651 // tunneled web API or SSH API, will remain and it is safe to point 652 // sshServer.clients[client.sessionID] to the new client. 653 // Limitation: this scenario remains possible with _untunneled_ web API 654 // requests. 655 // 656 // Blocking also ensures existingClient.releaseAuthorizations is invoked before 657 // the new client attempts to submit the same authorizations. 658 // 659 // Perform blocking awaitStopped operation outside the 660 // sshServer.clientsMutex mutex to avoid blocking all other clients for the 661 // duration. We still expect and require that the stop process completes 662 // rapidly, e.g., does not block on network I/O, allowing the new client 663 // connection to proceed without delay. 664 // 665 // In addition, operations triggered by stop, and which must complete before 666 // awaitStopped returns, will attempt to lock sshServer.clientsMutex, 667 // including unregisterEstablishedClient. 668 669 existingClient.awaitStopped() 670 } 671 672 sshServer.clientsMutex.Lock() 673 defer sshServer.clientsMutex.Unlock() 674 675 // existingClient's stop will have removed it from sshServer.clients via 676 // unregisterEstablishedClient, so sshServer.clients[client.sessionID] should 677 // be nil -- unless yet another client instance using the same sessionID has 678 // connected in the meantime while awaiting existingClient stop. In this 679 // case, it's not clear which is the most recent connection from the client, 680 // so instead of this connection terminating more peers, it aborts. 681 682 if sshServer.clients[client.sessionID] != nil { 683 // As this is expected to be rare case, it's logged at a higher severity 684 // level. 685 log.WithTrace().Warning( 686 "aborting new client with duplicate session ID") 687 return false 688 } 689 690 sshServer.clients[client.sessionID] = client 691 692 return true 693 } 694 695 func (sshServer *sshServer) unregisterEstablishedClient(client *sshClient) { 696 697 sshServer.clientsMutex.Lock() 698 699 registeredClient := sshServer.clients[client.sessionID] 700 701 // registeredClient will differ from client when client is the existingClient 702 // terminated in registerEstablishedClient. In that case, registeredClient 703 // remains connected, and the sshServer.clients entry should be retained. 704 if registeredClient == client { 705 delete(sshServer.clients, client.sessionID) 706 } 707 708 sshServer.clientsMutex.Unlock() 709 710 client.stop() 711 } 712 713 type UpstreamStats map[string]interface{} 714 type ProtocolStats map[string]map[string]interface{} 715 type RegionStats map[string]map[string]map[string]interface{} 716 717 func (sshServer *sshServer) getLoadStats() ( 718 UpstreamStats, ProtocolStats, RegionStats) { 719 720 sshServer.clientsMutex.Lock() 721 defer sshServer.clientsMutex.Unlock() 722 723 // Explicitly populate with zeros to ensure 0 counts in log messages. 724 725 zeroClientStats := func() map[string]interface{} { 726 stats := make(map[string]interface{}) 727 stats["accepted_clients"] = int64(0) 728 stats["established_clients"] = int64(0) 729 return stats 730 } 731 732 // Due to hot reload and changes to the underlying system configuration, the 733 // set of resolver IPs may change between getLoadStats calls, so this 734 // enumeration for zeroing is a best effort. 735 resolverIPs := sshServer.support.DNSResolver.GetAll() 736 737 // Fields which are primarily concerned with upstream/egress performance. 738 zeroUpstreamStats := func() map[string]interface{} { 739 stats := make(map[string]interface{}) 740 stats["dialing_tcp_port_forwards"] = int64(0) 741 stats["tcp_port_forwards"] = int64(0) 742 stats["total_tcp_port_forwards"] = int64(0) 743 stats["udp_port_forwards"] = int64(0) 744 stats["total_udp_port_forwards"] = int64(0) 745 stats["tcp_port_forward_dialed_count"] = int64(0) 746 stats["tcp_port_forward_dialed_duration"] = int64(0) 747 stats["tcp_port_forward_failed_count"] = int64(0) 748 stats["tcp_port_forward_failed_duration"] = int64(0) 749 stats["tcp_port_forward_rejected_dialing_limit_count"] = int64(0) 750 stats["tcp_port_forward_rejected_disallowed_count"] = int64(0) 751 stats["udp_port_forward_rejected_disallowed_count"] = int64(0) 752 stats["tcp_ipv4_port_forward_dialed_count"] = int64(0) 753 stats["tcp_ipv4_port_forward_dialed_duration"] = int64(0) 754 stats["tcp_ipv4_port_forward_failed_count"] = int64(0) 755 stats["tcp_ipv4_port_forward_failed_duration"] = int64(0) 756 stats["tcp_ipv6_port_forward_dialed_count"] = int64(0) 757 stats["tcp_ipv6_port_forward_dialed_duration"] = int64(0) 758 stats["tcp_ipv6_port_forward_failed_count"] = int64(0) 759 stats["tcp_ipv6_port_forward_failed_duration"] = int64(0) 760 761 zeroDNSStats := func() map[string]int64 { 762 m := map[string]int64{"ALL": 0} 763 for _, resolverIP := range resolverIPs { 764 m[resolverIP.String()] = 0 765 } 766 return m 767 } 768 769 stats["dns_count"] = zeroDNSStats() 770 stats["dns_duration"] = zeroDNSStats() 771 stats["dns_failed_count"] = zeroDNSStats() 772 stats["dns_failed_duration"] = zeroDNSStats() 773 return stats 774 } 775 776 zeroProtocolStats := func() map[string]map[string]interface{} { 777 stats := make(map[string]map[string]interface{}) 778 stats["ALL"] = zeroClientStats() 779 for tunnelProtocol := range sshServer.support.Config.TunnelProtocolPorts { 780 stats[tunnelProtocol] = zeroClientStats() 781 } 782 return stats 783 } 784 785 addInt64 := func(stats map[string]interface{}, name string, value int64) { 786 stats[name] = stats[name].(int64) + value 787 } 788 789 upstreamStats := zeroUpstreamStats() 790 791 // [<protocol or ALL>][<stat name>] -> count 792 protocolStats := zeroProtocolStats() 793 794 // [<region][<protocol or ALL>][<stat name>] -> count 795 regionStats := make(RegionStats) 796 797 // Note: as currently tracked/counted, each established client is also an accepted client 798 799 for tunnelProtocol, regionAcceptedClientCounts := range sshServer.acceptedClientCounts { 800 for region, acceptedClientCount := range regionAcceptedClientCounts { 801 802 if acceptedClientCount > 0 { 803 if regionStats[region] == nil { 804 regionStats[region] = zeroProtocolStats() 805 } 806 807 addInt64(protocolStats["ALL"], "accepted_clients", acceptedClientCount) 808 addInt64(protocolStats[tunnelProtocol], "accepted_clients", acceptedClientCount) 809 810 addInt64(regionStats[region]["ALL"], "accepted_clients", acceptedClientCount) 811 addInt64(regionStats[region][tunnelProtocol], "accepted_clients", acceptedClientCount) 812 } 813 } 814 } 815 816 for _, client := range sshServer.clients { 817 818 client.Lock() 819 820 tunnelProtocol := client.tunnelProtocol 821 region := client.geoIPData.Country 822 823 if regionStats[region] == nil { 824 regionStats[region] = zeroProtocolStats() 825 } 826 827 for _, stats := range []map[string]interface{}{ 828 protocolStats["ALL"], 829 protocolStats[tunnelProtocol], 830 regionStats[region]["ALL"], 831 regionStats[region][tunnelProtocol]} { 832 833 addInt64(stats, "established_clients", 1) 834 } 835 836 // Note: 837 // - can't sum trafficState.peakConcurrentPortForwardCount to get a global peak 838 // - client.udpTrafficState.concurrentDialingPortForwardCount isn't meaningful 839 840 addInt64(upstreamStats, "dialing_tcp_port_forwards", 841 client.tcpTrafficState.concurrentDialingPortForwardCount) 842 843 addInt64(upstreamStats, "tcp_port_forwards", 844 client.tcpTrafficState.concurrentPortForwardCount) 845 846 addInt64(upstreamStats, "total_tcp_port_forwards", 847 client.tcpTrafficState.totalPortForwardCount) 848 849 addInt64(upstreamStats, "udp_port_forwards", 850 client.udpTrafficState.concurrentPortForwardCount) 851 852 addInt64(upstreamStats, "total_udp_port_forwards", 853 client.udpTrafficState.totalPortForwardCount) 854 855 addInt64(upstreamStats, "tcp_port_forward_dialed_count", 856 client.qualityMetrics.TCPPortForwardDialedCount) 857 858 addInt64(upstreamStats, "tcp_port_forward_dialed_duration", 859 int64(client.qualityMetrics.TCPPortForwardDialedDuration/time.Millisecond)) 860 861 addInt64(upstreamStats, "tcp_port_forward_failed_count", 862 client.qualityMetrics.TCPPortForwardFailedCount) 863 864 addInt64(upstreamStats, "tcp_port_forward_failed_duration", 865 int64(client.qualityMetrics.TCPPortForwardFailedDuration/time.Millisecond)) 866 867 addInt64(upstreamStats, "tcp_port_forward_rejected_dialing_limit_count", 868 client.qualityMetrics.TCPPortForwardRejectedDialingLimitCount) 869 870 addInt64(upstreamStats, "tcp_port_forward_rejected_disallowed_count", 871 client.qualityMetrics.TCPPortForwardRejectedDisallowedCount) 872 873 addInt64(upstreamStats, "udp_port_forward_rejected_disallowed_count", 874 client.qualityMetrics.UDPPortForwardRejectedDisallowedCount) 875 876 addInt64(upstreamStats, "tcp_ipv4_port_forward_dialed_count", 877 client.qualityMetrics.TCPIPv4PortForwardDialedCount) 878 879 addInt64(upstreamStats, "tcp_ipv4_port_forward_dialed_duration", 880 int64(client.qualityMetrics.TCPIPv4PortForwardDialedDuration/time.Millisecond)) 881 882 addInt64(upstreamStats, "tcp_ipv4_port_forward_failed_count", 883 client.qualityMetrics.TCPIPv4PortForwardFailedCount) 884 885 addInt64(upstreamStats, "tcp_ipv4_port_forward_failed_duration", 886 int64(client.qualityMetrics.TCPIPv4PortForwardFailedDuration/time.Millisecond)) 887 888 addInt64(upstreamStats, "tcp_ipv6_port_forward_dialed_count", 889 client.qualityMetrics.TCPIPv6PortForwardDialedCount) 890 891 addInt64(upstreamStats, "tcp_ipv6_port_forward_dialed_duration", 892 int64(client.qualityMetrics.TCPIPv6PortForwardDialedDuration/time.Millisecond)) 893 894 addInt64(upstreamStats, "tcp_ipv6_port_forward_failed_count", 895 client.qualityMetrics.TCPIPv6PortForwardFailedCount) 896 897 addInt64(upstreamStats, "tcp_ipv6_port_forward_failed_duration", 898 int64(client.qualityMetrics.TCPIPv6PortForwardFailedDuration/time.Millisecond)) 899 900 // DNS metrics limitations: 901 // - port forwards (sshClient.handleTCPChannel) don't know or log the resolver IP. 902 // - udpgw and packet tunnel transparent DNS use a heuristic to classify success/failure, 903 // and there may be some delay before these code paths report DNS metrics. 904 905 // Every client.qualityMetrics DNS map has an "ALL" entry. 906 907 totalDNSCount := int64(0) 908 totalDNSFailedCount := int64(0) 909 910 for key, value := range client.qualityMetrics.DNSCount { 911 upstreamStats["dns_count"].(map[string]int64)[key] += value 912 totalDNSCount += value 913 } 914 915 for key, value := range client.qualityMetrics.DNSDuration { 916 upstreamStats["dns_duration"].(map[string]int64)[key] += int64(value / time.Millisecond) 917 } 918 919 for key, value := range client.qualityMetrics.DNSFailedCount { 920 upstreamStats["dns_failed_count"].(map[string]int64)[key] += value 921 totalDNSFailedCount += value 922 } 923 924 for key, value := range client.qualityMetrics.DNSFailedDuration { 925 upstreamStats["dns_failed_duration"].(map[string]int64)[key] += int64(value / time.Millisecond) 926 } 927 928 // Update client peak failure rate metrics, to be recorded in 929 // server_tunnel. 930 // 931 // Limitations: 932 // 933 // - This is a simple data sampling that doesn't require additional 934 // timers or tracking logic. Since the rates are calculated on 935 // getLoadStats events and using accumulated counts, these peaks 936 // only represent the highest failure rate within a 937 // Config.LoadMonitorPeriodSeconds non-sliding window. There is no 938 // sample recorded for short tunnels with no overlapping 939 // getLoadStats event. 940 // 941 // - There is no minimum sample window, as a getLoadStats event may 942 // occur immediately after a client first connects. This may be 943 // compensated for by adjusting 944 // Config.PeakUpstreamFailureRateMinimumSampleSize, so as to only 945 // consider failure rates with a larger number of samples. 946 // 947 // - Non-UDP "failures" are not currently tracked. 948 949 minimumSampleSize := int64(sshServer.support.Config.peakUpstreamFailureRateMinimumSampleSize) 950 951 sampleSize := client.qualityMetrics.TCPPortForwardDialedCount + 952 client.qualityMetrics.TCPPortForwardFailedCount 953 954 if sampleSize >= minimumSampleSize { 955 956 TCPPortForwardFailureRate := float64(client.qualityMetrics.TCPPortForwardFailedCount) / 957 float64(sampleSize) 958 959 if client.peakMetrics.TCPPortForwardFailureRate == nil { 960 961 client.peakMetrics.TCPPortForwardFailureRate = new(float64) 962 *client.peakMetrics.TCPPortForwardFailureRate = TCPPortForwardFailureRate 963 client.peakMetrics.TCPPortForwardFailureRateSampleSize = new(int64) 964 *client.peakMetrics.TCPPortForwardFailureRateSampleSize = sampleSize 965 966 } else if *client.peakMetrics.TCPPortForwardFailureRate < TCPPortForwardFailureRate { 967 968 *client.peakMetrics.TCPPortForwardFailureRate = TCPPortForwardFailureRate 969 *client.peakMetrics.TCPPortForwardFailureRateSampleSize = sampleSize 970 } 971 } 972 973 sampleSize = totalDNSCount + totalDNSFailedCount 974 975 if sampleSize >= minimumSampleSize { 976 977 DNSFailureRate := float64(totalDNSFailedCount) / float64(sampleSize) 978 979 if client.peakMetrics.DNSFailureRate == nil { 980 981 client.peakMetrics.DNSFailureRate = new(float64) 982 *client.peakMetrics.DNSFailureRate = DNSFailureRate 983 client.peakMetrics.DNSFailureRateSampleSize = new(int64) 984 *client.peakMetrics.DNSFailureRateSampleSize = sampleSize 985 986 } else if *client.peakMetrics.DNSFailureRate < DNSFailureRate { 987 988 *client.peakMetrics.DNSFailureRate = DNSFailureRate 989 *client.peakMetrics.DNSFailureRateSampleSize = sampleSize 990 } 991 } 992 993 // Reset quality metrics counters 994 995 client.qualityMetrics.reset() 996 997 client.Unlock() 998 } 999 1000 for _, client := range sshServer.clients { 1001 1002 client.Lock() 1003 1004 // Update client peak proximate (same region) concurrently connected 1005 // (other clients) client metrics, to be recorded in server_tunnel. 1006 // This operation requires a second loop over sshServer.clients since 1007 // established_clients is calculated in the first loop. 1008 // 1009 // Limitations: 1010 // 1011 // - This is an approximation, not a true peak, as it only samples 1012 // data every Config.LoadMonitorPeriodSeconds period. There is no 1013 // sample recorded for short tunnels with no overlapping 1014 // getLoadStats event. 1015 // 1016 // - The "-1" calculation counts all but the current client as other 1017 // clients; it can be the case that the same client has a dangling 1018 // accepted connection that has yet to time-out server side. Due to 1019 // NAT, we can't determine if the client is the same based on 1020 // network address. For established clients, 1021 // registerEstablishedClient ensures that any previous connection 1022 // is first terminated, although this is only for the same 1023 // session_id. Concurrent proximate clients may be considered an 1024 // exact number of other _network connections_, even from the same 1025 // client. 1026 1027 region := client.geoIPData.Country 1028 stats := regionStats[region]["ALL"] 1029 1030 n := stats["accepted_clients"].(int64) - 1 1031 if n >= 0 { 1032 if client.peakMetrics.concurrentProximateAcceptedClients == nil { 1033 1034 client.peakMetrics.concurrentProximateAcceptedClients = new(int64) 1035 *client.peakMetrics.concurrentProximateAcceptedClients = n 1036 1037 } else if *client.peakMetrics.concurrentProximateAcceptedClients < n { 1038 1039 *client.peakMetrics.concurrentProximateAcceptedClients = n 1040 } 1041 } 1042 1043 n = stats["established_clients"].(int64) - 1 1044 if n >= 0 { 1045 if client.peakMetrics.concurrentProximateEstablishedClients == nil { 1046 1047 client.peakMetrics.concurrentProximateEstablishedClients = new(int64) 1048 *client.peakMetrics.concurrentProximateEstablishedClients = n 1049 1050 } else if *client.peakMetrics.concurrentProximateEstablishedClients < n { 1051 1052 *client.peakMetrics.concurrentProximateEstablishedClients = n 1053 } 1054 } 1055 1056 client.Unlock() 1057 } 1058 1059 return upstreamStats, protocolStats, regionStats 1060 } 1061 1062 func (sshServer *sshServer) getEstablishedClientCount() int { 1063 sshServer.clientsMutex.Lock() 1064 defer sshServer.clientsMutex.Unlock() 1065 establishedClients := len(sshServer.clients) 1066 return establishedClients 1067 } 1068 1069 func (sshServer *sshServer) resetAllClientTrafficRules() { 1070 1071 sshServer.clientsMutex.Lock() 1072 clients := make(map[string]*sshClient) 1073 for sessionID, client := range sshServer.clients { 1074 clients[sessionID] = client 1075 } 1076 sshServer.clientsMutex.Unlock() 1077 1078 for _, client := range clients { 1079 client.setTrafficRules() 1080 } 1081 } 1082 1083 func (sshServer *sshServer) resetAllClientOSLConfigs() { 1084 1085 // Flush cached seed state. This has the same effect 1086 // and same limitations as calling setOSLConfig for 1087 // currently connected clients -- all progress is lost. 1088 sshServer.oslSessionCacheMutex.Lock() 1089 sshServer.oslSessionCache.Flush() 1090 sshServer.oslSessionCacheMutex.Unlock() 1091 1092 sshServer.clientsMutex.Lock() 1093 clients := make(map[string]*sshClient) 1094 for sessionID, client := range sshServer.clients { 1095 clients[sessionID] = client 1096 } 1097 sshServer.clientsMutex.Unlock() 1098 1099 for _, client := range clients { 1100 client.setOSLConfig() 1101 } 1102 } 1103 1104 func (sshServer *sshServer) setClientHandshakeState( 1105 sessionID string, 1106 state handshakeState, 1107 authorizations []string) (*handshakeStateInfo, error) { 1108 1109 sshServer.clientsMutex.Lock() 1110 client := sshServer.clients[sessionID] 1111 sshServer.clientsMutex.Unlock() 1112 1113 if client == nil { 1114 return nil, errors.TraceNew("unknown session ID") 1115 } 1116 1117 handshakeStateInfo, err := client.setHandshakeState( 1118 state, authorizations) 1119 if err != nil { 1120 return nil, errors.Trace(err) 1121 } 1122 1123 return handshakeStateInfo, nil 1124 } 1125 1126 func (sshServer *sshServer) getClientHandshaked( 1127 sessionID string) (bool, bool, error) { 1128 1129 sshServer.clientsMutex.Lock() 1130 client := sshServer.clients[sessionID] 1131 sshServer.clientsMutex.Unlock() 1132 1133 if client == nil { 1134 return false, false, errors.TraceNew("unknown session ID") 1135 } 1136 1137 completed, exhausted := client.getHandshaked() 1138 1139 return completed, exhausted, nil 1140 } 1141 1142 func (sshServer *sshServer) getClientDisableDiscovery( 1143 sessionID string) (bool, error) { 1144 1145 sshServer.clientsMutex.Lock() 1146 client := sshServer.clients[sessionID] 1147 sshServer.clientsMutex.Unlock() 1148 1149 if client == nil { 1150 return false, errors.TraceNew("unknown session ID") 1151 } 1152 1153 return client.getDisableDiscovery(), nil 1154 } 1155 1156 func (sshServer *sshServer) updateClientAPIParameters( 1157 sessionID string, 1158 apiParams common.APIParameters) error { 1159 1160 sshServer.clientsMutex.Lock() 1161 client := sshServer.clients[sessionID] 1162 sshServer.clientsMutex.Unlock() 1163 1164 if client == nil { 1165 return errors.TraceNew("unknown session ID") 1166 } 1167 1168 client.updateAPIParameters(apiParams) 1169 1170 return nil 1171 } 1172 1173 func (sshServer *sshServer) revokeClientAuthorizations(sessionID string) { 1174 sshServer.clientsMutex.Lock() 1175 client := sshServer.clients[sessionID] 1176 sshServer.clientsMutex.Unlock() 1177 1178 if client == nil { 1179 return 1180 } 1181 1182 // sshClient.handshakeState.authorizedAccessTypes is not cleared. Clearing 1183 // authorizedAccessTypes may cause sshClient.logTunnel to fail to log 1184 // access types. As the revocation may be due to legitimate use of an 1185 // authorization in multiple sessions by a single client, useful metrics 1186 // would be lost. 1187 1188 client.Lock() 1189 client.handshakeState.authorizationsRevoked = true 1190 client.Unlock() 1191 1192 // Select and apply new traffic rules, as filtered by the client's new 1193 // authorization state. 1194 1195 client.setTrafficRules() 1196 } 1197 1198 func (sshServer *sshServer) acceptClientDomainBytes( 1199 sessionID string) (bool, error) { 1200 1201 sshServer.clientsMutex.Lock() 1202 client := sshServer.clients[sessionID] 1203 sshServer.clientsMutex.Unlock() 1204 1205 if client == nil { 1206 return false, errors.TraceNew("unknown session ID") 1207 } 1208 1209 return client.acceptDomainBytes(), nil 1210 } 1211 1212 func (sshServer *sshServer) stopClients() { 1213 1214 sshServer.clientsMutex.Lock() 1215 sshServer.stoppingClients = true 1216 clients := sshServer.clients 1217 sshServer.clients = make(map[string]*sshClient) 1218 sshServer.clientsMutex.Unlock() 1219 1220 for _, client := range clients { 1221 client.stop() 1222 } 1223 } 1224 1225 func (sshServer *sshServer) handleClient( 1226 sshListener *sshListener, tunnelProtocol string, clientConn net.Conn) { 1227 1228 // Calling clientConn.RemoteAddr at this point, before any Read calls, 1229 // satisfies the constraint documented in tapdance.Listen. 1230 1231 clientAddr := clientConn.RemoteAddr() 1232 1233 // Check if there were irregularities during the network connection 1234 // establishment. When present, log and then behave as Obfuscated SSH does 1235 // when the client fails to provide a valid seed message. 1236 // 1237 // One concrete irregular case is failure to send a PROXY protocol header for 1238 // TAPDANCE-OSSH. 1239 1240 if indicator, ok := clientConn.(common.IrregularIndicator); ok { 1241 1242 tunnelErr := indicator.IrregularTunnelError() 1243 1244 if tunnelErr != nil { 1245 1246 logIrregularTunnel( 1247 sshServer.support, 1248 sshListener.tunnelProtocol, 1249 sshListener.port, 1250 common.IPAddressFromAddr(clientAddr), 1251 errors.Trace(tunnelErr), 1252 nil) 1253 1254 var afterFunc *time.Timer 1255 if sshServer.support.Config.sshHandshakeTimeout > 0 { 1256 afterFunc = time.AfterFunc(sshServer.support.Config.sshHandshakeTimeout, func() { 1257 clientConn.Close() 1258 }) 1259 } 1260 io.Copy(ioutil.Discard, clientConn) 1261 clientConn.Close() 1262 afterFunc.Stop() 1263 1264 return 1265 } 1266 } 1267 1268 // Get any packet manipulation values from GetAppliedSpecName as soon as 1269 // possible due to the expiring TTL. 1270 1271 serverPacketManipulation := "" 1272 replayedServerPacketManipulation := false 1273 1274 if sshServer.support.Config.RunPacketManipulator && 1275 protocol.TunnelProtocolMayUseServerPacketManipulation(tunnelProtocol) { 1276 1277 // A meekConn has synthetic address values, including the original client 1278 // address in cases where the client uses an upstream proxy to connect to 1279 // Psiphon. For meekConn, and any other conn implementing 1280 // UnderlyingTCPAddrSource, get the underlying TCP connection addresses. 1281 // 1282 // Limitation: a meek tunnel may consist of several TCP connections. The 1283 // server_packet_manipulation metric will reflect the packet manipulation 1284 // applied to the _first_ TCP connection only. 1285 1286 var localAddr, remoteAddr *net.TCPAddr 1287 var ok bool 1288 underlying, ok := clientConn.(common.UnderlyingTCPAddrSource) 1289 if ok { 1290 localAddr, remoteAddr, ok = underlying.GetUnderlyingTCPAddrs() 1291 } else { 1292 localAddr, ok = clientConn.LocalAddr().(*net.TCPAddr) 1293 if ok { 1294 remoteAddr, ok = clientConn.RemoteAddr().(*net.TCPAddr) 1295 } 1296 } 1297 1298 if ok { 1299 specName, extraData, err := sshServer.support.PacketManipulator. 1300 GetAppliedSpecName(localAddr, remoteAddr) 1301 if err == nil { 1302 serverPacketManipulation = specName 1303 replayedServerPacketManipulation, _ = extraData.(bool) 1304 } 1305 } 1306 } 1307 1308 geoIPData := sshServer.support.GeoIPService.Lookup( 1309 common.IPAddressFromAddr(clientAddr)) 1310 1311 sshServer.registerAcceptedClient(tunnelProtocol, geoIPData.Country) 1312 defer sshServer.unregisterAcceptedClient(tunnelProtocol, geoIPData.Country) 1313 1314 // When configured, enforce a cap on the number of concurrent SSH 1315 // handshakes. This limits load spikes on busy servers when many clients 1316 // attempt to connect at once. Wait a short time, SSH_BEGIN_HANDSHAKE_TIMEOUT, 1317 // to acquire; waiting will avoid immediately creating more load on another 1318 // server in the network when the client tries a new candidate. Disconnect the 1319 // client when that wait time is exceeded. 1320 // 1321 // This mechanism limits memory allocations and CPU usage associated with the 1322 // SSH handshake. At this point, new direct TCP connections or new meek 1323 // connections, with associated resource usage, are already established. Those 1324 // connections are expected to be rate or load limited using other mechanisms. 1325 // 1326 // TODO: 1327 // 1328 // - deduct time spent acquiring the semaphore from SSH_HANDSHAKE_TIMEOUT in 1329 // sshClient.run, since the client is also applying an SSH handshake timeout 1330 // and won't exclude time spent waiting. 1331 // - each call to sshServer.handleClient (in sshServer.runListener) is invoked 1332 // in its own goroutine, but shutdown doesn't synchronously await these 1333 // goroutnes. Once this is synchronizes, the following context.WithTimeout 1334 // should use an sshServer parent context to ensure blocking acquires 1335 // interrupt immediately upon shutdown. 1336 1337 var onSSHHandshakeFinished func() 1338 if sshServer.support.Config.MaxConcurrentSSHHandshakes > 0 { 1339 1340 ctx, cancelFunc := context.WithTimeout( 1341 context.Background(), 1342 sshServer.support.Config.sshBeginHandshakeTimeout) 1343 defer cancelFunc() 1344 1345 err := sshServer.concurrentSSHHandshakes.Acquire(ctx, 1) 1346 if err != nil { 1347 clientConn.Close() 1348 // This is a debug log as the only possible error is context timeout. 1349 log.WithTraceFields(LogFields{"error": err}).Debug( 1350 "acquire SSH handshake semaphore failed") 1351 return 1352 } 1353 1354 onSSHHandshakeFinished = func() { 1355 sshServer.concurrentSSHHandshakes.Release(1) 1356 } 1357 } 1358 1359 sshClient := newSshClient( 1360 sshServer, 1361 sshListener, 1362 tunnelProtocol, 1363 serverPacketManipulation, 1364 replayedServerPacketManipulation, 1365 clientAddr, 1366 geoIPData) 1367 1368 // sshClient.run _must_ call onSSHHandshakeFinished to release the semaphore: 1369 // in any error case; or, as soon as the SSH handshake phase has successfully 1370 // completed. 1371 1372 sshClient.run(clientConn, onSSHHandshakeFinished) 1373 } 1374 1375 func (sshServer *sshServer) monitorPortForwardDialError(err error) { 1376 1377 // "err" is the error returned from a failed TCP or UDP port 1378 // forward dial. Certain system error codes indicate low resource 1379 // conditions: insufficient file descriptors, ephemeral ports, or 1380 // memory. For these cases, log an alert. 1381 1382 // TODO: also temporarily suspend new clients 1383 1384 // Note: don't log net.OpError.Error() as the full error string 1385 // may contain client destination addresses. 1386 1387 opErr, ok := err.(*net.OpError) 1388 if ok { 1389 if opErr.Err == syscall.EADDRNOTAVAIL || 1390 opErr.Err == syscall.EAGAIN || 1391 opErr.Err == syscall.ENOMEM || 1392 opErr.Err == syscall.EMFILE || 1393 opErr.Err == syscall.ENFILE { 1394 1395 log.WithTraceFields( 1396 LogFields{"error": opErr.Err}).Error( 1397 "port forward dial failed due to unavailable resource") 1398 } 1399 } 1400 } 1401 1402 type sshClient struct { 1403 sync.Mutex 1404 sshServer *sshServer 1405 sshListener *sshListener 1406 tunnelProtocol string 1407 sshConn ssh.Conn 1408 throttledConn *common.ThrottledConn 1409 serverPacketManipulation string 1410 replayedServerPacketManipulation bool 1411 clientAddr net.Addr 1412 geoIPData GeoIPData 1413 sessionID string 1414 isFirstTunnelInSession bool 1415 supportsServerRequests bool 1416 handshakeState handshakeState 1417 udpgwChannelHandler *udpgwPortForwardMultiplexer 1418 totalUdpgwChannelCount int 1419 packetTunnelChannel ssh.Channel 1420 totalPacketTunnelChannelCount int 1421 trafficRules TrafficRules 1422 tcpTrafficState trafficState 1423 udpTrafficState trafficState 1424 qualityMetrics *qualityMetrics 1425 tcpPortForwardLRU *common.LRUConns 1426 oslClientSeedState *osl.ClientSeedState 1427 signalIssueSLOKs chan struct{} 1428 runCtx context.Context 1429 stopRunning context.CancelFunc 1430 stopped chan struct{} 1431 tcpPortForwardDialingAvailableSignal context.CancelFunc 1432 releaseAuthorizations func() 1433 stopTimer *time.Timer 1434 preHandshakeRandomStreamMetrics randomStreamMetrics 1435 postHandshakeRandomStreamMetrics randomStreamMetrics 1436 sendAlertRequests chan protocol.AlertRequest 1437 sentAlertRequests map[string]bool 1438 peakMetrics peakMetrics 1439 destinationBytesMetricsASN string 1440 tcpDestinationBytesMetrics destinationBytesMetrics 1441 udpDestinationBytesMetrics destinationBytesMetrics 1442 } 1443 1444 type trafficState struct { 1445 bytesUp int64 1446 bytesDown int64 1447 concurrentDialingPortForwardCount int64 1448 peakConcurrentDialingPortForwardCount int64 1449 concurrentPortForwardCount int64 1450 peakConcurrentPortForwardCount int64 1451 totalPortForwardCount int64 1452 availablePortForwardCond *sync.Cond 1453 } 1454 1455 type randomStreamMetrics struct { 1456 count int64 1457 upstreamBytes int64 1458 receivedUpstreamBytes int64 1459 downstreamBytes int64 1460 sentDownstreamBytes int64 1461 } 1462 1463 type peakMetrics struct { 1464 concurrentProximateAcceptedClients *int64 1465 concurrentProximateEstablishedClients *int64 1466 TCPPortForwardFailureRate *float64 1467 TCPPortForwardFailureRateSampleSize *int64 1468 DNSFailureRate *float64 1469 DNSFailureRateSampleSize *int64 1470 } 1471 1472 // qualityMetrics records upstream TCP dial attempts and 1473 // elapsed time. Elapsed time includes the full TCP handshake 1474 // and, in aggregate, is a measure of the quality of the 1475 // upstream link. These stats are recorded by each sshClient 1476 // and then reported and reset in sshServer.getLoadStats(). 1477 type qualityMetrics struct { 1478 TCPPortForwardDialedCount int64 1479 TCPPortForwardDialedDuration time.Duration 1480 TCPPortForwardFailedCount int64 1481 TCPPortForwardFailedDuration time.Duration 1482 TCPPortForwardRejectedDialingLimitCount int64 1483 TCPPortForwardRejectedDisallowedCount int64 1484 UDPPortForwardRejectedDisallowedCount int64 1485 TCPIPv4PortForwardDialedCount int64 1486 TCPIPv4PortForwardDialedDuration time.Duration 1487 TCPIPv4PortForwardFailedCount int64 1488 TCPIPv4PortForwardFailedDuration time.Duration 1489 TCPIPv6PortForwardDialedCount int64 1490 TCPIPv6PortForwardDialedDuration time.Duration 1491 TCPIPv6PortForwardFailedCount int64 1492 TCPIPv6PortForwardFailedDuration time.Duration 1493 DNSCount map[string]int64 1494 DNSDuration map[string]time.Duration 1495 DNSFailedCount map[string]int64 1496 DNSFailedDuration map[string]time.Duration 1497 } 1498 1499 func newQualityMetrics() *qualityMetrics { 1500 return &qualityMetrics{ 1501 DNSCount: make(map[string]int64), 1502 DNSDuration: make(map[string]time.Duration), 1503 DNSFailedCount: make(map[string]int64), 1504 DNSFailedDuration: make(map[string]time.Duration), 1505 } 1506 } 1507 1508 func (q *qualityMetrics) reset() { 1509 1510 q.TCPPortForwardDialedCount = 0 1511 q.TCPPortForwardDialedDuration = 0 1512 q.TCPPortForwardFailedCount = 0 1513 q.TCPPortForwardFailedDuration = 0 1514 q.TCPPortForwardRejectedDialingLimitCount = 0 1515 q.TCPPortForwardRejectedDisallowedCount = 0 1516 1517 q.UDPPortForwardRejectedDisallowedCount = 0 1518 1519 q.TCPIPv4PortForwardDialedCount = 0 1520 q.TCPIPv4PortForwardDialedDuration = 0 1521 q.TCPIPv4PortForwardFailedCount = 0 1522 q.TCPIPv4PortForwardFailedDuration = 0 1523 1524 q.TCPIPv6PortForwardDialedCount = 0 1525 q.TCPIPv6PortForwardDialedDuration = 0 1526 q.TCPIPv6PortForwardFailedCount = 0 1527 q.TCPIPv6PortForwardFailedDuration = 0 1528 1529 // Retain existing maps to avoid memory churn. The Go compiler optimizes map 1530 // clearing operations of the following form. 1531 1532 for k := range q.DNSCount { 1533 delete(q.DNSCount, k) 1534 } 1535 for k := range q.DNSDuration { 1536 delete(q.DNSDuration, k) 1537 } 1538 for k := range q.DNSFailedCount { 1539 delete(q.DNSFailedCount, k) 1540 } 1541 for k := range q.DNSFailedDuration { 1542 delete(q.DNSFailedDuration, k) 1543 } 1544 } 1545 1546 type handshakeStateInfo struct { 1547 activeAuthorizationIDs []string 1548 authorizedAccessTypes []string 1549 upstreamBytesPerSecond int64 1550 downstreamBytesPerSecond int64 1551 } 1552 1553 type handshakeState struct { 1554 completed bool 1555 apiProtocol string 1556 apiParams common.APIParameters 1557 activeAuthorizationIDs []string 1558 authorizedAccessTypes []string 1559 authorizationsRevoked bool 1560 domainBytesChecksum []byte 1561 establishedTunnelsCount int 1562 splitTunnelLookup *splitTunnelLookup 1563 } 1564 1565 type destinationBytesMetrics struct { 1566 bytesUp int64 1567 bytesDown int64 1568 } 1569 1570 func (d *destinationBytesMetrics) UpdateProgress( 1571 downstreamBytes, upstreamBytes, _ int64) { 1572 1573 // Concurrency: UpdateProgress may be called without holding the sshClient 1574 // lock; all accesses to bytesUp/bytesDown must use atomic operations. 1575 1576 atomic.AddInt64(&d.bytesUp, upstreamBytes) 1577 atomic.AddInt64(&d.bytesDown, downstreamBytes) 1578 } 1579 1580 func (d *destinationBytesMetrics) getBytesUp() int64 { 1581 return atomic.LoadInt64(&d.bytesUp) 1582 } 1583 1584 func (d *destinationBytesMetrics) getBytesDown() int64 { 1585 return atomic.LoadInt64(&d.bytesDown) 1586 } 1587 1588 type splitTunnelLookup struct { 1589 regions []string 1590 regionsLookup map[string]bool 1591 } 1592 1593 func newSplitTunnelLookup( 1594 ownRegion string, 1595 otherRegions []string) (*splitTunnelLookup, error) { 1596 1597 length := len(otherRegions) 1598 if ownRegion != "" { 1599 length += 1 1600 } 1601 1602 // This length check is a sanity check and prevents clients shipping 1603 // excessively long lists which could impact performance. 1604 if length > 250 { 1605 return nil, errors.Tracef("too many regions: %d", length) 1606 } 1607 1608 // Create map lookups for lists where the number of values to compare 1609 // against exceeds a threshold where benchmarks show maps are faster than 1610 // looping through a slice. Otherwise use a slice for lookups. In both 1611 // cases, the input slice is no longer referenced. 1612 1613 if length >= stringLookupThreshold { 1614 regionsLookup := make(map[string]bool) 1615 if ownRegion != "" { 1616 regionsLookup[ownRegion] = true 1617 } 1618 for _, region := range otherRegions { 1619 regionsLookup[region] = true 1620 } 1621 return &splitTunnelLookup{ 1622 regionsLookup: regionsLookup, 1623 }, nil 1624 } else { 1625 regions := []string{} 1626 if ownRegion != "" && !common.Contains(otherRegions, ownRegion) { 1627 regions = append(regions, ownRegion) 1628 } 1629 // TODO: check for other duplicate regions? 1630 regions = append(regions, otherRegions...) 1631 return &splitTunnelLookup{ 1632 regions: regions, 1633 }, nil 1634 } 1635 } 1636 1637 func (lookup *splitTunnelLookup) lookup(region string) bool { 1638 if lookup.regionsLookup != nil { 1639 return lookup.regionsLookup[region] 1640 } else { 1641 return common.Contains(lookup.regions, region) 1642 } 1643 } 1644 1645 func newSshClient( 1646 sshServer *sshServer, 1647 sshListener *sshListener, 1648 tunnelProtocol string, 1649 serverPacketManipulation string, 1650 replayedServerPacketManipulation bool, 1651 clientAddr net.Addr, 1652 geoIPData GeoIPData) *sshClient { 1653 1654 runCtx, stopRunning := context.WithCancel(context.Background()) 1655 1656 // isFirstTunnelInSession is defaulted to true so that the pre-handshake 1657 // traffic rules won't apply UnthrottleFirstTunnelOnly and negate any 1658 // unthrottled bytes during the initial protocol negotiation. 1659 1660 client := &sshClient{ 1661 sshServer: sshServer, 1662 sshListener: sshListener, 1663 tunnelProtocol: tunnelProtocol, 1664 serverPacketManipulation: serverPacketManipulation, 1665 replayedServerPacketManipulation: replayedServerPacketManipulation, 1666 clientAddr: clientAddr, 1667 geoIPData: geoIPData, 1668 isFirstTunnelInSession: true, 1669 qualityMetrics: newQualityMetrics(), 1670 tcpPortForwardLRU: common.NewLRUConns(), 1671 signalIssueSLOKs: make(chan struct{}, 1), 1672 runCtx: runCtx, 1673 stopRunning: stopRunning, 1674 stopped: make(chan struct{}), 1675 sendAlertRequests: make(chan protocol.AlertRequest, ALERT_REQUEST_QUEUE_BUFFER_SIZE), 1676 sentAlertRequests: make(map[string]bool), 1677 } 1678 1679 client.tcpTrafficState.availablePortForwardCond = sync.NewCond(new(sync.Mutex)) 1680 client.udpTrafficState.availablePortForwardCond = sync.NewCond(new(sync.Mutex)) 1681 1682 return client 1683 } 1684 1685 func (sshClient *sshClient) run( 1686 baseConn net.Conn, onSSHHandshakeFinished func()) { 1687 1688 // When run returns, the client has fully stopped, with all SSH state torn 1689 // down and no port forwards or API requests in progress. 1690 defer close(sshClient.stopped) 1691 1692 // onSSHHandshakeFinished must be called even if the SSH handshake is aborted. 1693 defer func() { 1694 if onSSHHandshakeFinished != nil { 1695 onSSHHandshakeFinished() 1696 } 1697 }() 1698 1699 // Set initial traffic rules, pre-handshake, based on currently known info. 1700 sshClient.setTrafficRules() 1701 1702 conn := baseConn 1703 1704 // Wrap the base client connection with an ActivityMonitoredConn which will 1705 // terminate the connection if no data is received before the deadline. This 1706 // timeout is in effect for the entire duration of the SSH connection. Clients 1707 // must actively use the connection or send SSH keep alive requests to keep 1708 // the connection active. Writes are not considered reliable activity indicators 1709 // due to buffering. 1710 1711 activityConn, err := common.NewActivityMonitoredConn( 1712 conn, 1713 SSH_CONNECTION_READ_DEADLINE, 1714 false, 1715 nil) 1716 if err != nil { 1717 conn.Close() 1718 if !isExpectedTunnelIOError(err) { 1719 log.WithTraceFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed") 1720 } 1721 return 1722 } 1723 conn = activityConn 1724 1725 // Further wrap the connection with burst monitoring, when enabled. 1726 // 1727 // Limitation: burst parameters are fixed for the duration of the tunnel 1728 // and do not change after a tactics hot reload. 1729 1730 var burstConn *common.BurstMonitoredConn 1731 1732 p, err := sshClient.sshServer.support.ServerTacticsParametersCache.Get(sshClient.geoIPData) 1733 if err != nil { 1734 log.WithTraceFields(LogFields{"error": errors.Trace(err)}).Warning( 1735 "ServerTacticsParametersCache.Get failed") 1736 return 1737 } 1738 1739 if !p.IsNil() { 1740 upstreamTargetBytes := int64(p.Int(parameters.ServerBurstUpstreamTargetBytes)) 1741 upstreamDeadline := p.Duration(parameters.ServerBurstUpstreamDeadline) 1742 downstreamTargetBytes := int64(p.Int(parameters.ServerBurstDownstreamTargetBytes)) 1743 downstreamDeadline := p.Duration(parameters.ServerBurstDownstreamDeadline) 1744 1745 if (upstreamDeadline != 0 && upstreamTargetBytes != 0) || 1746 (downstreamDeadline != 0 && downstreamTargetBytes != 0) { 1747 1748 burstConn = common.NewBurstMonitoredConn( 1749 conn, 1750 true, 1751 upstreamTargetBytes, upstreamDeadline, 1752 downstreamTargetBytes, downstreamDeadline) 1753 conn = burstConn 1754 } 1755 } 1756 1757 // Allow garbage collection. 1758 p.Close() 1759 1760 // Further wrap the connection in a rate limiting ThrottledConn. 1761 1762 throttledConn := common.NewThrottledConn(conn, sshClient.rateLimits()) 1763 conn = throttledConn 1764 1765 // Replay of server-side parameters is set or extended after a new tunnel 1766 // meets duration and bytes transferred targets. Set a timer now that expires 1767 // shortly after the target duration. When the timer fires, check the time of 1768 // last byte read (a read indicating a live connection with the client), 1769 // along with total bytes transferred and set or extend replay if the targets 1770 // are met. 1771 // 1772 // Both target checks are conservative: the tunnel may be healthy, but a byte 1773 // may not have been read in the last second when the timer fires. Or bytes 1774 // may be transferring, but not at the target level. Only clients that meet 1775 // the strict targets at the single check time will trigger replay; however, 1776 // this replay will impact all clients with similar GeoIP data. 1777 // 1778 // A deferred function cancels the timer and also increments the replay 1779 // failure counter, which will ultimately clear replay parameters, when the 1780 // tunnel fails before the API handshake is completed (this includes any 1781 // liveness test). 1782 // 1783 // A tunnel which fails to meet the targets but successfully completes any 1784 // liveness test and the API handshake is ignored in terms of replay scoring. 1785 1786 isReplayCandidate, replayWaitDuration, replayTargetDuration := 1787 sshClient.sshServer.support.ReplayCache.GetReplayTargetDuration(sshClient.geoIPData) 1788 1789 if isReplayCandidate { 1790 1791 getFragmentorSeed := func() *prng.Seed { 1792 fragmentor, ok := baseConn.(common.FragmentorReplayAccessor) 1793 if ok { 1794 fragmentorSeed, _ := fragmentor.GetReplay() 1795 return fragmentorSeed 1796 } 1797 return nil 1798 } 1799 1800 setReplayAfterFunc := time.AfterFunc( 1801 replayWaitDuration, 1802 func() { 1803 if activityConn.GetActiveDuration() >= replayTargetDuration { 1804 1805 sshClient.Lock() 1806 bytesUp := sshClient.tcpTrafficState.bytesUp + sshClient.udpTrafficState.bytesUp 1807 bytesDown := sshClient.tcpTrafficState.bytesDown + sshClient.udpTrafficState.bytesDown 1808 sshClient.Unlock() 1809 1810 sshClient.sshServer.support.ReplayCache.SetReplayParameters( 1811 sshClient.tunnelProtocol, 1812 sshClient.geoIPData, 1813 sshClient.serverPacketManipulation, 1814 getFragmentorSeed(), 1815 bytesUp, 1816 bytesDown) 1817 } 1818 }) 1819 1820 defer func() { 1821 setReplayAfterFunc.Stop() 1822 completed, _ := sshClient.getHandshaked() 1823 if !completed { 1824 1825 // Count a replay failure case when a tunnel used replay parameters 1826 // (excluding OSSH fragmentation, which doesn't use the ReplayCache) and 1827 // failed to complete the API handshake. 1828 1829 replayedFragmentation := false 1830 if sshClient.tunnelProtocol != protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH { 1831 fragmentor, ok := baseConn.(common.FragmentorReplayAccessor) 1832 if ok { 1833 _, replayedFragmentation = fragmentor.GetReplay() 1834 } 1835 } 1836 usedReplay := replayedFragmentation || sshClient.replayedServerPacketManipulation 1837 1838 if usedReplay { 1839 sshClient.sshServer.support.ReplayCache.FailedReplayParameters( 1840 sshClient.tunnelProtocol, 1841 sshClient.geoIPData, 1842 sshClient.serverPacketManipulation, 1843 getFragmentorSeed()) 1844 } 1845 } 1846 }() 1847 } 1848 1849 // Run the initial [obfuscated] SSH handshake in a goroutine so we can both 1850 // respect shutdownBroadcast and implement a specific handshake timeout. 1851 // The timeout is to reclaim network resources in case the handshake takes 1852 // too long. 1853 1854 type sshNewServerConnResult struct { 1855 obfuscatedSSHConn *obfuscator.ObfuscatedSSHConn 1856 sshConn *ssh.ServerConn 1857 channels <-chan ssh.NewChannel 1858 requests <-chan *ssh.Request 1859 err error 1860 } 1861 1862 resultChannel := make(chan *sshNewServerConnResult, 2) 1863 1864 var sshHandshakeAfterFunc *time.Timer 1865 if sshClient.sshServer.support.Config.sshHandshakeTimeout > 0 { 1866 sshHandshakeAfterFunc = time.AfterFunc(sshClient.sshServer.support.Config.sshHandshakeTimeout, func() { 1867 resultChannel <- &sshNewServerConnResult{err: std_errors.New("ssh handshake timeout")} 1868 }) 1869 } 1870 1871 go func(baseConn, conn net.Conn) { 1872 sshServerConfig := &ssh.ServerConfig{ 1873 PasswordCallback: sshClient.passwordCallback, 1874 AuthLogCallback: sshClient.authLogCallback, 1875 ServerVersion: sshClient.sshServer.support.Config.SSHServerVersion, 1876 } 1877 sshServerConfig.AddHostKey(sshClient.sshServer.sshHostKey) 1878 1879 var err error 1880 1881 if protocol.TunnelProtocolUsesObfuscatedSSH(sshClient.tunnelProtocol) { 1882 // With Encrypt-then-MAC hash algorithms, packet length is 1883 // transmitted in plaintext, which aids in traffic analysis; 1884 // clients may still send Encrypt-then-MAC algorithms in their 1885 // KEX_INIT message, but do not select these algorithms. 1886 // 1887 // The exception is TUNNEL_PROTOCOL_SSH, which is intended to appear 1888 // like SSH on the wire. 1889 sshServerConfig.NoEncryptThenMACHash = true 1890 1891 } else { 1892 // For TUNNEL_PROTOCOL_SSH only, randomize KEX. 1893 if sshClient.sshServer.support.Config.ObfuscatedSSHKey != "" { 1894 sshServerConfig.KEXPRNGSeed, err = protocol.DeriveSSHServerKEXPRNGSeed( 1895 sshClient.sshServer.support.Config.ObfuscatedSSHKey) 1896 if err != nil { 1897 err = errors.Trace(err) 1898 } 1899 } 1900 } 1901 1902 result := &sshNewServerConnResult{} 1903 1904 // Wrap the connection in an SSH deobfuscator when required. 1905 1906 if err == nil && protocol.TunnelProtocolUsesObfuscatedSSH(sshClient.tunnelProtocol) { 1907 1908 // Note: NewServerObfuscatedSSHConn blocks on network I/O 1909 // TODO: ensure this won't block shutdown 1910 result.obfuscatedSSHConn, err = obfuscator.NewServerObfuscatedSSHConn( 1911 conn, 1912 sshClient.sshServer.support.Config.ObfuscatedSSHKey, 1913 sshClient.sshServer.obfuscatorSeedHistory, 1914 func(clientIP string, err error, logFields common.LogFields) { 1915 logIrregularTunnel( 1916 sshClient.sshServer.support, 1917 sshClient.sshListener.tunnelProtocol, 1918 sshClient.sshListener.port, 1919 clientIP, 1920 errors.Trace(err), 1921 LogFields(logFields)) 1922 }) 1923 1924 if err != nil { 1925 err = errors.Trace(err) 1926 } else { 1927 conn = result.obfuscatedSSHConn 1928 } 1929 1930 // Seed the fragmentor, when present, with seed derived from initial 1931 // obfuscator message. See tactics.Listener.Accept. This must preceed 1932 // ssh.NewServerConn to ensure fragmentor is seeded before downstream bytes 1933 // are written. 1934 if err == nil && sshClient.tunnelProtocol == protocol.TUNNEL_PROTOCOL_OBFUSCATED_SSH { 1935 fragmentor, ok := baseConn.(common.FragmentorReplayAccessor) 1936 if ok { 1937 var fragmentorPRNG *prng.PRNG 1938 fragmentorPRNG, err = result.obfuscatedSSHConn.GetDerivedPRNG("server-side-fragmentor") 1939 if err != nil { 1940 err = errors.Trace(err) 1941 } else { 1942 fragmentor.SetReplay(fragmentorPRNG) 1943 } 1944 } 1945 } 1946 } 1947 1948 if err == nil { 1949 result.sshConn, result.channels, result.requests, err = 1950 ssh.NewServerConn(conn, sshServerConfig) 1951 if err != nil { 1952 err = errors.Trace(err) 1953 } 1954 } 1955 1956 result.err = err 1957 1958 resultChannel <- result 1959 1960 }(baseConn, conn) 1961 1962 var result *sshNewServerConnResult 1963 select { 1964 case result = <-resultChannel: 1965 case <-sshClient.sshServer.shutdownBroadcast: 1966 // Close() will interrupt an ongoing handshake 1967 // TODO: wait for SSH handshake goroutines to exit before returning? 1968 conn.Close() 1969 return 1970 } 1971 1972 if sshHandshakeAfterFunc != nil { 1973 sshHandshakeAfterFunc.Stop() 1974 } 1975 1976 if result.err != nil { 1977 conn.Close() 1978 // This is a Debug log due to noise. The handshake often fails due to I/O 1979 // errors as clients frequently interrupt connections in progress when 1980 // client-side load balancing completes a connection to a different server. 1981 log.WithTraceFields(LogFields{"error": result.err}).Debug("SSH handshake failed") 1982 return 1983 } 1984 1985 // The SSH handshake has finished successfully; notify now to allow other 1986 // blocked SSH handshakes to proceed. 1987 if onSSHHandshakeFinished != nil { 1988 onSSHHandshakeFinished() 1989 } 1990 onSSHHandshakeFinished = nil 1991 1992 sshClient.Lock() 1993 sshClient.sshConn = result.sshConn 1994 sshClient.throttledConn = throttledConn 1995 sshClient.Unlock() 1996 1997 if !sshClient.sshServer.registerEstablishedClient(sshClient) { 1998 conn.Close() 1999 log.WithTrace().Warning("register failed") 2000 return 2001 } 2002 2003 sshClient.runTunnel(result.channels, result.requests) 2004 2005 // Note: sshServer.unregisterEstablishedClient calls sshClient.stop(), 2006 // which also closes underlying transport Conn. 2007 2008 sshClient.sshServer.unregisterEstablishedClient(sshClient) 2009 2010 // Log tunnel metrics. 2011 2012 var additionalMetrics []LogFields 2013 2014 // Add activity and burst metrics. 2015 // 2016 // The reported duration is based on last confirmed data transfer, which for 2017 // sshClient.activityConn.GetActiveDuration() is time of last read byte and 2018 // not conn close time. This is important for protocols such as meek. For 2019 // meek, the connection remains open until the HTTP session expires, which 2020 // may be some time after the tunnel has closed. (The meek protocol has no 2021 // allowance for signalling payload EOF, and even if it did the client may 2022 // not have the opportunity to send a final request with an EOF flag set.) 2023 2024 activityMetrics := make(LogFields) 2025 activityMetrics["start_time"] = activityConn.GetStartTime() 2026 activityMetrics["duration"] = int64(activityConn.GetActiveDuration() / time.Millisecond) 2027 additionalMetrics = append(additionalMetrics, activityMetrics) 2028 2029 if burstConn != nil { 2030 // Any outstanding burst should be recorded by burstConn.Close which should 2031 // be called by unregisterEstablishedClient. 2032 additionalMetrics = append( 2033 additionalMetrics, LogFields(burstConn.GetMetrics(activityConn.GetStartTime()))) 2034 } 2035 2036 // Some conns report additional metrics. Meek conns report resiliency 2037 // metrics and fragmentor.Conns report fragmentor configs. 2038 2039 if metricsSource, ok := baseConn.(common.MetricsSource); ok { 2040 additionalMetrics = append( 2041 additionalMetrics, LogFields(metricsSource.GetMetrics())) 2042 } 2043 if result.obfuscatedSSHConn != nil { 2044 additionalMetrics = append( 2045 additionalMetrics, LogFields(result.obfuscatedSSHConn.GetMetrics())) 2046 } 2047 2048 // Add server-replay metrics. 2049 2050 replayMetrics := make(LogFields) 2051 replayedFragmentation := false 2052 fragmentor, ok := baseConn.(common.FragmentorReplayAccessor) 2053 if ok { 2054 _, replayedFragmentation = fragmentor.GetReplay() 2055 } 2056 replayMetrics["server_replay_fragmentation"] = replayedFragmentation 2057 replayMetrics["server_replay_packet_manipulation"] = sshClient.replayedServerPacketManipulation 2058 additionalMetrics = append(additionalMetrics, replayMetrics) 2059 2060 // Limitation: there's only one log per tunnel with bytes transferred 2061 // metrics, so the byte count can't be attributed to a certain day for 2062 // tunnels that remain connected for well over 24h. In practise, most 2063 // tunnels are short-lived, especially on mobile devices. 2064 2065 sshClient.logTunnel(additionalMetrics) 2066 2067 // Transfer OSL seed state -- the OSL progress -- from the closing 2068 // client to the session cache so the client can resume its progress 2069 // if it reconnects to this same server. 2070 // Note: following setOSLConfig order of locking. 2071 2072 sshClient.Lock() 2073 if sshClient.oslClientSeedState != nil { 2074 sshClient.sshServer.oslSessionCacheMutex.Lock() 2075 sshClient.oslClientSeedState.Hibernate() 2076 sshClient.sshServer.oslSessionCache.Set( 2077 sshClient.sessionID, sshClient.oslClientSeedState, cache.DefaultExpiration) 2078 sshClient.sshServer.oslSessionCacheMutex.Unlock() 2079 sshClient.oslClientSeedState = nil 2080 } 2081 sshClient.Unlock() 2082 2083 // Initiate cleanup of the GeoIP session cache. To allow for post-tunnel 2084 // final status requests, the lifetime of cached GeoIP records exceeds the 2085 // lifetime of the sshClient. 2086 sshClient.sshServer.support.GeoIPService.MarkSessionCacheToExpire(sshClient.sessionID) 2087 } 2088 2089 func (sshClient *sshClient) passwordCallback(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { 2090 2091 expectedSessionIDLength := 2 * protocol.PSIPHON_API_CLIENT_SESSION_ID_LENGTH 2092 expectedSSHPasswordLength := 2 * SSH_PASSWORD_BYTE_LENGTH 2093 2094 var sshPasswordPayload protocol.SSHPasswordPayload 2095 err := json.Unmarshal(password, &sshPasswordPayload) 2096 if err != nil { 2097 2098 // Backwards compatibility case: instead of a JSON payload, older clients 2099 // send the hex encoded session ID prepended to the SSH password. 2100 // Note: there's an even older case where clients don't send any session ID, 2101 // but that's no longer supported. 2102 if len(password) == expectedSessionIDLength+expectedSSHPasswordLength { 2103 sshPasswordPayload.SessionId = string(password[0:expectedSessionIDLength]) 2104 sshPasswordPayload.SshPassword = string(password[expectedSessionIDLength:]) 2105 } else { 2106 return nil, errors.Tracef("invalid password payload for %q", conn.User()) 2107 } 2108 } 2109 2110 if !isHexDigits(sshClient.sshServer.support.Config, sshPasswordPayload.SessionId) || 2111 len(sshPasswordPayload.SessionId) != expectedSessionIDLength { 2112 return nil, errors.Tracef("invalid session ID for %q", conn.User()) 2113 } 2114 2115 userOk := (subtle.ConstantTimeCompare( 2116 []byte(conn.User()), []byte(sshClient.sshServer.support.Config.SSHUserName)) == 1) 2117 2118 passwordOk := (subtle.ConstantTimeCompare( 2119 []byte(sshPasswordPayload.SshPassword), []byte(sshClient.sshServer.support.Config.SSHPassword)) == 1) 2120 2121 if !userOk || !passwordOk { 2122 return nil, errors.Tracef("invalid password for %q", conn.User()) 2123 } 2124 2125 sessionID := sshPasswordPayload.SessionId 2126 2127 // The GeoIP session cache will be populated if there was a previous tunnel 2128 // with this session ID. This will be true up to GEOIP_SESSION_CACHE_TTL, which 2129 // is currently much longer than the OSL session cache, another option to use if 2130 // the GeoIP session cache is retired (the GeoIP session cache currently only 2131 // supports legacy use cases). 2132 isFirstTunnelInSession := !sshClient.sshServer.support.GeoIPService.InSessionCache(sessionID) 2133 2134 supportsServerRequests := common.Contains( 2135 sshPasswordPayload.ClientCapabilities, protocol.CLIENT_CAPABILITY_SERVER_REQUESTS) 2136 2137 sshClient.Lock() 2138 2139 // After this point, these values are read-only as they are read 2140 // without obtaining sshClient.Lock. 2141 sshClient.sessionID = sessionID 2142 sshClient.isFirstTunnelInSession = isFirstTunnelInSession 2143 sshClient.supportsServerRequests = supportsServerRequests 2144 2145 geoIPData := sshClient.geoIPData 2146 2147 sshClient.Unlock() 2148 2149 // Store the GeoIP data associated with the session ID. This makes 2150 // the GeoIP data available to the web server for web API requests. 2151 // A cache that's distinct from the sshClient record is used to allow 2152 // for or post-tunnel final status requests. 2153 // If the client is reconnecting with the same session ID, this call 2154 // will undo the expiry set by MarkSessionCacheToExpire. 2155 sshClient.sshServer.support.GeoIPService.SetSessionCache(sessionID, geoIPData) 2156 2157 return nil, nil 2158 } 2159 2160 func (sshClient *sshClient) authLogCallback(conn ssh.ConnMetadata, method string, err error) { 2161 2162 if err != nil { 2163 2164 if method == "none" && err.Error() == "ssh: no auth passed yet" { 2165 // In this case, the callback invocation is noise from auth negotiation 2166 return 2167 } 2168 2169 // Note: here we previously logged messages for fail2ban to act on. This is no longer 2170 // done as the complexity outweighs the benefits. 2171 // 2172 // - The SSH credential is not secret -- it's in the server entry. Attackers targeting 2173 // the server likely already have the credential. On the other hand, random scanning and 2174 // brute forcing is mitigated with high entropy random passwords, rate limiting 2175 // (implemented on the host via iptables), and limited capabilities (the SSH session can 2176 // only port forward). 2177 // 2178 // - fail2ban coverage was inconsistent; in the case of an unfronted meek protocol through 2179 // an upstream proxy, the remote address is the upstream proxy, which should not be blocked. 2180 // The X-Forwarded-For header cant be used instead as it may be forged and used to get IPs 2181 // deliberately blocked; and in any case fail2ban adds iptables rules which can only block 2182 // by direct remote IP, not by original client IP. Fronted meek has the same iptables issue. 2183 // 2184 // Random scanning and brute forcing of port 22 will result in log noise. To mitigate this, 2185 // not every authentication failure is logged. A summary log is emitted periodically to 2186 // retain some record of this activity in case this is relevant to, e.g., a performance 2187 // investigation. 2188 2189 atomic.AddInt64(&sshClient.sshServer.authFailedCount, 1) 2190 2191 lastAuthLog := monotime.Time(atomic.LoadInt64(&sshClient.sshServer.lastAuthLog)) 2192 if monotime.Since(lastAuthLog) > SSH_AUTH_LOG_PERIOD { 2193 now := int64(monotime.Now()) 2194 if atomic.CompareAndSwapInt64(&sshClient.sshServer.lastAuthLog, int64(lastAuthLog), now) { 2195 count := atomic.SwapInt64(&sshClient.sshServer.authFailedCount, 0) 2196 log.WithTraceFields( 2197 LogFields{"lastError": err, "failedCount": count}).Warning("authentication failures") 2198 } 2199 } 2200 2201 log.WithTraceFields(LogFields{"error": err, "method": method}).Debug("authentication failed") 2202 2203 } else { 2204 2205 log.WithTraceFields(LogFields{"error": err, "method": method}).Debug("authentication success") 2206 } 2207 } 2208 2209 // stop signals the ssh connection to shutdown. After sshConn.Wait returns, 2210 // the SSH connection has terminated but sshClient.run may still be running and 2211 // in the process of exiting. 2212 // 2213 // The shutdown process must complete rapidly and not, e.g., block on network 2214 // I/O, as newly connecting clients need to await stop completion of any 2215 // existing connection that shares the same session ID. 2216 func (sshClient *sshClient) stop() { 2217 sshClient.sshConn.Close() 2218 sshClient.sshConn.Wait() 2219 } 2220 2221 // awaitStopped will block until sshClient.run has exited, at which point all 2222 // worker goroutines associated with the sshClient, including any in-flight 2223 // API handlers, will have exited. 2224 func (sshClient *sshClient) awaitStopped() { 2225 <-sshClient.stopped 2226 } 2227 2228 // runTunnel handles/dispatches new channels and new requests from the client. 2229 // When the SSH client connection closes, both the channels and requests channels 2230 // will close and runTunnel will exit. 2231 func (sshClient *sshClient) runTunnel( 2232 channels <-chan ssh.NewChannel, 2233 requests <-chan *ssh.Request) { 2234 2235 waitGroup := new(sync.WaitGroup) 2236 2237 // Start client SSH API request handler 2238 2239 waitGroup.Add(1) 2240 go func() { 2241 defer waitGroup.Done() 2242 sshClient.handleSSHRequests(requests) 2243 }() 2244 2245 // Start request senders 2246 2247 if sshClient.supportsServerRequests { 2248 2249 waitGroup.Add(1) 2250 go func() { 2251 defer waitGroup.Done() 2252 sshClient.runOSLSender() 2253 }() 2254 2255 waitGroup.Add(1) 2256 go func() { 2257 defer waitGroup.Done() 2258 sshClient.runAlertSender() 2259 }() 2260 } 2261 2262 // Start the TCP port forward manager 2263 2264 // The queue size is set to the traffic rules (MaxTCPPortForwardCount + 2265 // MaxTCPDialingPortForwardCount), which is a reasonable indication of resource 2266 // limits per client; when that value is not set, a default is used. 2267 // A limitation: this queue size is set once and doesn't change, for this client, 2268 // when traffic rules are reloaded. 2269 queueSize := sshClient.getTCPPortForwardQueueSize() 2270 if queueSize == 0 { 2271 queueSize = SSH_TCP_PORT_FORWARD_QUEUE_SIZE 2272 } 2273 newTCPPortForwards := make(chan *newTCPPortForward, queueSize) 2274 2275 waitGroup.Add(1) 2276 go func() { 2277 defer waitGroup.Done() 2278 sshClient.handleTCPPortForwards(waitGroup, newTCPPortForwards) 2279 }() 2280 2281 // Handle new channel (port forward) requests from the client. 2282 2283 for newChannel := range channels { 2284 switch newChannel.ChannelType() { 2285 case protocol.RANDOM_STREAM_CHANNEL_TYPE: 2286 sshClient.handleNewRandomStreamChannel(waitGroup, newChannel) 2287 case protocol.PACKET_TUNNEL_CHANNEL_TYPE: 2288 sshClient.handleNewPacketTunnelChannel(waitGroup, newChannel) 2289 case protocol.TCP_PORT_FORWARD_NO_SPLIT_TUNNEL_TYPE: 2290 // The protocol.TCP_PORT_FORWARD_NO_SPLIT_TUNNEL_TYPE is the same as 2291 // "direct-tcpip", except split tunnel channel rejections are disallowed 2292 // even if the client has enabled split tunnel. This channel type allows 2293 // the client to ensure tunneling for certain cases while split tunnel is 2294 // enabled. 2295 sshClient.handleNewTCPPortForwardChannel(waitGroup, newChannel, false, newTCPPortForwards) 2296 case "direct-tcpip": 2297 sshClient.handleNewTCPPortForwardChannel(waitGroup, newChannel, true, newTCPPortForwards) 2298 default: 2299 sshClient.rejectNewChannel(newChannel, 2300 fmt.Sprintf("unknown or unsupported channel type: %s", newChannel.ChannelType())) 2301 } 2302 } 2303 2304 // The channel loop is interrupted by a client 2305 // disconnect or by calling sshClient.stop(). 2306 2307 // Stop the TCP port forward manager 2308 close(newTCPPortForwards) 2309 2310 // Stop all other worker goroutines 2311 sshClient.stopRunning() 2312 2313 if sshClient.sshServer.support.Config.RunPacketTunnel { 2314 // PacketTunnelServer.ClientDisconnected stops packet tunnel workers. 2315 sshClient.sshServer.support.PacketTunnelServer.ClientDisconnected( 2316 sshClient.sessionID) 2317 } 2318 2319 waitGroup.Wait() 2320 2321 sshClient.cleanupAuthorizations() 2322 } 2323 2324 func (sshClient *sshClient) handleSSHRequests(requests <-chan *ssh.Request) { 2325 2326 for request := range requests { 2327 2328 // Requests are processed serially; API responses must be sent in request order. 2329 2330 var responsePayload []byte 2331 var err error 2332 2333 if request.Type == "keepalive@openssh.com" { 2334 2335 // SSH keep alive round trips are used as speed test samples. 2336 responsePayload, err = tactics.MakeSpeedTestResponse( 2337 SSH_KEEP_ALIVE_PAYLOAD_MIN_BYTES, SSH_KEEP_ALIVE_PAYLOAD_MAX_BYTES) 2338 2339 } else { 2340 2341 // All other requests are assumed to be API requests. 2342 2343 sshClient.Lock() 2344 authorizedAccessTypes := sshClient.handshakeState.authorizedAccessTypes 2345 sshClient.Unlock() 2346 2347 // Note: unlock before use is only safe as long as referenced sshClient data, 2348 // such as slices in handshakeState, is read-only after initially set. 2349 2350 clientAddr := "" 2351 if sshClient.clientAddr != nil { 2352 clientAddr = sshClient.clientAddr.String() 2353 } 2354 2355 responsePayload, err = sshAPIRequestHandler( 2356 sshClient.sshServer.support, 2357 clientAddr, 2358 sshClient.geoIPData, 2359 authorizedAccessTypes, 2360 request.Type, 2361 request.Payload) 2362 } 2363 2364 if err == nil { 2365 err = request.Reply(true, responsePayload) 2366 } else { 2367 log.WithTraceFields(LogFields{"error": err}).Warning("request failed") 2368 err = request.Reply(false, nil) 2369 } 2370 if err != nil { 2371 if !isExpectedTunnelIOError(err) { 2372 log.WithTraceFields(LogFields{"error": err}).Warning("response failed") 2373 } 2374 } 2375 2376 } 2377 2378 } 2379 2380 type newTCPPortForward struct { 2381 enqueueTime time.Time 2382 hostToConnect string 2383 portToConnect int 2384 doSplitTunnel bool 2385 newChannel ssh.NewChannel 2386 } 2387 2388 func (sshClient *sshClient) handleTCPPortForwards( 2389 waitGroup *sync.WaitGroup, 2390 newTCPPortForwards chan *newTCPPortForward) { 2391 2392 // Lifecycle of a TCP port forward: 2393 // 2394 // 1. A "direct-tcpip" SSH request is received from the client. 2395 // 2396 // A new TCP port forward request is enqueued. The queue delivers TCP port 2397 // forward requests to the TCP port forward manager, which enforces the TCP 2398 // port forward dial limit. 2399 // 2400 // Enqueuing new requests allows for reading further SSH requests from the 2401 // client without blocking when the dial limit is hit; this is to permit new 2402 // UDP/udpgw port forwards to be restablished without delay. The maximum size 2403 // of the queue enforces a hard cap on resources consumed by a client in the 2404 // pre-dial phase. When the queue is full, new TCP port forwards are 2405 // immediately rejected. 2406 // 2407 // 2. The TCP port forward manager dequeues the request. 2408 // 2409 // The manager calls dialingTCPPortForward(), which increments 2410 // concurrentDialingPortForwardCount, and calls 2411 // isTCPDialingPortForwardLimitExceeded() to check the concurrent dialing 2412 // count. 2413 // 2414 // The manager enforces the concurrent TCP dial limit: when at the limit, the 2415 // manager blocks waiting for the number of dials to drop below the limit before 2416 // dispatching the request to handleTCPPortForward(), which will run in its own 2417 // goroutine and will dial and relay the port forward. 2418 // 2419 // The block delays the current request and also halts dequeuing of subsequent 2420 // requests and could ultimately cause requests to be immediately rejected if 2421 // the queue fills. These actions are intended to apply back pressure when 2422 // upstream network resources are impaired. 2423 // 2424 // The time spent in the queue is deducted from the port forward's dial timeout. 2425 // The time spent blocking while at the dial limit is similarly deducted from 2426 // the dial timeout. If the dial timeout has expired before the dial begins, the 2427 // port forward is rejected and a stat is recorded. 2428 // 2429 // 3. handleTCPPortForward() performs the port forward dial and relaying. 2430 // 2431 // a. Dial the target, using the dial timeout remaining after queue and blocking 2432 // time is deducted. 2433 // 2434 // b. If the dial fails, call abortedTCPPortForward() to decrement 2435 // concurrentDialingPortForwardCount, freeing up a dial slot. 2436 // 2437 // c. If the dial succeeds, call establishedPortForward(), which decrements 2438 // concurrentDialingPortForwardCount and increments concurrentPortForwardCount, 2439 // the "established" port forward count. 2440 // 2441 // d. Check isPortForwardLimitExceeded(), which enforces the configurable limit on 2442 // concurrentPortForwardCount, the number of _established_ TCP port forwards. 2443 // If the limit is exceeded, the LRU established TCP port forward is closed and 2444 // the newly established TCP port forward proceeds. This LRU logic allows some 2445 // dangling resource consumption (e.g., TIME_WAIT) while providing a better 2446 // experience for clients. 2447 // 2448 // e. Relay data. 2449 // 2450 // f. Call closedPortForward() which decrements concurrentPortForwardCount and 2451 // records bytes transferred. 2452 2453 for newPortForward := range newTCPPortForwards { 2454 2455 remainingDialTimeout := 2456 time.Duration(sshClient.getDialTCPPortForwardTimeoutMilliseconds())*time.Millisecond - 2457 time.Since(newPortForward.enqueueTime) 2458 2459 if remainingDialTimeout <= 0 { 2460 sshClient.updateQualityMetricsWithRejectedDialingLimit() 2461 sshClient.rejectNewChannel( 2462 newPortForward.newChannel, "TCP port forward timed out in queue") 2463 continue 2464 } 2465 2466 // Reserve a TCP dialing slot. 2467 // 2468 // TOCTOU note: important to increment counts _before_ checking limits; otherwise, 2469 // the client could potentially consume excess resources by initiating many port 2470 // forwards concurrently. 2471 2472 sshClient.dialingTCPPortForward() 2473 2474 // When max dials are in progress, wait up to remainingDialTimeout for dialing 2475 // to become available. This blocks all dequeing. 2476 2477 if sshClient.isTCPDialingPortForwardLimitExceeded() { 2478 blockStartTime := time.Now() 2479 ctx, cancelCtx := context.WithTimeout(sshClient.runCtx, remainingDialTimeout) 2480 sshClient.setTCPPortForwardDialingAvailableSignal(cancelCtx) 2481 <-ctx.Done() 2482 sshClient.setTCPPortForwardDialingAvailableSignal(nil) 2483 cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled" 2484 remainingDialTimeout -= time.Since(blockStartTime) 2485 } 2486 2487 if remainingDialTimeout <= 0 { 2488 2489 // Release the dialing slot here since handleTCPChannel() won't be called. 2490 sshClient.abortedTCPPortForward() 2491 2492 sshClient.updateQualityMetricsWithRejectedDialingLimit() 2493 sshClient.rejectNewChannel( 2494 newPortForward.newChannel, "TCP port forward timed out before dialing") 2495 continue 2496 } 2497 2498 // Dial and relay the TCP port forward. handleTCPChannel is run in its own worker goroutine. 2499 // handleTCPChannel will release the dialing slot reserved by dialingTCPPortForward(); and 2500 // will deal with remainingDialTimeout <= 0. 2501 2502 waitGroup.Add(1) 2503 go func(remainingDialTimeout time.Duration, newPortForward *newTCPPortForward) { 2504 defer waitGroup.Done() 2505 sshClient.handleTCPChannel( 2506 remainingDialTimeout, 2507 newPortForward.hostToConnect, 2508 newPortForward.portToConnect, 2509 newPortForward.doSplitTunnel, 2510 newPortForward.newChannel) 2511 }(remainingDialTimeout, newPortForward) 2512 } 2513 } 2514 2515 func (sshClient *sshClient) handleNewRandomStreamChannel( 2516 waitGroup *sync.WaitGroup, newChannel ssh.NewChannel) { 2517 2518 // A random stream channel returns the requested number of bytes -- random 2519 // bytes -- to the client while also consuming and discarding bytes sent 2520 // by the client. 2521 // 2522 // One use case for the random stream channel is a liveness test that the 2523 // client performs to confirm that the tunnel is live. As the liveness 2524 // test is performed in the concurrent establishment phase, before 2525 // selecting a single candidate for handshake, the random stream channel 2526 // is available pre-handshake, albeit with additional restrictions. 2527 // 2528 // The random stream is subject to throttling in traffic rules; for 2529 // unthrottled liveness tests, set EstablishmentRead/WriteBytesPerSecond as 2530 // required. The random stream maximum count and response size cap mitigate 2531 // clients abusing the facility to waste server resources. 2532 // 2533 // Like all other channels, this channel type is handled asynchronously, 2534 // so it's possible to run at any point in the tunnel lifecycle. 2535 // 2536 // Up/downstream byte counts don't include SSH packet and request 2537 // marshalling overhead. 2538 2539 var request protocol.RandomStreamRequest 2540 err := json.Unmarshal(newChannel.ExtraData(), &request) 2541 if err != nil { 2542 sshClient.rejectNewChannel(newChannel, fmt.Sprintf("invalid request: %s", err)) 2543 return 2544 } 2545 2546 if request.UpstreamBytes > RANDOM_STREAM_MAX_BYTES { 2547 sshClient.rejectNewChannel(newChannel, 2548 fmt.Sprintf("invalid upstream bytes: %d", request.UpstreamBytes)) 2549 return 2550 } 2551 2552 if request.DownstreamBytes > RANDOM_STREAM_MAX_BYTES { 2553 sshClient.rejectNewChannel(newChannel, 2554 fmt.Sprintf("invalid downstream bytes: %d", request.DownstreamBytes)) 2555 return 2556 } 2557 2558 var metrics *randomStreamMetrics 2559 2560 sshClient.Lock() 2561 2562 if !sshClient.handshakeState.completed { 2563 metrics = &sshClient.preHandshakeRandomStreamMetrics 2564 } else { 2565 metrics = &sshClient.postHandshakeRandomStreamMetrics 2566 } 2567 2568 countOk := true 2569 if !sshClient.handshakeState.completed && 2570 metrics.count >= PRE_HANDSHAKE_RANDOM_STREAM_MAX_COUNT { 2571 countOk = false 2572 } else { 2573 metrics.count++ 2574 } 2575 2576 sshClient.Unlock() 2577 2578 if !countOk { 2579 sshClient.rejectNewChannel(newChannel, "max count exceeded") 2580 return 2581 } 2582 2583 channel, requests, err := newChannel.Accept() 2584 if err != nil { 2585 if !isExpectedTunnelIOError(err) { 2586 log.WithTraceFields(LogFields{"error": err}).Warning("accept new channel failed") 2587 } 2588 return 2589 } 2590 go ssh.DiscardRequests(requests) 2591 2592 waitGroup.Add(1) 2593 go func() { 2594 defer waitGroup.Done() 2595 2596 upstream := new(sync.WaitGroup) 2597 received := 0 2598 sent := 0 2599 2600 if request.UpstreamBytes > 0 { 2601 2602 // Process streams concurrently to minimize elapsed time. This also 2603 // avoids a unidirectional flow burst early in the tunnel lifecycle. 2604 2605 upstream.Add(1) 2606 go func() { 2607 defer upstream.Done() 2608 n, err := io.CopyN(ioutil.Discard, channel, int64(request.UpstreamBytes)) 2609 received = int(n) 2610 if err != nil { 2611 if !isExpectedTunnelIOError(err) { 2612 log.WithTraceFields(LogFields{"error": err}).Warning("receive failed") 2613 } 2614 } 2615 }() 2616 } 2617 2618 if request.DownstreamBytes > 0 { 2619 n, err := io.CopyN(channel, rand.Reader, int64(request.DownstreamBytes)) 2620 sent = int(n) 2621 if err != nil { 2622 if !isExpectedTunnelIOError(err) { 2623 log.WithTraceFields(LogFields{"error": err}).Warning("send failed") 2624 } 2625 } 2626 } 2627 2628 upstream.Wait() 2629 2630 sshClient.Lock() 2631 metrics.upstreamBytes += int64(request.UpstreamBytes) 2632 metrics.receivedUpstreamBytes += int64(received) 2633 metrics.downstreamBytes += int64(request.DownstreamBytes) 2634 metrics.sentDownstreamBytes += int64(sent) 2635 sshClient.Unlock() 2636 2637 channel.Close() 2638 }() 2639 } 2640 2641 func (sshClient *sshClient) handleNewPacketTunnelChannel( 2642 waitGroup *sync.WaitGroup, newChannel ssh.NewChannel) { 2643 2644 // packet tunnel channels are handled by the packet tunnel server 2645 // component. Each client may have at most one packet tunnel channel. 2646 2647 if !sshClient.sshServer.support.Config.RunPacketTunnel { 2648 sshClient.rejectNewChannel(newChannel, "unsupported packet tunnel channel type") 2649 return 2650 } 2651 2652 // Accept this channel immediately. This channel will replace any 2653 // previously existing packet tunnel channel for this client. 2654 2655 packetTunnelChannel, requests, err := newChannel.Accept() 2656 if err != nil { 2657 if !isExpectedTunnelIOError(err) { 2658 log.WithTraceFields(LogFields{"error": err}).Warning("accept new channel failed") 2659 } 2660 return 2661 } 2662 go ssh.DiscardRequests(requests) 2663 2664 sshClient.setPacketTunnelChannel(packetTunnelChannel) 2665 2666 // PacketTunnelServer will run the client's packet tunnel. If necessary, ClientConnected 2667 // will stop packet tunnel workers for any previous packet tunnel channel. 2668 2669 checkAllowedTCPPortFunc := func(upstreamIPAddress net.IP, port int) bool { 2670 return sshClient.isPortForwardPermitted(portForwardTypeTCP, upstreamIPAddress, port) 2671 } 2672 2673 checkAllowedUDPPortFunc := func(upstreamIPAddress net.IP, port int) bool { 2674 return sshClient.isPortForwardPermitted(portForwardTypeUDP, upstreamIPAddress, port) 2675 } 2676 2677 checkAllowedDomainFunc := func(domain string) bool { 2678 ok, _ := sshClient.isDomainPermitted(domain) 2679 return ok 2680 } 2681 2682 flowActivityUpdaterMaker := func( 2683 isTCP bool, upstreamHostname string, upstreamIPAddress net.IP) []tun.FlowActivityUpdater { 2684 2685 trafficType := portForwardTypeTCP 2686 if !isTCP { 2687 trafficType = portForwardTypeUDP 2688 } 2689 2690 activityUpdaters := sshClient.getActivityUpdaters(trafficType, upstreamIPAddress) 2691 2692 flowUpdaters := make([]tun.FlowActivityUpdater, len(activityUpdaters)) 2693 for i, activityUpdater := range activityUpdaters { 2694 flowUpdaters[i] = activityUpdater 2695 } 2696 2697 return flowUpdaters 2698 } 2699 2700 metricUpdater := func( 2701 TCPApplicationBytesDown, TCPApplicationBytesUp, 2702 UDPApplicationBytesDown, UDPApplicationBytesUp int64) { 2703 2704 sshClient.Lock() 2705 sshClient.tcpTrafficState.bytesDown += TCPApplicationBytesDown 2706 sshClient.tcpTrafficState.bytesUp += TCPApplicationBytesUp 2707 sshClient.udpTrafficState.bytesDown += UDPApplicationBytesDown 2708 sshClient.udpTrafficState.bytesUp += UDPApplicationBytesUp 2709 sshClient.Unlock() 2710 } 2711 2712 dnsQualityReporter := sshClient.updateQualityMetricsWithDNSResult 2713 2714 err = sshClient.sshServer.support.PacketTunnelServer.ClientConnected( 2715 sshClient.sessionID, 2716 packetTunnelChannel, 2717 checkAllowedTCPPortFunc, 2718 checkAllowedUDPPortFunc, 2719 checkAllowedDomainFunc, 2720 flowActivityUpdaterMaker, 2721 metricUpdater, 2722 dnsQualityReporter) 2723 if err != nil { 2724 log.WithTraceFields(LogFields{"error": err}).Warning("start packet tunnel client failed") 2725 sshClient.setPacketTunnelChannel(nil) 2726 } 2727 } 2728 2729 func (sshClient *sshClient) handleNewTCPPortForwardChannel( 2730 waitGroup *sync.WaitGroup, 2731 newChannel ssh.NewChannel, 2732 allowSplitTunnel bool, 2733 newTCPPortForwards chan *newTCPPortForward) { 2734 2735 // udpgw client connections are dispatched immediately (clients use this for 2736 // DNS, so it's essential to not block; and only one udpgw connection is 2737 // retained at a time). 2738 // 2739 // All other TCP port forwards are dispatched via the TCP port forward 2740 // manager queue. 2741 2742 // http://tools.ietf.org/html/rfc4254#section-7.2 2743 var directTcpipExtraData struct { 2744 HostToConnect string 2745 PortToConnect uint32 2746 OriginatorIPAddress string 2747 OriginatorPort uint32 2748 } 2749 2750 err := ssh.Unmarshal(newChannel.ExtraData(), &directTcpipExtraData) 2751 if err != nil { 2752 sshClient.rejectNewChannel(newChannel, "invalid extra data") 2753 return 2754 } 2755 2756 // Intercept TCP port forwards to a specified udpgw server and handle directly. 2757 // TODO: also support UDP explicitly, e.g. with a custom "direct-udp" channel type? 2758 isUdpgwChannel := sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress != "" && 2759 sshClient.sshServer.support.Config.UDPInterceptUdpgwServerAddress == 2760 net.JoinHostPort(directTcpipExtraData.HostToConnect, strconv.Itoa(int(directTcpipExtraData.PortToConnect))) 2761 2762 if isUdpgwChannel { 2763 2764 // Dispatch immediately. handleUDPChannel runs the udpgw protocol in its 2765 // own worker goroutine. 2766 2767 waitGroup.Add(1) 2768 go func(channel ssh.NewChannel) { 2769 defer waitGroup.Done() 2770 sshClient.handleUdpgwChannel(channel) 2771 }(newChannel) 2772 2773 } else { 2774 2775 // Dispatch via TCP port forward manager. When the queue is full, the channel 2776 // is immediately rejected. 2777 // 2778 // Split tunnel logic is enabled for this TCP port forward when the client 2779 // has enabled split tunnel mode and the channel type allows it. 2780 2781 doSplitTunnel := sshClient.handshakeState.splitTunnelLookup != nil && allowSplitTunnel 2782 2783 tcpPortForward := &newTCPPortForward{ 2784 enqueueTime: time.Now(), 2785 hostToConnect: directTcpipExtraData.HostToConnect, 2786 portToConnect: int(directTcpipExtraData.PortToConnect), 2787 doSplitTunnel: doSplitTunnel, 2788 newChannel: newChannel, 2789 } 2790 2791 select { 2792 case newTCPPortForwards <- tcpPortForward: 2793 default: 2794 sshClient.updateQualityMetricsWithRejectedDialingLimit() 2795 sshClient.rejectNewChannel(newChannel, "TCP port forward dial queue full") 2796 } 2797 } 2798 } 2799 2800 func (sshClient *sshClient) cleanupAuthorizations() { 2801 sshClient.Lock() 2802 2803 if sshClient.releaseAuthorizations != nil { 2804 sshClient.releaseAuthorizations() 2805 } 2806 2807 if sshClient.stopTimer != nil { 2808 sshClient.stopTimer.Stop() 2809 } 2810 2811 sshClient.Unlock() 2812 } 2813 2814 // setPacketTunnelChannel sets the single packet tunnel channel 2815 // for this sshClient. Any existing packet tunnel channel is 2816 // closed. 2817 func (sshClient *sshClient) setPacketTunnelChannel(channel ssh.Channel) { 2818 sshClient.Lock() 2819 if sshClient.packetTunnelChannel != nil { 2820 sshClient.packetTunnelChannel.Close() 2821 } 2822 sshClient.packetTunnelChannel = channel 2823 sshClient.totalPacketTunnelChannelCount += 1 2824 sshClient.Unlock() 2825 } 2826 2827 // setUdpgwChannelHandler sets the single udpgw channel handler for this 2828 // sshClient. Each sshClient may have only one concurrent udpgw 2829 // channel/handler. Each udpgw channel multiplexes many UDP port forwards via 2830 // the udpgw protocol. Any existing udpgw channel/handler is closed. 2831 func (sshClient *sshClient) setUdpgwChannelHandler(udpgwChannelHandler *udpgwPortForwardMultiplexer) bool { 2832 sshClient.Lock() 2833 if sshClient.udpgwChannelHandler != nil { 2834 previousHandler := sshClient.udpgwChannelHandler 2835 sshClient.udpgwChannelHandler = nil 2836 2837 // stop must be run without holding the sshClient mutex lock, as the 2838 // udpgw goroutines may attempt to lock the same mutex. For example, 2839 // udpgwPortForwardMultiplexer.run calls sshClient.establishedPortForward 2840 // which calls sshClient.allocatePortForward. 2841 sshClient.Unlock() 2842 previousHandler.stop() 2843 sshClient.Lock() 2844 2845 // In case some other channel has set the sshClient.udpgwChannelHandler 2846 // in the meantime, fail. The caller should discard this channel/handler. 2847 if sshClient.udpgwChannelHandler != nil { 2848 sshClient.Unlock() 2849 return false 2850 } 2851 } 2852 sshClient.udpgwChannelHandler = udpgwChannelHandler 2853 sshClient.totalUdpgwChannelCount += 1 2854 sshClient.Unlock() 2855 return true 2856 } 2857 2858 var serverTunnelStatParams = append( 2859 []requestParamSpec{ 2860 {"last_connected", isLastConnected, requestParamOptional}, 2861 {"establishment_duration", isIntString, requestParamOptional}}, 2862 baseSessionAndDialParams...) 2863 2864 func (sshClient *sshClient) logTunnel(additionalMetrics []LogFields) { 2865 2866 sshClient.Lock() 2867 2868 logFields := getRequestLogFields( 2869 "server_tunnel", 2870 sshClient.geoIPData, 2871 sshClient.handshakeState.authorizedAccessTypes, 2872 sshClient.handshakeState.apiParams, 2873 serverTunnelStatParams) 2874 2875 // "relay_protocol" is sent with handshake API parameters. In pre- 2876 // handshake logTunnel cases, this value is not yet known. As 2877 // sshClient.tunnelProtocol is authoritative, set this value 2878 // unconditionally, overwriting any value from handshake. 2879 logFields["relay_protocol"] = sshClient.tunnelProtocol 2880 2881 if sshClient.serverPacketManipulation != "" { 2882 logFields["server_packet_manipulation"] = sshClient.serverPacketManipulation 2883 } 2884 if sshClient.sshListener.BPFProgramName != "" { 2885 logFields["server_bpf"] = sshClient.sshListener.BPFProgramName 2886 } 2887 logFields["session_id"] = sshClient.sessionID 2888 logFields["is_first_tunnel_in_session"] = sshClient.isFirstTunnelInSession 2889 logFields["handshake_completed"] = sshClient.handshakeState.completed 2890 logFields["bytes_up_tcp"] = sshClient.tcpTrafficState.bytesUp 2891 logFields["bytes_down_tcp"] = sshClient.tcpTrafficState.bytesDown 2892 logFields["peak_concurrent_dialing_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentDialingPortForwardCount 2893 logFields["peak_concurrent_port_forward_count_tcp"] = sshClient.tcpTrafficState.peakConcurrentPortForwardCount 2894 logFields["total_port_forward_count_tcp"] = sshClient.tcpTrafficState.totalPortForwardCount 2895 logFields["bytes_up_udp"] = sshClient.udpTrafficState.bytesUp 2896 logFields["bytes_down_udp"] = sshClient.udpTrafficState.bytesDown 2897 // sshClient.udpTrafficState.peakConcurrentDialingPortForwardCount isn't meaningful 2898 logFields["peak_concurrent_port_forward_count_udp"] = sshClient.udpTrafficState.peakConcurrentPortForwardCount 2899 logFields["total_port_forward_count_udp"] = sshClient.udpTrafficState.totalPortForwardCount 2900 logFields["total_udpgw_channel_count"] = sshClient.totalUdpgwChannelCount 2901 logFields["total_packet_tunnel_channel_count"] = sshClient.totalPacketTunnelChannelCount 2902 2903 logFields["pre_handshake_random_stream_count"] = sshClient.preHandshakeRandomStreamMetrics.count 2904 logFields["pre_handshake_random_stream_upstream_bytes"] = sshClient.preHandshakeRandomStreamMetrics.upstreamBytes 2905 logFields["pre_handshake_random_stream_received_upstream_bytes"] = sshClient.preHandshakeRandomStreamMetrics.receivedUpstreamBytes 2906 logFields["pre_handshake_random_stream_downstream_bytes"] = sshClient.preHandshakeRandomStreamMetrics.downstreamBytes 2907 logFields["pre_handshake_random_stream_sent_downstream_bytes"] = sshClient.preHandshakeRandomStreamMetrics.sentDownstreamBytes 2908 logFields["random_stream_count"] = sshClient.postHandshakeRandomStreamMetrics.count 2909 logFields["random_stream_upstream_bytes"] = sshClient.postHandshakeRandomStreamMetrics.upstreamBytes 2910 logFields["random_stream_received_upstream_bytes"] = sshClient.postHandshakeRandomStreamMetrics.receivedUpstreamBytes 2911 logFields["random_stream_downstream_bytes"] = sshClient.postHandshakeRandomStreamMetrics.downstreamBytes 2912 logFields["random_stream_sent_downstream_bytes"] = sshClient.postHandshakeRandomStreamMetrics.sentDownstreamBytes 2913 2914 if sshClient.destinationBytesMetricsASN != "" { 2915 2916 // Check if the configured DestinationBytesMetricsASN has changed 2917 // (or been cleared). If so, don't log and discard the accumulated 2918 // bytes to ensure we don't continue to record stats as previously 2919 // configured. 2920 // 2921 // Any counts accumulated before the DestinationBytesMetricsASN change 2922 // are lost. At this time we can't change 2923 // sshClient.destinationBytesMetricsASN dynamically, after a tactics 2924 // hot reload, as there may be destination bytes port forwards that 2925 // were in place before the change, which will continue to count. 2926 2927 logDestBytes := true 2928 if sshClient.sshServer.support.ServerTacticsParametersCache != nil { 2929 p, err := sshClient.sshServer.support.ServerTacticsParametersCache.Get(sshClient.geoIPData) 2930 if err != nil || p.IsNil() || 2931 sshClient.destinationBytesMetricsASN != p.String(parameters.DestinationBytesMetricsASN) { 2932 logDestBytes = false 2933 } 2934 } 2935 2936 if logDestBytes { 2937 bytesUpTCP := sshClient.tcpDestinationBytesMetrics.getBytesUp() 2938 bytesDownTCP := sshClient.tcpDestinationBytesMetrics.getBytesDown() 2939 bytesUpUDP := sshClient.udpDestinationBytesMetrics.getBytesUp() 2940 bytesDownUDP := sshClient.udpDestinationBytesMetrics.getBytesDown() 2941 2942 logFields["dest_bytes_asn"] = sshClient.destinationBytesMetricsASN 2943 logFields["dest_bytes_up_tcp"] = bytesUpTCP 2944 logFields["dest_bytes_down_tcp"] = bytesDownTCP 2945 logFields["dest_bytes_up_udp"] = bytesUpUDP 2946 logFields["dest_bytes_down_udp"] = bytesDownUDP 2947 logFields["dest_bytes"] = bytesUpTCP + bytesDownTCP + bytesUpUDP + bytesDownUDP 2948 } 2949 } 2950 2951 // Only log fields for peakMetrics when there is data recorded, otherwise 2952 // omit the field. 2953 if sshClient.peakMetrics.concurrentProximateAcceptedClients != nil { 2954 logFields["peak_concurrent_proximate_accepted_clients"] = *sshClient.peakMetrics.concurrentProximateAcceptedClients 2955 } 2956 if sshClient.peakMetrics.concurrentProximateEstablishedClients != nil { 2957 logFields["peak_concurrent_proximate_established_clients"] = *sshClient.peakMetrics.concurrentProximateEstablishedClients 2958 } 2959 if sshClient.peakMetrics.TCPPortForwardFailureRate != nil && sshClient.peakMetrics.TCPPortForwardFailureRateSampleSize != nil { 2960 logFields["peak_tcp_port_forward_failure_rate"] = *sshClient.peakMetrics.TCPPortForwardFailureRate 2961 logFields["peak_tcp_port_forward_failure_rate_sample_size"] = *sshClient.peakMetrics.TCPPortForwardFailureRateSampleSize 2962 } 2963 if sshClient.peakMetrics.DNSFailureRate != nil && sshClient.peakMetrics.DNSFailureRateSampleSize != nil { 2964 logFields["peak_dns_failure_rate"] = *sshClient.peakMetrics.DNSFailureRate 2965 logFields["peak_dns_failure_rate_sample_size"] = *sshClient.peakMetrics.DNSFailureRateSampleSize 2966 } 2967 2968 // Pre-calculate a total-tunneled-bytes field. This total is used 2969 // extensively in analytics and is more performant when pre-calculated. 2970 logFields["bytes"] = sshClient.tcpTrafficState.bytesUp + 2971 sshClient.tcpTrafficState.bytesDown + 2972 sshClient.udpTrafficState.bytesUp + 2973 sshClient.udpTrafficState.bytesDown 2974 2975 // Merge in additional metrics from the optional metrics source 2976 for _, metrics := range additionalMetrics { 2977 for name, value := range metrics { 2978 // Don't overwrite any basic fields 2979 if logFields[name] == nil { 2980 logFields[name] = value 2981 } 2982 } 2983 } 2984 2985 // Retain lock when invoking LogRawFieldsWithTimestamp to block any 2986 // concurrent writes to variables referenced by logFields. 2987 log.LogRawFieldsWithTimestamp(logFields) 2988 2989 sshClient.Unlock() 2990 } 2991 2992 var blocklistHitsStatParams = []requestParamSpec{ 2993 {"propagation_channel_id", isHexDigits, 0}, 2994 {"sponsor_id", isHexDigits, 0}, 2995 {"client_version", isIntString, requestParamLogStringAsInt}, 2996 {"client_platform", isClientPlatform, 0}, 2997 {"client_features", isAnyString, requestParamOptional | requestParamArray}, 2998 {"client_build_rev", isHexDigits, requestParamOptional}, 2999 {"device_region", isAnyString, requestParamOptional}, 3000 {"egress_region", isRegionCode, requestParamOptional}, 3001 {"session_id", isHexDigits, 0}, 3002 {"last_connected", isLastConnected, requestParamOptional}, 3003 } 3004 3005 func (sshClient *sshClient) logBlocklistHits(IP net.IP, domain string, tags []BlocklistTag) { 3006 3007 sshClient.Lock() 3008 3009 logFields := getRequestLogFields( 3010 "server_blocklist_hit", 3011 sshClient.geoIPData, 3012 sshClient.handshakeState.authorizedAccessTypes, 3013 sshClient.handshakeState.apiParams, 3014 blocklistHitsStatParams) 3015 3016 logFields["session_id"] = sshClient.sessionID 3017 3018 // Note: see comment in logTunnel regarding unlock and concurrent access. 3019 3020 sshClient.Unlock() 3021 3022 for _, tag := range tags { 3023 if IP != nil { 3024 logFields["blocklist_ip_address"] = IP.String() 3025 } 3026 if domain != "" { 3027 logFields["blocklist_domain"] = domain 3028 } 3029 logFields["blocklist_source"] = tag.Source 3030 logFields["blocklist_subject"] = tag.Subject 3031 3032 log.LogRawFieldsWithTimestamp(logFields) 3033 } 3034 } 3035 3036 func (sshClient *sshClient) runOSLSender() { 3037 3038 for { 3039 // Await a signal that there are SLOKs to send 3040 // TODO: use reflect.SelectCase, and optionally await timer here? 3041 select { 3042 case <-sshClient.signalIssueSLOKs: 3043 case <-sshClient.runCtx.Done(): 3044 return 3045 } 3046 3047 retryDelay := SSH_SEND_OSL_INITIAL_RETRY_DELAY 3048 for { 3049 err := sshClient.sendOSLRequest() 3050 if err == nil { 3051 break 3052 } 3053 if !isExpectedTunnelIOError(err) { 3054 log.WithTraceFields(LogFields{"error": err}).Warning("sendOSLRequest failed") 3055 } 3056 3057 // If the request failed, retry after a delay (with exponential backoff) 3058 // or when signaled that there are additional SLOKs to send 3059 retryTimer := time.NewTimer(retryDelay) 3060 select { 3061 case <-retryTimer.C: 3062 case <-sshClient.signalIssueSLOKs: 3063 case <-sshClient.runCtx.Done(): 3064 retryTimer.Stop() 3065 return 3066 } 3067 retryTimer.Stop() 3068 retryDelay *= SSH_SEND_OSL_RETRY_FACTOR 3069 } 3070 } 3071 } 3072 3073 // sendOSLRequest will invoke osl.GetSeedPayload to issue SLOKs and 3074 // generate a payload, and send an OSL request to the client when 3075 // there are new SLOKs in the payload. 3076 func (sshClient *sshClient) sendOSLRequest() error { 3077 3078 seedPayload := sshClient.getOSLSeedPayload() 3079 3080 // Don't send when no SLOKs. This will happen when signalIssueSLOKs 3081 // is received but no new SLOKs are issued. 3082 if len(seedPayload.SLOKs) == 0 { 3083 return nil 3084 } 3085 3086 oslRequest := protocol.OSLRequest{ 3087 SeedPayload: seedPayload, 3088 } 3089 requestPayload, err := json.Marshal(oslRequest) 3090 if err != nil { 3091 return errors.Trace(err) 3092 } 3093 3094 ok, _, err := sshClient.sshConn.SendRequest( 3095 protocol.PSIPHON_API_OSL_REQUEST_NAME, 3096 true, 3097 requestPayload) 3098 if err != nil { 3099 return errors.Trace(err) 3100 } 3101 if !ok { 3102 return errors.TraceNew("client rejected request") 3103 } 3104 3105 sshClient.clearOSLSeedPayload() 3106 3107 return nil 3108 } 3109 3110 // runAlertSender dequeues and sends alert requests to the client. As these 3111 // alerts are informational, there is no retry logic and no SSH client 3112 // acknowledgement (wantReply) is requested. This worker scheme allows 3113 // nonconcurrent components including udpgw and packet tunnel to enqueue 3114 // alerts without blocking their traffic processing. 3115 func (sshClient *sshClient) runAlertSender() { 3116 for { 3117 select { 3118 case <-sshClient.runCtx.Done(): 3119 return 3120 3121 case request := <-sshClient.sendAlertRequests: 3122 payload, err := json.Marshal(request) 3123 if err != nil { 3124 log.WithTraceFields(LogFields{"error": err}).Warning("Marshal failed") 3125 break 3126 } 3127 _, _, err = sshClient.sshConn.SendRequest( 3128 protocol.PSIPHON_API_ALERT_REQUEST_NAME, 3129 false, 3130 payload) 3131 if err != nil && !isExpectedTunnelIOError(err) { 3132 log.WithTraceFields(LogFields{"error": err}).Warning("SendRequest failed") 3133 break 3134 } 3135 sshClient.Lock() 3136 sshClient.sentAlertRequests[fmt.Sprintf("%+v", request)] = true 3137 sshClient.Unlock() 3138 } 3139 } 3140 } 3141 3142 // enqueueAlertRequest enqueues an alert request to be sent to the client. 3143 // Only one request is sent per tunnel per protocol.AlertRequest value; 3144 // subsequent alerts with the same value are dropped. enqueueAlertRequest will 3145 // not block until the queue exceeds ALERT_REQUEST_QUEUE_BUFFER_SIZE. 3146 func (sshClient *sshClient) enqueueAlertRequest(request protocol.AlertRequest) { 3147 sshClient.Lock() 3148 if sshClient.sentAlertRequests[fmt.Sprintf("%+v", request)] { 3149 sshClient.Unlock() 3150 return 3151 } 3152 sshClient.Unlock() 3153 select { 3154 case <-sshClient.runCtx.Done(): 3155 case sshClient.sendAlertRequests <- request: 3156 } 3157 } 3158 3159 func (sshClient *sshClient) enqueueDisallowedTrafficAlertRequest() { 3160 3161 reason := protocol.PSIPHON_API_ALERT_DISALLOWED_TRAFFIC 3162 actionURLs := sshClient.getAlertActionURLs(reason) 3163 3164 sshClient.enqueueAlertRequest( 3165 protocol.AlertRequest{ 3166 Reason: reason, 3167 ActionURLs: actionURLs, 3168 }) 3169 } 3170 3171 func (sshClient *sshClient) enqueueUnsafeTrafficAlertRequest(tags []BlocklistTag) { 3172 3173 reason := protocol.PSIPHON_API_ALERT_UNSAFE_TRAFFIC 3174 actionURLs := sshClient.getAlertActionURLs(reason) 3175 3176 for _, tag := range tags { 3177 sshClient.enqueueAlertRequest( 3178 protocol.AlertRequest{ 3179 Reason: reason, 3180 Subject: tag.Subject, 3181 ActionURLs: actionURLs, 3182 }) 3183 } 3184 } 3185 3186 func (sshClient *sshClient) getAlertActionURLs(alertReason string) []string { 3187 3188 sshClient.Lock() 3189 sponsorID, _ := getStringRequestParam( 3190 sshClient.handshakeState.apiParams, "sponsor_id") 3191 sshClient.Unlock() 3192 3193 return sshClient.sshServer.support.PsinetDatabase.GetAlertActionURLs( 3194 alertReason, 3195 sponsorID, 3196 sshClient.geoIPData.Country, 3197 sshClient.geoIPData.ASN) 3198 } 3199 3200 func (sshClient *sshClient) rejectNewChannel(newChannel ssh.NewChannel, logMessage string) { 3201 3202 // We always return the reject reason "Prohibited": 3203 // - Traffic rules and connection limits may prohibit the connection. 3204 // - External firewall rules may prohibit the connection, and this is not currently 3205 // distinguishable from other failure modes. 3206 // - We limit the failure information revealed to the client. 3207 reason := ssh.Prohibited 3208 3209 // Note: Debug level, as logMessage may contain user traffic destination address information 3210 log.WithTraceFields( 3211 LogFields{ 3212 "channelType": newChannel.ChannelType(), 3213 "logMessage": logMessage, 3214 "rejectReason": reason.String(), 3215 }).Debug("reject new channel") 3216 3217 // Note: logMessage is internal, for logging only; just the reject reason is sent to the client. 3218 newChannel.Reject(reason, reason.String()) 3219 } 3220 3221 // setHandshakeState records that a client has completed a handshake API request. 3222 // Some parameters from the handshake request may be used in future traffic rule 3223 // selection. Port forwards are disallowed until a handshake is complete. The 3224 // handshake parameters are included in the session summary log recorded in 3225 // sshClient.stop(). 3226 func (sshClient *sshClient) setHandshakeState( 3227 state handshakeState, 3228 authorizations []string) (*handshakeStateInfo, error) { 3229 3230 sshClient.Lock() 3231 completed := sshClient.handshakeState.completed 3232 if !completed { 3233 sshClient.handshakeState = state 3234 } 3235 sshClient.Unlock() 3236 3237 // Client must only perform one handshake 3238 if completed { 3239 return nil, errors.TraceNew("handshake already completed") 3240 } 3241 3242 // Verify the authorizations submitted by the client. Verified, active 3243 // (non-expired) access types will be available for traffic rules 3244 // filtering. 3245 // 3246 // When an authorization is active but expires while the client is 3247 // connected, the client is disconnected to ensure the access is reset. 3248 // This is implemented by setting a timer to perform the disconnect at the 3249 // expiry time of the soonest expiring authorization. 3250 // 3251 // sshServer.authorizationSessionIDs tracks the unique mapping of active 3252 // authorization IDs to client session IDs and is used to detect and 3253 // prevent multiple malicious clients from reusing a single authorization 3254 // (within the scope of this server). 3255 3256 // authorizationIDs and authorizedAccessTypes are returned to the client 3257 // and logged, respectively; initialize to empty lists so the 3258 // protocol/logs don't need to handle 'null' values. 3259 authorizationIDs := make([]string, 0) 3260 authorizedAccessTypes := make([]string, 0) 3261 var stopTime time.Time 3262 3263 for i, authorization := range authorizations { 3264 3265 // This sanity check mitigates malicious clients causing excess CPU use. 3266 if i >= MAX_AUTHORIZATIONS { 3267 log.WithTrace().Warning("too many authorizations") 3268 break 3269 } 3270 3271 verifiedAuthorization, err := accesscontrol.VerifyAuthorization( 3272 &sshClient.sshServer.support.Config.AccessControlVerificationKeyRing, 3273 authorization) 3274 3275 if err != nil { 3276 log.WithTraceFields( 3277 LogFields{"error": err}).Warning("verify authorization failed") 3278 continue 3279 } 3280 3281 authorizationID := base64.StdEncoding.EncodeToString(verifiedAuthorization.ID) 3282 3283 if common.Contains(authorizedAccessTypes, verifiedAuthorization.AccessType) { 3284 log.WithTraceFields( 3285 LogFields{"accessType": verifiedAuthorization.AccessType}).Warning("duplicate authorization access type") 3286 continue 3287 } 3288 3289 authorizationIDs = append(authorizationIDs, authorizationID) 3290 authorizedAccessTypes = append(authorizedAccessTypes, verifiedAuthorization.AccessType) 3291 3292 if stopTime.IsZero() || stopTime.After(verifiedAuthorization.Expires) { 3293 stopTime = verifiedAuthorization.Expires 3294 } 3295 } 3296 3297 // Associate all verified authorizationIDs with this client's session ID. 3298 // Handle cases where previous associations exist: 3299 // 3300 // - Multiple malicious clients reusing a single authorization. In this 3301 // case, authorizations are revoked from the previous client. 3302 // 3303 // - The client reconnected with a new session ID due to user toggling. 3304 // This case is expected due to server affinity. This cannot be 3305 // distinguished from the previous case and the same action is taken; 3306 // this will have no impact on a legitimate client as the previous 3307 // session is dangling. 3308 // 3309 // - The client automatically reconnected with the same session ID. This 3310 // case is not expected as sshServer.registerEstablishedClient 3311 // synchronously calls sshClient.releaseAuthorizations; as a safe guard, 3312 // this case is distinguished and no revocation action is taken. 3313 3314 sshClient.sshServer.authorizationSessionIDsMutex.Lock() 3315 for _, authorizationID := range authorizationIDs { 3316 sessionID, ok := sshClient.sshServer.authorizationSessionIDs[authorizationID] 3317 if ok && sessionID != sshClient.sessionID { 3318 3319 logFields := LogFields{ 3320 "event_name": "irregular_tunnel", 3321 "tunnel_error": "duplicate active authorization", 3322 "duplicate_authorization_id": authorizationID, 3323 } 3324 sshClient.geoIPData.SetLogFields(logFields) 3325 duplicateGeoIPData := sshClient.sshServer.support.GeoIPService.GetSessionCache(sessionID) 3326 if duplicateGeoIPData != sshClient.geoIPData { 3327 duplicateGeoIPData.SetLogFieldsWithPrefix("duplicate_authorization_", logFields) 3328 } 3329 log.LogRawFieldsWithTimestamp(logFields) 3330 3331 // Invoke asynchronously to avoid deadlocks. 3332 // TODO: invoke only once for each distinct sessionID? 3333 go sshClient.sshServer.revokeClientAuthorizations(sessionID) 3334 } 3335 sshClient.sshServer.authorizationSessionIDs[authorizationID] = sshClient.sessionID 3336 } 3337 sshClient.sshServer.authorizationSessionIDsMutex.Unlock() 3338 3339 if len(authorizationIDs) > 0 { 3340 3341 sshClient.Lock() 3342 3343 // Make the authorizedAccessTypes available for traffic rules filtering. 3344 3345 sshClient.handshakeState.activeAuthorizationIDs = authorizationIDs 3346 sshClient.handshakeState.authorizedAccessTypes = authorizedAccessTypes 3347 3348 // On exit, sshClient.runTunnel will call releaseAuthorizations, which 3349 // will release the authorization IDs so the client can reconnect and 3350 // present the same authorizations again. sshClient.runTunnel will 3351 // also cancel the stopTimer in case it has not yet fired. 3352 // Note: termination of the stopTimer goroutine is not synchronized. 3353 3354 sshClient.releaseAuthorizations = func() { 3355 sshClient.sshServer.authorizationSessionIDsMutex.Lock() 3356 for _, authorizationID := range authorizationIDs { 3357 sessionID, ok := sshClient.sshServer.authorizationSessionIDs[authorizationID] 3358 if ok && sessionID == sshClient.sessionID { 3359 delete(sshClient.sshServer.authorizationSessionIDs, authorizationID) 3360 } 3361 } 3362 sshClient.sshServer.authorizationSessionIDsMutex.Unlock() 3363 } 3364 3365 sshClient.stopTimer = time.AfterFunc( 3366 time.Until(stopTime), 3367 func() { 3368 sshClient.stop() 3369 }) 3370 3371 sshClient.Unlock() 3372 } 3373 3374 upstreamBytesPerSecond, downstreamBytesPerSecond := sshClient.setTrafficRules() 3375 3376 sshClient.setOSLConfig() 3377 3378 // Set destination bytes metrics. 3379 // 3380 // Limitation: this is a one-time operation and doesn't get reset when 3381 // tactics are hot-reloaded. This allows us to simply retain any 3382 // destination byte counts accumulated and eventually log in 3383 // server_tunnel, without having to deal with a destination change 3384 // mid-tunnel. As typical tunnels are short, and destination changes can 3385 // be applied gradually, handling mid-tunnel changes is not a priority. 3386 sshClient.setDestinationBytesMetrics() 3387 3388 return &handshakeStateInfo{ 3389 activeAuthorizationIDs: authorizationIDs, 3390 authorizedAccessTypes: authorizedAccessTypes, 3391 upstreamBytesPerSecond: upstreamBytesPerSecond, 3392 downstreamBytesPerSecond: downstreamBytesPerSecond, 3393 }, nil 3394 } 3395 3396 // getHandshaked returns whether the client has completed a handshake API 3397 // request and whether the traffic rules that were selected after the 3398 // handshake immediately exhaust the client. 3399 // 3400 // When the client is immediately exhausted it will be closed; but this 3401 // takes effect asynchronously. The "exhausted" return value is used to 3402 // prevent API requests by clients that will close. 3403 func (sshClient *sshClient) getHandshaked() (bool, bool) { 3404 sshClient.Lock() 3405 defer sshClient.Unlock() 3406 3407 completed := sshClient.handshakeState.completed 3408 3409 exhausted := false 3410 3411 // Notes: 3412 // - "Immediately exhausted" is when CloseAfterExhausted is set and 3413 // either ReadUnthrottledBytes or WriteUnthrottledBytes starts from 3414 // 0, so no bytes would be read or written. This check does not 3415 // examine whether 0 bytes _remain_ in the ThrottledConn. 3416 // - This check is made against the current traffic rules, which 3417 // could have changed in a hot reload since the handshake. 3418 3419 if completed && 3420 *sshClient.trafficRules.RateLimits.CloseAfterExhausted && 3421 (*sshClient.trafficRules.RateLimits.ReadUnthrottledBytes == 0 || 3422 *sshClient.trafficRules.RateLimits.WriteUnthrottledBytes == 0) { 3423 3424 exhausted = true 3425 } 3426 3427 return completed, exhausted 3428 } 3429 3430 func (sshClient *sshClient) getDisableDiscovery() bool { 3431 sshClient.Lock() 3432 defer sshClient.Unlock() 3433 3434 return *sshClient.trafficRules.DisableDiscovery 3435 } 3436 3437 func (sshClient *sshClient) updateAPIParameters( 3438 apiParams common.APIParameters) { 3439 3440 sshClient.Lock() 3441 defer sshClient.Unlock() 3442 3443 // Only update after handshake has initialized API params. 3444 if !sshClient.handshakeState.completed { 3445 return 3446 } 3447 3448 for name, value := range apiParams { 3449 sshClient.handshakeState.apiParams[name] = value 3450 } 3451 } 3452 3453 func (sshClient *sshClient) acceptDomainBytes() bool { 3454 sshClient.Lock() 3455 defer sshClient.Unlock() 3456 3457 // When the domain bytes checksum differs from the checksum sent to the 3458 // client in the handshake response, the psinet regex configuration has 3459 // changed. In this case, drop the stats so we don't continue to record 3460 // stats as previously configured. 3461 // 3462 // Limitations: 3463 // - The checksum comparison may result in dropping some stats for a 3464 // domain that remains in the new configuration. 3465 // - We don't push new regexs to the clients, so clients that remain 3466 // connected will continue to send stats that will be dropped; and 3467 // those clients will not send stats as newly configured until after 3468 // reconnecting. 3469 // - Due to the design of 3470 // transferstats.ReportRecentBytesTransferredForServer in the client, 3471 // the client may accumulate stats, reconnect before its next status 3472 // request, get a new regex configuration, and then send the previously 3473 // accumulated stats in its next status request. The checksum scheme 3474 // won't prevent the reporting of those stats. 3475 3476 sponsorID, _ := getStringRequestParam(sshClient.handshakeState.apiParams, "sponsor_id") 3477 3478 domainBytesChecksum := sshClient.sshServer.support.PsinetDatabase.GetDomainBytesChecksum(sponsorID) 3479 3480 return bytes.Equal(sshClient.handshakeState.domainBytesChecksum, domainBytesChecksum) 3481 } 3482 3483 // setOSLConfig resets the client's OSL seed state based on the latest OSL config 3484 // As sshClient.oslClientSeedState may be reset by a concurrent goroutine, 3485 // oslClientSeedState must only be accessed within the sshClient mutex. 3486 func (sshClient *sshClient) setOSLConfig() { 3487 sshClient.Lock() 3488 defer sshClient.Unlock() 3489 3490 propagationChannelID, err := getStringRequestParam( 3491 sshClient.handshakeState.apiParams, "propagation_channel_id") 3492 if err != nil { 3493 // This should not fail as long as client has sent valid handshake 3494 return 3495 } 3496 3497 // Use a cached seed state if one is found for the client's 3498 // session ID. This enables resuming progress made in a previous 3499 // tunnel. 3500 // Note: go-cache is already concurency safe; the additional mutex 3501 // is necessary to guarantee that Get/Delete is atomic; although in 3502 // practice no two concurrent clients should ever supply the same 3503 // session ID. 3504 3505 sshClient.sshServer.oslSessionCacheMutex.Lock() 3506 oslClientSeedState, found := sshClient.sshServer.oslSessionCache.Get(sshClient.sessionID) 3507 if found { 3508 sshClient.sshServer.oslSessionCache.Delete(sshClient.sessionID) 3509 sshClient.sshServer.oslSessionCacheMutex.Unlock() 3510 sshClient.oslClientSeedState = oslClientSeedState.(*osl.ClientSeedState) 3511 sshClient.oslClientSeedState.Resume(sshClient.signalIssueSLOKs) 3512 return 3513 } 3514 sshClient.sshServer.oslSessionCacheMutex.Unlock() 3515 3516 // Two limitations when setOSLConfig() is invoked due to an 3517 // OSL config hot reload: 3518 // 3519 // 1. any partial progress towards SLOKs is lost. 3520 // 3521 // 2. all existing osl.ClientSeedPortForwards for existing 3522 // port forwards will not send progress to the new client 3523 // seed state. 3524 3525 sshClient.oslClientSeedState = sshClient.sshServer.support.OSLConfig.NewClientSeedState( 3526 sshClient.geoIPData.Country, 3527 propagationChannelID, 3528 sshClient.signalIssueSLOKs) 3529 } 3530 3531 // newClientSeedPortForward will return nil when no seeding is 3532 // associated with the specified ipAddress. 3533 func (sshClient *sshClient) newClientSeedPortForward(IPAddress net.IP) *osl.ClientSeedPortForward { 3534 sshClient.Lock() 3535 defer sshClient.Unlock() 3536 3537 // Will not be initialized before handshake. 3538 if sshClient.oslClientSeedState == nil { 3539 return nil 3540 } 3541 3542 return sshClient.oslClientSeedState.NewClientSeedPortForward(IPAddress) 3543 } 3544 3545 // getOSLSeedPayload returns a payload containing all seeded SLOKs for 3546 // this client's session. 3547 func (sshClient *sshClient) getOSLSeedPayload() *osl.SeedPayload { 3548 sshClient.Lock() 3549 defer sshClient.Unlock() 3550 3551 // Will not be initialized before handshake. 3552 if sshClient.oslClientSeedState == nil { 3553 return &osl.SeedPayload{SLOKs: make([]*osl.SLOK, 0)} 3554 } 3555 3556 return sshClient.oslClientSeedState.GetSeedPayload() 3557 } 3558 3559 func (sshClient *sshClient) clearOSLSeedPayload() { 3560 sshClient.Lock() 3561 defer sshClient.Unlock() 3562 3563 sshClient.oslClientSeedState.ClearSeedPayload() 3564 } 3565 3566 func (sshClient *sshClient) setDestinationBytesMetrics() { 3567 sshClient.Lock() 3568 defer sshClient.Unlock() 3569 3570 // Limitation: the server-side tactics cache is used to avoid the overhead 3571 // of an additional tactics filtering per tunnel. As this cache is 3572 // designed for GeoIP filtering only, handshake API parameters are not 3573 // applied to tactics filtering in this case. 3574 3575 tacticsCache := sshClient.sshServer.support.ServerTacticsParametersCache 3576 if tacticsCache == nil { 3577 return 3578 } 3579 3580 p, err := tacticsCache.Get(sshClient.geoIPData) 3581 if err != nil { 3582 log.WithTraceFields(LogFields{"error": err}).Warning("get tactics failed") 3583 return 3584 } 3585 if p.IsNil() { 3586 return 3587 } 3588 3589 sshClient.destinationBytesMetricsASN = p.String(parameters.DestinationBytesMetricsASN) 3590 } 3591 3592 func (sshClient *sshClient) newDestinationBytesMetricsUpdater(portForwardType int, IPAddress net.IP) *destinationBytesMetrics { 3593 sshClient.Lock() 3594 defer sshClient.Unlock() 3595 3596 if sshClient.destinationBytesMetricsASN == "" { 3597 return nil 3598 } 3599 3600 if sshClient.sshServer.support.GeoIPService.LookupISPForIP(IPAddress).ASN != sshClient.destinationBytesMetricsASN { 3601 return nil 3602 } 3603 3604 if portForwardType == portForwardTypeTCP { 3605 return &sshClient.tcpDestinationBytesMetrics 3606 } 3607 3608 return &sshClient.udpDestinationBytesMetrics 3609 } 3610 3611 func (sshClient *sshClient) getActivityUpdaters(portForwardType int, IPAddress net.IP) []common.ActivityUpdater { 3612 var updaters []common.ActivityUpdater 3613 3614 clientSeedPortForward := sshClient.newClientSeedPortForward(IPAddress) 3615 if clientSeedPortForward != nil { 3616 updaters = append(updaters, clientSeedPortForward) 3617 } 3618 3619 destinationBytesMetrics := sshClient.newDestinationBytesMetricsUpdater(portForwardType, IPAddress) 3620 if destinationBytesMetrics != nil { 3621 updaters = append(updaters, destinationBytesMetrics) 3622 } 3623 3624 return updaters 3625 } 3626 3627 // setTrafficRules resets the client's traffic rules based on the latest server config 3628 // and client properties. As sshClient.trafficRules may be reset by a concurrent 3629 // goroutine, trafficRules must only be accessed within the sshClient mutex. 3630 func (sshClient *sshClient) setTrafficRules() (int64, int64) { 3631 sshClient.Lock() 3632 defer sshClient.Unlock() 3633 3634 isFirstTunnelInSession := sshClient.isFirstTunnelInSession && 3635 sshClient.handshakeState.establishedTunnelsCount == 0 3636 3637 sshClient.trafficRules = sshClient.sshServer.support.TrafficRulesSet.GetTrafficRules( 3638 isFirstTunnelInSession, 3639 sshClient.tunnelProtocol, 3640 sshClient.geoIPData, 3641 sshClient.handshakeState) 3642 3643 if sshClient.throttledConn != nil { 3644 // Any existing throttling state is reset. 3645 sshClient.throttledConn.SetLimits( 3646 sshClient.trafficRules.RateLimits.CommonRateLimits( 3647 sshClient.handshakeState.completed)) 3648 } 3649 3650 return *sshClient.trafficRules.RateLimits.ReadBytesPerSecond, 3651 *sshClient.trafficRules.RateLimits.WriteBytesPerSecond 3652 } 3653 3654 func (sshClient *sshClient) rateLimits() common.RateLimits { 3655 sshClient.Lock() 3656 defer sshClient.Unlock() 3657 3658 return sshClient.trafficRules.RateLimits.CommonRateLimits( 3659 sshClient.handshakeState.completed) 3660 } 3661 3662 func (sshClient *sshClient) idleTCPPortForwardTimeout() time.Duration { 3663 sshClient.Lock() 3664 defer sshClient.Unlock() 3665 3666 return time.Duration(*sshClient.trafficRules.IdleTCPPortForwardTimeoutMilliseconds) * time.Millisecond 3667 } 3668 3669 func (sshClient *sshClient) idleUDPPortForwardTimeout() time.Duration { 3670 sshClient.Lock() 3671 defer sshClient.Unlock() 3672 3673 return time.Duration(*sshClient.trafficRules.IdleUDPPortForwardTimeoutMilliseconds) * time.Millisecond 3674 } 3675 3676 func (sshClient *sshClient) setTCPPortForwardDialingAvailableSignal(signal context.CancelFunc) { 3677 sshClient.Lock() 3678 defer sshClient.Unlock() 3679 3680 sshClient.tcpPortForwardDialingAvailableSignal = signal 3681 } 3682 3683 const ( 3684 portForwardTypeTCP = iota 3685 portForwardTypeUDP 3686 ) 3687 3688 func (sshClient *sshClient) isPortForwardPermitted( 3689 portForwardType int, 3690 remoteIP net.IP, 3691 port int) bool { 3692 3693 // Disallow connection to bogons. 3694 // 3695 // As a security measure, this is a failsafe. The server should be run on a 3696 // host with correctly configured firewall rules. 3697 // 3698 // This check also avoids spurious disallowed traffic alerts for destinations 3699 // that are impossible to reach. 3700 3701 if !sshClient.sshServer.support.Config.AllowBogons && common.IsBogon(remoteIP) { 3702 return false 3703 } 3704 3705 // Blocklist check. 3706 // 3707 // Limitation: isPortForwardPermitted is not called in transparent DNS 3708 // forwarding cases. As the destination IP address is rewritten in these 3709 // cases, a blocklist entry won't be dialed in any case. However, no logs 3710 // will be recorded. 3711 3712 if !sshClient.isIPPermitted(remoteIP) { 3713 return false 3714 } 3715 3716 // Don't lock before calling logBlocklistHits. 3717 // Unlock before calling enqueueDisallowedTrafficAlertRequest/log. 3718 3719 sshClient.Lock() 3720 3721 allowed := true 3722 3723 // Client must complete handshake before port forwards are permitted. 3724 if !sshClient.handshakeState.completed { 3725 allowed = false 3726 } 3727 3728 if allowed { 3729 // Traffic rules checks. 3730 switch portForwardType { 3731 case portForwardTypeTCP: 3732 if !sshClient.trafficRules.AllowTCPPort(remoteIP, port) { 3733 allowed = false 3734 } 3735 case portForwardTypeUDP: 3736 if !sshClient.trafficRules.AllowUDPPort(remoteIP, port) { 3737 allowed = false 3738 } 3739 } 3740 } 3741 3742 sshClient.Unlock() 3743 3744 if allowed { 3745 return true 3746 } 3747 3748 switch portForwardType { 3749 case portForwardTypeTCP: 3750 sshClient.updateQualityMetricsWithTCPRejectedDisallowed() 3751 case portForwardTypeUDP: 3752 sshClient.updateQualityMetricsWithUDPRejectedDisallowed() 3753 } 3754 3755 sshClient.enqueueDisallowedTrafficAlertRequest() 3756 3757 log.WithTraceFields( 3758 LogFields{ 3759 "type": portForwardType, 3760 "port": port, 3761 }).Debug("port forward denied by traffic rules") 3762 3763 return false 3764 } 3765 3766 // isDomainPermitted returns true when the specified domain may be resolved 3767 // and returns false and a reject reason otherwise. 3768 func (sshClient *sshClient) isDomainPermitted(domain string) (bool, string) { 3769 3770 // We're not doing comprehensive validation, to avoid overhead per port 3771 // forward. This is a simple sanity check to ensure we don't process 3772 // blantantly invalid input. 3773 // 3774 // TODO: validate with dns.IsDomainName? 3775 if len(domain) > 255 { 3776 return false, "invalid domain name" 3777 } 3778 3779 tags := sshClient.sshServer.support.Blocklist.LookupDomain(domain) 3780 if len(tags) > 0 { 3781 3782 sshClient.logBlocklistHits(nil, domain, tags) 3783 3784 if sshClient.sshServer.support.Config.BlocklistActive { 3785 // Actively alert and block 3786 sshClient.enqueueUnsafeTrafficAlertRequest(tags) 3787 return false, "port forward not permitted" 3788 } 3789 } 3790 3791 return true, "" 3792 } 3793 3794 func (sshClient *sshClient) isIPPermitted(remoteIP net.IP) bool { 3795 3796 tags := sshClient.sshServer.support.Blocklist.LookupIP(remoteIP) 3797 if len(tags) > 0 { 3798 3799 sshClient.logBlocklistHits(remoteIP, "", tags) 3800 3801 if sshClient.sshServer.support.Config.BlocklistActive { 3802 // Actively alert and block 3803 sshClient.enqueueUnsafeTrafficAlertRequest(tags) 3804 return false 3805 } 3806 } 3807 3808 return true 3809 } 3810 3811 func (sshClient *sshClient) isTCPDialingPortForwardLimitExceeded() bool { 3812 3813 sshClient.Lock() 3814 defer sshClient.Unlock() 3815 3816 state := &sshClient.tcpTrafficState 3817 max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount 3818 3819 if max > 0 && state.concurrentDialingPortForwardCount >= int64(max) { 3820 return true 3821 } 3822 return false 3823 } 3824 3825 func (sshClient *sshClient) getTCPPortForwardQueueSize() int { 3826 3827 sshClient.Lock() 3828 defer sshClient.Unlock() 3829 3830 return *sshClient.trafficRules.MaxTCPPortForwardCount + 3831 *sshClient.trafficRules.MaxTCPDialingPortForwardCount 3832 } 3833 3834 func (sshClient *sshClient) getDialTCPPortForwardTimeoutMilliseconds() int { 3835 3836 sshClient.Lock() 3837 defer sshClient.Unlock() 3838 3839 return *sshClient.trafficRules.DialTCPPortForwardTimeoutMilliseconds 3840 } 3841 3842 func (sshClient *sshClient) dialingTCPPortForward() { 3843 3844 sshClient.Lock() 3845 defer sshClient.Unlock() 3846 3847 state := &sshClient.tcpTrafficState 3848 3849 state.concurrentDialingPortForwardCount += 1 3850 if state.concurrentDialingPortForwardCount > state.peakConcurrentDialingPortForwardCount { 3851 state.peakConcurrentDialingPortForwardCount = state.concurrentDialingPortForwardCount 3852 } 3853 } 3854 3855 func (sshClient *sshClient) abortedTCPPortForward() { 3856 3857 sshClient.Lock() 3858 defer sshClient.Unlock() 3859 3860 sshClient.tcpTrafficState.concurrentDialingPortForwardCount -= 1 3861 } 3862 3863 func (sshClient *sshClient) allocatePortForward(portForwardType int) bool { 3864 3865 sshClient.Lock() 3866 defer sshClient.Unlock() 3867 3868 // Check if at port forward limit. The subsequent counter 3869 // changes must be atomic with the limit check to ensure 3870 // the counter never exceeds the limit in the case of 3871 // concurrent allocations. 3872 3873 var max int 3874 var state *trafficState 3875 if portForwardType == portForwardTypeTCP { 3876 max = *sshClient.trafficRules.MaxTCPPortForwardCount 3877 state = &sshClient.tcpTrafficState 3878 } else { 3879 max = *sshClient.trafficRules.MaxUDPPortForwardCount 3880 state = &sshClient.udpTrafficState 3881 } 3882 3883 if max > 0 && state.concurrentPortForwardCount >= int64(max) { 3884 return false 3885 } 3886 3887 // Update port forward counters. 3888 3889 if portForwardType == portForwardTypeTCP { 3890 3891 // Assumes TCP port forwards called dialingTCPPortForward 3892 state.concurrentDialingPortForwardCount -= 1 3893 3894 if sshClient.tcpPortForwardDialingAvailableSignal != nil { 3895 3896 max := *sshClient.trafficRules.MaxTCPDialingPortForwardCount 3897 if max <= 0 || state.concurrentDialingPortForwardCount < int64(max) { 3898 sshClient.tcpPortForwardDialingAvailableSignal() 3899 } 3900 } 3901 } 3902 3903 state.concurrentPortForwardCount += 1 3904 if state.concurrentPortForwardCount > state.peakConcurrentPortForwardCount { 3905 state.peakConcurrentPortForwardCount = state.concurrentPortForwardCount 3906 } 3907 state.totalPortForwardCount += 1 3908 3909 return true 3910 } 3911 3912 // establishedPortForward increments the concurrent port 3913 // forward counter. closedPortForward decrements it, so it 3914 // must always be called for each establishedPortForward 3915 // call. 3916 // 3917 // When at the limit of established port forwards, the LRU 3918 // existing port forward is closed to make way for the newly 3919 // established one. There can be a minor delay as, in addition 3920 // to calling Close() on the port forward net.Conn, 3921 // establishedPortForward waits for the LRU's closedPortForward() 3922 // call which will decrement the concurrent counter. This 3923 // ensures all resources associated with the LRU (socket, 3924 // goroutine) are released or will very soon be released before 3925 // proceeding. 3926 func (sshClient *sshClient) establishedPortForward( 3927 portForwardType int, portForwardLRU *common.LRUConns) { 3928 3929 // Do not lock sshClient here. 3930 3931 var state *trafficState 3932 if portForwardType == portForwardTypeTCP { 3933 state = &sshClient.tcpTrafficState 3934 } else { 3935 state = &sshClient.udpTrafficState 3936 } 3937 3938 // When the maximum number of port forwards is already 3939 // established, close the LRU. CloseOldest will call 3940 // Close on the port forward net.Conn. Both TCP and 3941 // UDP port forwards have handler goroutines that may 3942 // be blocked calling Read on the net.Conn. Close will 3943 // eventually interrupt the Read and cause the handlers 3944 // to exit, but not immediately. So the following logic 3945 // waits for a LRU handler to be interrupted and signal 3946 // availability. 3947 // 3948 // Notes: 3949 // 3950 // - the port forward limit can change via a traffic 3951 // rules hot reload; the condition variable handles 3952 // this case whereas a channel-based semaphore would 3953 // not. 3954 // 3955 // - if a number of goroutines exceeding the total limit 3956 // arrive here all concurrently, some CloseOldest() calls 3957 // will have no effect as there can be less existing port 3958 // forwards than new ones. In this case, the new port 3959 // forward will be delayed. This is highly unlikely in 3960 // practise since UDP calls to establishedPortForward are 3961 // serialized and TCP calls are limited by the dial 3962 // queue/count. 3963 3964 if !sshClient.allocatePortForward(portForwardType) { 3965 3966 portForwardLRU.CloseOldest() 3967 log.WithTrace().Debug("closed LRU port forward") 3968 3969 state.availablePortForwardCond.L.Lock() 3970 for !sshClient.allocatePortForward(portForwardType) { 3971 state.availablePortForwardCond.Wait() 3972 } 3973 state.availablePortForwardCond.L.Unlock() 3974 } 3975 } 3976 3977 func (sshClient *sshClient) closedPortForward( 3978 portForwardType int, bytesUp, bytesDown int64) { 3979 3980 sshClient.Lock() 3981 3982 var state *trafficState 3983 if portForwardType == portForwardTypeTCP { 3984 state = &sshClient.tcpTrafficState 3985 } else { 3986 state = &sshClient.udpTrafficState 3987 } 3988 3989 state.concurrentPortForwardCount -= 1 3990 state.bytesUp += bytesUp 3991 state.bytesDown += bytesDown 3992 3993 sshClient.Unlock() 3994 3995 // Signal any goroutine waiting in establishedPortForward 3996 // that an established port forward slot is available. 3997 state.availablePortForwardCond.Signal() 3998 } 3999 4000 func (sshClient *sshClient) updateQualityMetricsWithDialResult( 4001 tcpPortForwardDialSuccess bool, dialDuration time.Duration, IP net.IP) { 4002 4003 sshClient.Lock() 4004 defer sshClient.Unlock() 4005 4006 if tcpPortForwardDialSuccess { 4007 sshClient.qualityMetrics.TCPPortForwardDialedCount += 1 4008 sshClient.qualityMetrics.TCPPortForwardDialedDuration += dialDuration 4009 if IP.To4() != nil { 4010 sshClient.qualityMetrics.TCPIPv4PortForwardDialedCount += 1 4011 sshClient.qualityMetrics.TCPIPv4PortForwardDialedDuration += dialDuration 4012 } else if IP != nil { 4013 sshClient.qualityMetrics.TCPIPv6PortForwardDialedCount += 1 4014 sshClient.qualityMetrics.TCPIPv6PortForwardDialedDuration += dialDuration 4015 } 4016 } else { 4017 sshClient.qualityMetrics.TCPPortForwardFailedCount += 1 4018 sshClient.qualityMetrics.TCPPortForwardFailedDuration += dialDuration 4019 if IP.To4() != nil { 4020 sshClient.qualityMetrics.TCPIPv4PortForwardFailedCount += 1 4021 sshClient.qualityMetrics.TCPIPv4PortForwardFailedDuration += dialDuration 4022 } else if IP != nil { 4023 sshClient.qualityMetrics.TCPIPv6PortForwardFailedCount += 1 4024 sshClient.qualityMetrics.TCPIPv6PortForwardFailedDuration += dialDuration 4025 } 4026 } 4027 } 4028 4029 func (sshClient *sshClient) updateQualityMetricsWithRejectedDialingLimit() { 4030 4031 sshClient.Lock() 4032 defer sshClient.Unlock() 4033 4034 sshClient.qualityMetrics.TCPPortForwardRejectedDialingLimitCount += 1 4035 } 4036 4037 func (sshClient *sshClient) updateQualityMetricsWithTCPRejectedDisallowed() { 4038 4039 sshClient.Lock() 4040 defer sshClient.Unlock() 4041 4042 sshClient.qualityMetrics.TCPPortForwardRejectedDisallowedCount += 1 4043 } 4044 4045 func (sshClient *sshClient) updateQualityMetricsWithUDPRejectedDisallowed() { 4046 4047 sshClient.Lock() 4048 defer sshClient.Unlock() 4049 4050 sshClient.qualityMetrics.UDPPortForwardRejectedDisallowedCount += 1 4051 } 4052 4053 func (sshClient *sshClient) updateQualityMetricsWithDNSResult( 4054 success bool, duration time.Duration, resolverIP net.IP) { 4055 4056 sshClient.Lock() 4057 defer sshClient.Unlock() 4058 4059 resolver := "" 4060 if resolverIP != nil { 4061 resolver = resolverIP.String() 4062 } 4063 if success { 4064 sshClient.qualityMetrics.DNSCount["ALL"] += 1 4065 sshClient.qualityMetrics.DNSDuration["ALL"] += duration 4066 if resolver != "" { 4067 sshClient.qualityMetrics.DNSCount[resolver] += 1 4068 sshClient.qualityMetrics.DNSDuration[resolver] += duration 4069 } 4070 } else { 4071 sshClient.qualityMetrics.DNSFailedCount["ALL"] += 1 4072 sshClient.qualityMetrics.DNSFailedDuration["ALL"] += duration 4073 if resolver != "" { 4074 sshClient.qualityMetrics.DNSFailedCount[resolver] += 1 4075 sshClient.qualityMetrics.DNSFailedDuration[resolver] += duration 4076 } 4077 } 4078 } 4079 4080 func (sshClient *sshClient) handleTCPChannel( 4081 remainingDialTimeout time.Duration, 4082 hostToConnect string, 4083 portToConnect int, 4084 doSplitTunnel bool, 4085 newChannel ssh.NewChannel) { 4086 4087 // Assumptions: 4088 // - sshClient.dialingTCPPortForward() has been called 4089 // - remainingDialTimeout > 0 4090 4091 established := false 4092 defer func() { 4093 if !established { 4094 sshClient.abortedTCPPortForward() 4095 } 4096 }() 4097 4098 // Transparently redirect web API request connections. 4099 4100 isWebServerPortForward := false 4101 config := sshClient.sshServer.support.Config 4102 if config.WebServerPortForwardAddress != "" { 4103 destination := net.JoinHostPort(hostToConnect, strconv.Itoa(portToConnect)) 4104 if destination == config.WebServerPortForwardAddress { 4105 isWebServerPortForward = true 4106 if config.WebServerPortForwardRedirectAddress != "" { 4107 // Note: redirect format is validated when config is loaded 4108 host, portStr, _ := net.SplitHostPort(config.WebServerPortForwardRedirectAddress) 4109 port, _ := strconv.Atoi(portStr) 4110 hostToConnect = host 4111 portToConnect = port 4112 } 4113 } 4114 } 4115 4116 // Validate the domain name and check the domain blocklist before dialing. 4117 // 4118 // The IP blocklist is checked in isPortForwardPermitted, which also provides 4119 // IP blocklist checking for the packet tunnel code path. When hostToConnect 4120 // is an IP address, the following hostname resolution step effectively 4121 // performs no actions and next immediate step is the isPortForwardPermitted 4122 // check. 4123 // 4124 // Limitation: this case handles port forwards where the client sends the 4125 // destination domain in the SSH port forward request but does not currently 4126 // handle DNS-over-TCP; in the DNS-over-TCP case, a client may bypass the 4127 // block list check. 4128 4129 if !isWebServerPortForward && 4130 net.ParseIP(hostToConnect) == nil { 4131 4132 ok, rejectMessage := sshClient.isDomainPermitted(hostToConnect) 4133 if !ok { 4134 // Note: not recording a port forward failure in this case 4135 sshClient.rejectNewChannel(newChannel, rejectMessage) 4136 return 4137 } 4138 } 4139 4140 // Dial the remote address. 4141 // 4142 // Hostname resolution is performed explicitly, as a separate step, as the 4143 // target IP address is used for traffic rules (AllowSubnets), OSL seed 4144 // progress, and IP address blocklists. 4145 // 4146 // Contexts are used for cancellation (via sshClient.runCtx, which is 4147 // cancelled when the client is stopping) and timeouts. 4148 4149 dialStartTime := time.Now() 4150 4151 IP := net.ParseIP(hostToConnect) 4152 4153 if IP == nil { 4154 4155 // Resolve the hostname 4156 4157 log.WithTraceFields(LogFields{"hostToConnect": hostToConnect}).Debug("resolving") 4158 4159 ctx, cancelCtx := context.WithTimeout(sshClient.runCtx, remainingDialTimeout) 4160 IPs, err := (&net.Resolver{}).LookupIPAddr(ctx, hostToConnect) 4161 cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled" 4162 4163 resolveElapsedTime := time.Since(dialStartTime) 4164 4165 // Record DNS metrics. If LookupIPAddr returns net.DNSError.IsNotFound, this 4166 // is "no such host" and not a DNS failure. Limitation: the resolver IP is 4167 // not known. 4168 4169 dnsErr, ok := err.(*net.DNSError) 4170 dnsNotFound := ok && dnsErr.IsNotFound 4171 dnsSuccess := err == nil || dnsNotFound 4172 sshClient.updateQualityMetricsWithDNSResult(dnsSuccess, resolveElapsedTime, nil) 4173 4174 // IPv4 is preferred in case the host has limited IPv6 routing. IPv6 is 4175 // selected and attempted only when there's no IPv4 option. 4176 // TODO: shuffle list to try other IPs? 4177 4178 for _, ip := range IPs { 4179 if ip.IP.To4() != nil { 4180 IP = ip.IP 4181 break 4182 } 4183 } 4184 if IP == nil && len(IPs) > 0 { 4185 // If there are no IPv4 IPs, the first IP is IPv6. 4186 IP = IPs[0].IP 4187 } 4188 4189 if err == nil && IP == nil { 4190 err = std_errors.New("no IP address") 4191 } 4192 4193 if err != nil { 4194 4195 // Record a port forward failure 4196 sshClient.updateQualityMetricsWithDialResult(false, resolveElapsedTime, IP) 4197 4198 sshClient.rejectNewChannel(newChannel, fmt.Sprintf("LookupIP failed: %s", err)) 4199 return 4200 } 4201 4202 remainingDialTimeout -= resolveElapsedTime 4203 } 4204 4205 if remainingDialTimeout <= 0 { 4206 sshClient.rejectNewChannel(newChannel, "TCP port forward timed out resolving") 4207 return 4208 } 4209 4210 // When the client has indicated split tunnel mode and when the channel is 4211 // not of type protocol.TCP_PORT_FORWARD_NO_SPLIT_TUNNEL_TYPE, check if the 4212 // client and the port forward destination are in the same GeoIP country. If 4213 // so, reject the port forward with a distinct response code that indicates 4214 // to the client that this port forward should be performed locally, direct 4215 // and untunneled. 4216 // 4217 // Clients are expected to cache untunneled responses to avoid this round 4218 // trip in the immediate future and reduce server load. 4219 // 4220 // When the countries differ, immediately proceed with the standard port 4221 // forward. No additional round trip is required. 4222 // 4223 // If either GeoIP country is "None", one or both countries are unknown 4224 // and there is no match. 4225 // 4226 // Traffic rules, such as allowed ports, are not enforced for port forward 4227 // destinations classified as untunneled. 4228 // 4229 // Domain and IP blocklists still apply to port forward destinations 4230 // classified as untunneled. 4231 // 4232 // The client's use of split tunnel mode is logged in server_tunnel metrics 4233 // as the boolean value split_tunnel. As they may indicate some information 4234 // about browsing activity, no other split tunnel metrics are logged. 4235 4236 if doSplitTunnel { 4237 4238 destinationGeoIPData := sshClient.sshServer.support.GeoIPService.LookupIP(IP) 4239 4240 if sshClient.geoIPData.Country != GEOIP_UNKNOWN_VALUE && 4241 sshClient.handshakeState.splitTunnelLookup.lookup( 4242 destinationGeoIPData.Country) { 4243 4244 // Since isPortForwardPermitted is not called in this case, explicitly call 4245 // ipBlocklistCheck. The domain blocklist case is handled above. 4246 if !sshClient.isIPPermitted(IP) { 4247 // Note: not recording a port forward failure in this case 4248 sshClient.rejectNewChannel(newChannel, "port forward not permitted") 4249 return 4250 } 4251 4252 newChannel.Reject(protocol.CHANNEL_REJECT_REASON_SPLIT_TUNNEL, "") 4253 return 4254 } 4255 } 4256 4257 // Enforce traffic rules, using the resolved IP address. 4258 4259 if !isWebServerPortForward && 4260 !sshClient.isPortForwardPermitted( 4261 portForwardTypeTCP, 4262 IP, 4263 portToConnect) { 4264 // Note: not recording a port forward failure in this case 4265 sshClient.rejectNewChannel(newChannel, "port forward not permitted") 4266 return 4267 } 4268 4269 // TCP dial. 4270 4271 remoteAddr := net.JoinHostPort(IP.String(), strconv.Itoa(portToConnect)) 4272 4273 log.WithTraceFields(LogFields{"remoteAddr": remoteAddr}).Debug("dialing") 4274 4275 ctx, cancelCtx := context.WithTimeout(sshClient.runCtx, remainingDialTimeout) 4276 fwdConn, err := (&net.Dialer{}).DialContext(ctx, "tcp", remoteAddr) 4277 cancelCtx() // "must be called or the new context will remain live until its parent context is cancelled" 4278 4279 // Record port forward success or failure 4280 sshClient.updateQualityMetricsWithDialResult(err == nil, time.Since(dialStartTime), IP) 4281 4282 if err != nil { 4283 4284 // Monitor for low resource error conditions 4285 sshClient.sshServer.monitorPortForwardDialError(err) 4286 4287 sshClient.rejectNewChannel(newChannel, fmt.Sprintf("DialTimeout failed: %s", err)) 4288 return 4289 } 4290 4291 // The upstream TCP port forward connection has been established. Schedule 4292 // some cleanup and notify the SSH client that the channel is accepted. 4293 4294 defer fwdConn.Close() 4295 4296 fwdChannel, requests, err := newChannel.Accept() 4297 if err != nil { 4298 if !isExpectedTunnelIOError(err) { 4299 log.WithTraceFields(LogFields{"error": err}).Warning("accept new channel failed") 4300 } 4301 return 4302 } 4303 go ssh.DiscardRequests(requests) 4304 defer fwdChannel.Close() 4305 4306 // Release the dialing slot and acquire an established slot. 4307 // 4308 // establishedPortForward increments the concurrent TCP port 4309 // forward counter and closes the LRU existing TCP port forward 4310 // when already at the limit. 4311 // 4312 // Known limitations: 4313 // 4314 // - Closed LRU TCP sockets will enter the TIME_WAIT state, 4315 // continuing to consume some resources. 4316 4317 sshClient.establishedPortForward(portForwardTypeTCP, sshClient.tcpPortForwardLRU) 4318 4319 // "established = true" cancels the deferred abortedTCPPortForward() 4320 established = true 4321 4322 // TODO: 64-bit alignment? https://golang.org/pkg/sync/atomic/#pkg-note-BUG 4323 var bytesUp, bytesDown int64 4324 defer func() { 4325 sshClient.closedPortForward( 4326 portForwardTypeTCP, atomic.LoadInt64(&bytesUp), atomic.LoadInt64(&bytesDown)) 4327 }() 4328 4329 lruEntry := sshClient.tcpPortForwardLRU.Add(fwdConn) 4330 defer lruEntry.Remove() 4331 4332 // ActivityMonitoredConn monitors the TCP port forward I/O and updates 4333 // its LRU status. ActivityMonitoredConn also times out I/O on the port 4334 // forward if both reads and writes have been idle for the specified 4335 // duration. 4336 4337 fwdConn, err = common.NewActivityMonitoredConn( 4338 fwdConn, 4339 sshClient.idleTCPPortForwardTimeout(), 4340 true, 4341 lruEntry, 4342 sshClient.getActivityUpdaters(portForwardTypeTCP, IP)...) 4343 if err != nil { 4344 log.WithTraceFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed") 4345 return 4346 } 4347 4348 // Relay channel to forwarded connection. 4349 4350 log.WithTraceFields(LogFields{"remoteAddr": remoteAddr}).Debug("relaying") 4351 4352 // TODO: relay errors to fwdChannel.Stderr()? 4353 relayWaitGroup := new(sync.WaitGroup) 4354 relayWaitGroup.Add(1) 4355 go func() { 4356 defer relayWaitGroup.Done() 4357 // io.Copy allocates a 32K temporary buffer, and each port forward relay 4358 // uses two of these buffers; using common.CopyBuffer with a smaller buffer 4359 // reduces the overall memory footprint. 4360 bytes, err := common.CopyBuffer( 4361 fwdChannel, fwdConn, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE)) 4362 atomic.AddInt64(&bytesDown, bytes) 4363 if err != nil && err != io.EOF { 4364 // Debug since errors such as "connection reset by peer" occur during normal operation 4365 log.WithTraceFields(LogFields{"error": err}).Debug("downstream TCP relay failed") 4366 } 4367 // Interrupt upstream io.Copy when downstream is shutting down. 4368 // TODO: this is done to quickly cleanup the port forward when 4369 // fwdConn has a read timeout, but is it clean -- upstream may still 4370 // be flowing? 4371 fwdChannel.Close() 4372 }() 4373 bytes, err := common.CopyBuffer( 4374 fwdConn, fwdChannel, make([]byte, SSH_TCP_PORT_FORWARD_COPY_BUFFER_SIZE)) 4375 atomic.AddInt64(&bytesUp, bytes) 4376 if err != nil && err != io.EOF { 4377 log.WithTraceFields(LogFields{"error": err}).Debug("upstream TCP relay failed") 4378 } 4379 // Shutdown special case: fwdChannel will be closed and return EOF when 4380 // the SSH connection is closed, but we need to explicitly close fwdConn 4381 // to interrupt the downstream io.Copy, which may be blocked on a 4382 // fwdConn.Read(). 4383 fwdConn.Close() 4384 4385 relayWaitGroup.Wait() 4386 4387 log.WithTraceFields( 4388 LogFields{ 4389 "remoteAddr": remoteAddr, 4390 "bytesUp": atomic.LoadInt64(&bytesUp), 4391 "bytesDown": atomic.LoadInt64(&bytesDown)}).Debug("exiting") 4392 }