github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/server/udp.go (about) 1 /* 2 * Copyright (c) 2016, Psiphon Inc. 3 * All rights reserved. 4 * 5 * This program is free software: you can redistribute it and/or modify 6 * it under the terms of the GNU General Public License as published by 7 * the Free Software Foundation, either version 3 of the License, or 8 * (at your option) any later version. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU General Public License for more details. 14 * 15 * You should have received a copy of the GNU General Public License 16 * along with this program. If not, see <http://www.gnu.org/licenses/>. 17 * 18 */ 19 20 package server 21 22 import ( 23 "bytes" 24 "encoding/binary" 25 "fmt" 26 "io" 27 "net" 28 "strconv" 29 "sync" 30 "sync/atomic" 31 32 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common" 33 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh" 34 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors" 35 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/monotime" 36 ) 37 38 // handleUdpgwChannel implements UDP port forwarding. A single UDP 39 // SSH channel follows the udpgw protocol, which multiplexes many 40 // UDP port forwards. 41 // 42 // The udpgw protocol and original server implementation: 43 // Copyright (c) 2009, Ambroz Bizjak <ambrop7@gmail.com> 44 // https://github.com/ambrop72/badvpn 45 func (sshClient *sshClient) handleUdpgwChannel(newChannel ssh.NewChannel) { 46 47 // Accept this channel immediately. This channel will replace any 48 // previously existing udpgw channel for this client. 49 50 sshChannel, requests, err := newChannel.Accept() 51 if err != nil { 52 if !isExpectedTunnelIOError(err) { 53 log.WithTraceFields(LogFields{"error": err}).Warning("accept new channel failed") 54 } 55 return 56 } 57 go ssh.DiscardRequests(requests) 58 defer sshChannel.Close() 59 60 multiplexer := &udpgwPortForwardMultiplexer{ 61 sshClient: sshClient, 62 sshChannel: sshChannel, 63 portForwards: make(map[uint16]*udpgwPortForward), 64 portForwardLRU: common.NewLRUConns(), 65 relayWaitGroup: new(sync.WaitGroup), 66 runWaitGroup: new(sync.WaitGroup), 67 } 68 69 multiplexer.runWaitGroup.Add(1) 70 71 // setUdpgwChannelHandler will close any existing 72 // udpgwPortForwardMultiplexer, waiting for all run/relayDownstream 73 // goroutines to first terminate and all UDP socket resources to be 74 // cleaned up. 75 // 76 // This synchronous shutdown also ensures that the 77 // concurrentPortForwardCount is reduced to 0 before installing the new 78 // udpgwPortForwardMultiplexer and its LRU object. If the older handler 79 // were to dangle with open port forwards, and concurrentPortForwardCount 80 // were to hit the max, the wrong LRU, the new one, would be used to 81 // close the LRU port forward. 82 // 83 // Call setUdpgwHandler only after runWaitGroup is initialized, to ensure 84 // runWaitGroup.Wait() cannot be invoked (by some subsequent new udpgw 85 // channel) before initialized. 86 87 if !sshClient.setUdpgwChannelHandler(multiplexer) { 88 // setUdpgwChannelHandler returns false if some other SSH channel 89 // calls setUdpgwChannelHandler in the middle of this call. In that 90 // case, discard this channel: the client's latest udpgw channel is 91 // retained. 92 return 93 } 94 95 multiplexer.run() 96 multiplexer.runWaitGroup.Done() 97 } 98 99 type udpgwPortForwardMultiplexer struct { 100 sshClient *sshClient 101 sshChannelWriteMutex sync.Mutex 102 sshChannel ssh.Channel 103 portForwardsMutex sync.Mutex 104 portForwards map[uint16]*udpgwPortForward 105 portForwardLRU *common.LRUConns 106 relayWaitGroup *sync.WaitGroup 107 runWaitGroup *sync.WaitGroup 108 } 109 110 func (mux *udpgwPortForwardMultiplexer) stop() { 111 112 // udpgwPortForwardMultiplexer must be initialized by handleUdpgwChannel. 113 // 114 // stop closes the udpgw SSH channel, which will cause the run goroutine 115 // to exit its message read loop and await closure of all relayDownstream 116 // goroutines. Closing all port forward UDP conns will cause all 117 // relayDownstream to exit. 118 119 _ = mux.sshChannel.Close() 120 121 mux.portForwardsMutex.Lock() 122 for _, portForward := range mux.portForwards { 123 _ = portForward.conn.Close() 124 } 125 mux.portForwardsMutex.Unlock() 126 127 mux.runWaitGroup.Wait() 128 } 129 130 func (mux *udpgwPortForwardMultiplexer) run() { 131 132 // In a loop, read udpgw messages from the client to this channel. Each 133 // message contains a UDP packet to send upstream either via a new port 134 // forward, or on an existing port forward. 135 // 136 // A goroutine is run to read downstream packets for each UDP port forward. All read 137 // packets are encapsulated in udpgw protocol and sent down the channel to the client. 138 // 139 // When the client disconnects or the server shuts down, the channel will close and 140 // readUdpgwMessage will exit with EOF. 141 142 buffer := make([]byte, udpgwProtocolMaxMessageSize) 143 for { 144 // Note: message.packet points to the reusable memory in "buffer". 145 // Each readUdpgwMessage call will overwrite the last message.packet. 146 message, err := readUdpgwMessage(mux.sshChannel, buffer) 147 if err != nil { 148 if err != io.EOF { 149 // Debug since I/O errors occur during normal operation 150 log.WithTraceFields(LogFields{"error": err}).Debug("readUdpgwMessage failed") 151 } 152 break 153 } 154 155 mux.portForwardsMutex.Lock() 156 portForward := mux.portForwards[message.connID] 157 mux.portForwardsMutex.Unlock() 158 159 // In the udpgw protocol, an existing port forward is closed when 160 // either the discard flag is set or the remote address has changed. 161 162 if portForward != nil && 163 (message.discardExistingConn || 164 !bytes.Equal(portForward.remoteIP, message.remoteIP) || 165 portForward.remotePort != message.remotePort) { 166 167 // The port forward's goroutine will complete cleanup, including 168 // tallying stats and calling sshClient.closedPortForward. 169 // portForward.conn.Close() will signal this shutdown. 170 portForward.conn.Close() 171 172 // Synchronously await the termination of the relayDownstream 173 // goroutine. This ensures that the previous goroutine won't 174 // invoke removePortForward, with the connID that will be reused 175 // for the new port forward, after this point. 176 // 177 // Limitation: this synchronous shutdown cannot prevent a "wrong 178 // remote address" error on the badvpn udpgw client, which occurs 179 // when the client recycles a port forward (setting discard) but 180 // receives, from the server, a udpgw message containing the old 181 // remote address for the previous port forward with the same 182 // conn ID. That downstream message from the server may be in 183 // flight in the SSH channel when the client discard message arrives. 184 portForward.relayWaitGroup.Wait() 185 186 portForward = nil 187 } 188 189 if portForward == nil { 190 191 // Create a new port forward 192 193 dialIP := net.IP(message.remoteIP) 194 dialPort := int(message.remotePort) 195 196 // Validate DNS packets and check the domain blocklist both when the client 197 // indicates DNS or when DNS is _not_ indicated and the destination port is 198 // 53. 199 if message.forwardDNS || message.remotePort == 53 { 200 201 domain, err := common.ParseDNSQuestion(message.packet) 202 if err != nil { 203 log.WithTraceFields(LogFields{"error": err}).Debug("ParseDNSQuestion failed") 204 // Drop packet 205 continue 206 } 207 208 if domain != "" { 209 ok, _ := mux.sshClient.isDomainPermitted(domain) 210 if !ok { 211 // Drop packet 212 continue 213 } 214 } 215 } 216 217 if message.forwardDNS { 218 // Transparent DNS forwarding. In this case, isPortForwardPermitted 219 // traffic rules checks are bypassed, since DNS is essential. 220 dialIP = mux.sshClient.sshServer.support.DNSResolver.Get() 221 dialPort = DNS_RESOLVER_PORT 222 223 } else if !mux.sshClient.isPortForwardPermitted( 224 portForwardTypeUDP, dialIP, int(message.remotePort)) { 225 // The udpgw protocol has no error response, so 226 // we just discard the message and read another. 227 continue 228 } 229 230 // Note: UDP port forward counting has no dialing phase 231 232 // establishedPortForward increments the concurrent UDP port 233 // forward counter and closes the LRU existing UDP port forward 234 // when already at the limit. 235 236 mux.sshClient.establishedPortForward(portForwardTypeUDP, mux.portForwardLRU) 237 // Can't defer sshClient.closedPortForward() here; 238 // relayDownstream will call sshClient.closedPortForward() 239 240 // Pre-check log level to avoid overhead of rendering log for 241 // every DNS query and other UDP port forward. 242 if IsLogLevelDebug() { 243 log.WithTraceFields( 244 LogFields{ 245 "remoteAddr": net.JoinHostPort(dialIP.String(), strconv.Itoa(dialPort)), 246 "connID": message.connID}).Debug("dialing") 247 } 248 249 udpConn, err := net.DialUDP( 250 "udp", nil, &net.UDPAddr{IP: dialIP, Port: dialPort}) 251 if err != nil { 252 mux.sshClient.closedPortForward(portForwardTypeUDP, 0, 0) 253 254 // Monitor for low resource error conditions 255 mux.sshClient.sshServer.monitorPortForwardDialError(err) 256 257 // Note: Debug level, as logMessage may contain user traffic destination address information 258 log.WithTraceFields(LogFields{"error": err}).Debug("DialUDP failed") 259 continue 260 } 261 262 lruEntry := mux.portForwardLRU.Add(udpConn) 263 // Can't defer lruEntry.Remove() here; 264 // relayDownstream will call lruEntry.Remove() 265 266 // ActivityMonitoredConn monitors the UDP port forward I/O and updates 267 // its LRU status. ActivityMonitoredConn also times out I/O on the port 268 // forward if both reads and writes have been idle for the specified 269 // duration. 270 271 var activityUpdaters []common.ActivityUpdater 272 // Don't incur activity monitor overhead for DNS requests 273 if !message.forwardDNS { 274 activityUpdaters = mux.sshClient.getActivityUpdaters(portForwardTypeUDP, dialIP) 275 } 276 277 conn, err := common.NewActivityMonitoredConn( 278 udpConn, 279 mux.sshClient.idleUDPPortForwardTimeout(), 280 true, 281 lruEntry, 282 activityUpdaters...) 283 if err != nil { 284 lruEntry.Remove() 285 mux.sshClient.closedPortForward(portForwardTypeUDP, 0, 0) 286 log.WithTraceFields(LogFields{"error": err}).Error("NewActivityMonitoredConn failed") 287 continue 288 } 289 290 portForward = &udpgwPortForward{ 291 connID: message.connID, 292 preambleSize: message.preambleSize, 293 remoteIP: message.remoteIP, 294 remotePort: message.remotePort, 295 dialIP: dialIP, 296 conn: conn, 297 lruEntry: lruEntry, 298 bytesUp: 0, 299 bytesDown: 0, 300 relayWaitGroup: new(sync.WaitGroup), 301 mux: mux, 302 } 303 304 if message.forwardDNS { 305 portForward.dnsFirstWriteTime = int64(monotime.Now()) 306 } 307 308 mux.portForwardsMutex.Lock() 309 mux.portForwards[portForward.connID] = portForward 310 mux.portForwardsMutex.Unlock() 311 312 portForward.relayWaitGroup.Add(1) 313 mux.relayWaitGroup.Add(1) 314 go portForward.relayDownstream() 315 } 316 317 // Note: assumes UDP writes won't block (https://golang.org/pkg/net/#UDPConn.WriteToUDP) 318 _, err = portForward.conn.Write(message.packet) 319 if err != nil { 320 // Debug since errors such as "write: operation not permitted" occur during normal operation 321 log.WithTraceFields(LogFields{"error": err}).Debug("upstream UDP relay failed") 322 // The port forward's goroutine will complete cleanup 323 portForward.conn.Close() 324 } 325 326 portForward.lruEntry.Touch() 327 328 atomic.AddInt64(&portForward.bytesUp, int64(len(message.packet))) 329 } 330 331 // Cleanup all udpgw port forward workers when exiting 332 333 mux.portForwardsMutex.Lock() 334 for _, portForward := range mux.portForwards { 335 // The port forward's goroutine will complete cleanup 336 portForward.conn.Close() 337 } 338 mux.portForwardsMutex.Unlock() 339 340 mux.relayWaitGroup.Wait() 341 } 342 343 func (mux *udpgwPortForwardMultiplexer) removePortForward(connID uint16) { 344 mux.portForwardsMutex.Lock() 345 delete(mux.portForwards, connID) 346 mux.portForwardsMutex.Unlock() 347 } 348 349 type udpgwPortForward struct { 350 // Note: 64-bit ints used with atomic operations are placed 351 // at the start of struct to ensure 64-bit alignment. 352 // (https://golang.org/pkg/sync/atomic/#pkg-note-BUG) 353 dnsFirstWriteTime int64 354 dnsFirstReadTime int64 355 bytesUp int64 356 bytesDown int64 357 connID uint16 358 preambleSize int 359 remoteIP []byte 360 remotePort uint16 361 dialIP net.IP 362 conn net.Conn 363 lruEntry *common.LRUConnsEntry 364 relayWaitGroup *sync.WaitGroup 365 mux *udpgwPortForwardMultiplexer 366 } 367 368 func (portForward *udpgwPortForward) relayDownstream() { 369 defer portForward.relayWaitGroup.Done() 370 defer portForward.mux.relayWaitGroup.Done() 371 372 // Downstream UDP packets are read into the reusable memory 373 // in "buffer" starting at the offset past the udpgw message 374 // header and address, leaving enough space to write the udpgw 375 // values into the same buffer and use for writing to the ssh 376 // channel. 377 // 378 // Note: there is one downstream buffer per UDP port forward, 379 // while for upstream there is one buffer per client. 380 // TODO: is the buffer size larger than necessary? 381 buffer := make([]byte, udpgwProtocolMaxMessageSize) 382 packetBuffer := buffer[portForward.preambleSize:udpgwProtocolMaxMessageSize] 383 for { 384 // TODO: if read buffer is too small, excess bytes are discarded? 385 packetSize, err := portForward.conn.Read(packetBuffer) 386 if packetSize > udpgwProtocolMaxPayloadSize { 387 err = fmt.Errorf("unexpected packet size: %d", packetSize) 388 } 389 if err != nil { 390 if err != io.EOF { 391 // Debug since errors such as "use of closed network connection" occur during normal operation 392 log.WithTraceFields(LogFields{"error": err}).Debug("downstream UDP relay failed") 393 } 394 break 395 } 396 397 if atomic.LoadInt64(&portForward.dnsFirstWriteTime) > 0 && 398 atomic.LoadInt64(&portForward.dnsFirstReadTime) == 0 { // Check if already set before invoking Now. 399 atomic.CompareAndSwapInt64(&portForward.dnsFirstReadTime, 0, int64(monotime.Now())) 400 } 401 402 err = writeUdpgwPreamble( 403 portForward.preambleSize, 404 0, 405 portForward.connID, 406 portForward.remoteIP, 407 portForward.remotePort, 408 uint16(packetSize), 409 buffer) 410 if err == nil { 411 // ssh.Channel.Write cannot be called concurrently. 412 // See: https://github.com/Psiphon-Inc/crypto/blob/82d98b4c7c05e81f92545f6fddb45d4541e6da00/ssh/channel.go#L272, 413 // https://codereview.appspot.com/136420043/diff/80002/ssh/channel.go 414 portForward.mux.sshChannelWriteMutex.Lock() 415 _, err = portForward.mux.sshChannel.Write(buffer[0 : portForward.preambleSize+packetSize]) 416 portForward.mux.sshChannelWriteMutex.Unlock() 417 } 418 419 if err != nil { 420 // Close the channel, which will interrupt the main loop. 421 portForward.mux.sshChannel.Close() 422 log.WithTraceFields(LogFields{"error": err}).Debug("downstream UDP relay failed") 423 break 424 } 425 426 portForward.lruEntry.Touch() 427 428 atomic.AddInt64(&portForward.bytesDown, int64(packetSize)) 429 } 430 431 portForward.mux.removePortForward(portForward.connID) 432 433 portForward.lruEntry.Remove() 434 435 portForward.conn.Close() 436 437 bytesUp := atomic.LoadInt64(&portForward.bytesUp) 438 bytesDown := atomic.LoadInt64(&portForward.bytesDown) 439 portForward.mux.sshClient.closedPortForward(portForwardTypeUDP, bytesUp, bytesDown) 440 441 dnsStartTime := monotime.Time(atomic.LoadInt64(&portForward.dnsFirstWriteTime)) 442 if dnsStartTime > 0 { 443 444 // Record DNS metrics using a heuristic: if a UDP packet was written and 445 // then a packet was read, assume the DNS request successfully received a 446 // valid response; failure occurs when the resolver fails to provide a 447 // response; a "no such host" response is still a success. Limitations: we 448 // assume a resolver will not respond when, e.g., rate limiting; we ignore 449 // subsequent requests made via the same UDP port forward. 450 451 dnsEndTime := monotime.Time(atomic.LoadInt64(&portForward.dnsFirstReadTime)) 452 453 dnsSuccess := true 454 if dnsEndTime == 0 { 455 dnsSuccess = false 456 dnsEndTime = monotime.Now() 457 } 458 459 resolveElapsedTime := dnsEndTime.Sub(dnsStartTime) 460 461 portForward.mux.sshClient.updateQualityMetricsWithDNSResult( 462 dnsSuccess, 463 resolveElapsedTime, 464 net.IP(portForward.dialIP)) 465 } 466 467 log.WithTraceFields( 468 LogFields{ 469 "remoteAddr": net.JoinHostPort( 470 net.IP(portForward.remoteIP).String(), strconv.Itoa(int(portForward.remotePort))), 471 "bytesUp": bytesUp, 472 "bytesDown": bytesDown, 473 "connID": portForward.connID}).Debug("exiting") 474 } 475 476 // TODO: express and/or calculate udpgwProtocolMaxPayloadSize as function of MTU? 477 const ( 478 udpgwProtocolFlagKeepalive = 1 << 0 479 udpgwProtocolFlagRebind = 1 << 1 480 udpgwProtocolFlagDNS = 1 << 2 481 udpgwProtocolFlagIPv6 = 1 << 3 482 483 udpgwProtocolMaxPreambleSize = 23 484 udpgwProtocolMaxPayloadSize = 32768 485 udpgwProtocolMaxMessageSize = udpgwProtocolMaxPreambleSize + udpgwProtocolMaxPayloadSize 486 ) 487 488 type udpgwProtocolMessage struct { 489 connID uint16 490 preambleSize int 491 remoteIP []byte 492 remotePort uint16 493 discardExistingConn bool 494 forwardDNS bool 495 packet []byte 496 } 497 498 func readUdpgwMessage( 499 reader io.Reader, buffer []byte) (*udpgwProtocolMessage, error) { 500 501 // udpgw message layout: 502 // 503 // | 2 byte size | 3 byte header | 6 or 18 byte address | variable length packet | 504 505 for { 506 // Read message 507 508 _, err := io.ReadFull(reader, buffer[0:2]) 509 if err != nil { 510 if err != io.EOF { 511 err = errors.Trace(err) 512 } 513 return nil, err 514 } 515 516 size := binary.LittleEndian.Uint16(buffer[0:2]) 517 518 if size < 3 || int(size) > len(buffer)-2 { 519 return nil, errors.TraceNew("invalid udpgw message size") 520 } 521 522 _, err = io.ReadFull(reader, buffer[2:2+size]) 523 if err != nil { 524 if err != io.EOF { 525 err = errors.Trace(err) 526 } 527 return nil, err 528 } 529 530 flags := buffer[2] 531 532 connID := binary.LittleEndian.Uint16(buffer[3:5]) 533 534 // Ignore udpgw keep-alive messages -- read another message 535 536 if flags&udpgwProtocolFlagKeepalive == udpgwProtocolFlagKeepalive { 537 continue 538 } 539 540 // Read address 541 542 var remoteIP []byte 543 var remotePort uint16 544 var packetStart, packetEnd int 545 546 if flags&udpgwProtocolFlagIPv6 == udpgwProtocolFlagIPv6 { 547 548 if size < 21 { 549 return nil, errors.TraceNew("invalid udpgw message size") 550 } 551 552 remoteIP = make([]byte, 16) 553 copy(remoteIP, buffer[5:21]) 554 remotePort = binary.BigEndian.Uint16(buffer[21:23]) 555 packetStart = 23 556 packetEnd = 23 + int(size) - 21 557 558 } else { 559 560 if size < 9 { 561 return nil, errors.TraceNew("invalid udpgw message size") 562 } 563 564 remoteIP = make([]byte, 4) 565 copy(remoteIP, buffer[5:9]) 566 remotePort = binary.BigEndian.Uint16(buffer[9:11]) 567 packetStart = 11 568 packetEnd = 11 + int(size) - 9 569 } 570 571 // Assemble message 572 // Note: udpgwProtocolMessage.packet references memory in the input buffer 573 574 message := &udpgwProtocolMessage{ 575 connID: connID, 576 preambleSize: packetStart, 577 remoteIP: remoteIP, 578 remotePort: remotePort, 579 discardExistingConn: flags&udpgwProtocolFlagRebind == udpgwProtocolFlagRebind, 580 forwardDNS: flags&udpgwProtocolFlagDNS == udpgwProtocolFlagDNS, 581 packet: buffer[packetStart:packetEnd], 582 } 583 584 return message, nil 585 } 586 } 587 588 func writeUdpgwPreamble( 589 preambleSize int, 590 flags uint8, 591 connID uint16, 592 remoteIP []byte, 593 remotePort uint16, 594 packetSize uint16, 595 buffer []byte) error { 596 597 if preambleSize != 7+len(remoteIP) { 598 return errors.TraceNew("invalid udpgw preamble size") 599 } 600 601 size := uint16(preambleSize-2) + packetSize 602 603 // size 604 binary.LittleEndian.PutUint16(buffer[0:2], size) 605 606 // flags 607 buffer[2] = flags 608 609 // connID 610 binary.LittleEndian.PutUint16(buffer[3:5], connID) 611 612 // addr 613 copy(buffer[5:5+len(remoteIP)], remoteIP) 614 binary.BigEndian.PutUint16(buffer[5+len(remoteIP):7+len(remoteIP)], remotePort) 615 616 return nil 617 }