github.com/geph-official/geph2@v0.22.6-0.20210211030601-f527cb59b0df/libs/kcp-go/sess.go (about) 1 // Package kcp-go is a Reliable-UDP library for golang. 2 // 3 // This library intents to provide a smooth, resilient, ordered, 4 // error-checked and anonymous delivery of streams over UDP packets. 5 // 6 // The interfaces of this package aims to be compatible with 7 // net.Conn in standard library, but offers powerful features for advanced users. 8 package kcp 9 10 import ( 11 "crypto/rand" 12 "encoding/binary" 13 "io" 14 "log" 15 "net" 16 "sync" 17 "sync/atomic" 18 "time" 19 20 "github.com/pkg/errors" 21 "golang.org/x/net/ipv4" 22 "golang.org/x/net/ipv6" 23 ) 24 25 const ( 26 // 16-bytes nonce for each packet 27 nonceSize = 16 28 29 // 4-bytes packet checksum 30 crcSize = 4 31 32 // overall crypto header size 33 cryptHeaderSize = nonceSize + crcSize 34 35 // maximum packet size 36 mtuLimit = 1500 37 38 // FEC keeps rxFECMulti* (dataShard+parityShard) ordered packets in memory 39 rxFECMulti = 3 40 41 // accept backlog 42 acceptBacklog = 128 43 ) 44 45 var ( 46 errInvalidOperation = errors.New("invalid operation") 47 errTimeout = errors.New("timeout") 48 ) 49 50 var ( 51 // a system-wide packet buffer shared among sending, receiving and FEC 52 // to mitigate high-frequency memory allocation for packets 53 xmitBuf sync.Pool 54 ) 55 56 func init() { 57 xmitBuf.New = func() interface{} { 58 return make([]byte, mtuLimit) 59 } 60 } 61 62 type ( 63 // LossReporter reports losses 64 LossReporter interface { 65 net.PacketConn 66 UnderlyingLoss(destAddr net.Addr) (frac float64) 67 } 68 // UDPSession defines a KCP session implemented by UDP 69 UDPSession struct { 70 lastFecRate float64 71 updaterIdx int // record slice index in updater 72 conn net.PacketConn // the underlying packet connection 73 lossReporter LossReporter 74 kcp *KCP // KCP ARQ protocol 75 l *Listener // pointing to the Listener object if it's been accepted by a Listener 76 77 // kcp receiving is based on packets 78 // recvbuf turns packets into stream 79 recvbuf []byte 80 bufptr []byte 81 82 // FEC codec 83 fecDecoder *fecDecoder 84 fecEncoder *fecEncoder 85 86 // settings 87 remote net.Addr // remote peer address 88 rd time.Time // read deadline 89 wd time.Time // write deadline 90 ackNoDelay bool // send ack immediately for each incoming packet(testing purpose) 91 writeDelay bool // delay kcp.flush() for Write() for bulk transfer 92 dup int // duplicate udp packets(testing purpose) 93 94 // notifications 95 die chan struct{} // notify current session has Closed 96 dieOnce sync.Once 97 chReadEvent chan struct{} // notify Read() can be called without blocking 98 chWriteEvent chan struct{} // notify Write() can be called without blocking 99 100 // socket error handling 101 socketReadError atomic.Value 102 socketWriteError atomic.Value 103 chSocketReadError chan struct{} 104 chSocketWriteError chan struct{} 105 socketReadErrorOnce sync.Once 106 socketWriteErrorOnce sync.Once 107 108 // nonce generator 109 nonce Entropy 110 111 // packets waiting to be sent on wire 112 txqueue []ipv4.Message 113 xconn batchConn // for x/net 114 xconnWriteError error 115 116 pacer struct { 117 nextSendTime time.Time 118 } 119 120 updater updateHeap 121 122 fecbuffer [][]byte 123 124 mu sync.Mutex 125 } 126 127 setReadBuffer interface { 128 SetReadBuffer(bytes int) error 129 } 130 131 setWriteBuffer interface { 132 SetWriteBuffer(bytes int) error 133 } 134 135 setDSCP interface { 136 SetDSCP(int) error 137 } 138 ) 139 140 func (s *UDPSession) headerSize() int { 141 if s.fecEncoder != nil { 142 return fecHeaderSizePlus2 143 } 144 return 0 145 } 146 147 // newUDPSession create a new udp session for client or server 148 func newUDPSession(conv uint32, dataShards, parityShards int, l *Listener, conn net.PacketConn, remote net.Addr, block BlockCrypt) *UDPSession { 149 sess := new(UDPSession) 150 sess.die = make(chan struct{}) 151 sess.nonce = new(nonceAES128) 152 sess.nonce.Init() 153 sess.chReadEvent = make(chan struct{}, 1) 154 sess.chWriteEvent = make(chan struct{}, 1) 155 sess.chSocketReadError = make(chan struct{}) 156 sess.chSocketWriteError = make(chan struct{}) 157 sess.remote = remote 158 sess.conn = conn 159 sess.l = l 160 sess.updater.init() 161 go sess.updater.updateTask() 162 sess.recvbuf = make([]byte, mtuLimit) 163 164 // cast to writebatch conn 165 if _, ok := conn.(*net.UDPConn); ok { 166 addr, err := net.ResolveUDPAddr("udp", conn.LocalAddr().String()) 167 if err == nil { 168 if addr.IP.To4() != nil { 169 sess.xconn = ipv4.NewPacketConn(conn) 170 } else { 171 sess.xconn = ipv6.NewPacketConn(conn) 172 } 173 } 174 } 175 176 sess.lossReporter, _ = conn.(LossReporter) 177 178 // FEC codec initialization 179 sess.fecDecoder = newFECDecoder(rxFECMulti*(dataShards+parityShards), dataShards, parityShards) 180 sess.fecEncoder = newFECEncoder(dataShards, parityShards, 0) 181 182 sess.kcp = NewKCP(conv, func(buf []byte, size int) { 183 if size >= IKCP_OVERHEAD+sess.headerSize() { 184 sess.output(buf[:size]) 185 } 186 }) 187 sess.kcp.ReserveBytes(sess.headerSize()) 188 189 // register current session to the global updater, 190 // which call sess.update() periodically. 191 sess.updater.addSession(sess) 192 193 if sess.l == nil { // it's a client connection 194 go sess.readLoop() 195 atomic.AddUint64(&DefaultSnmp.ActiveOpens, 1) 196 } else { 197 atomic.AddUint64(&DefaultSnmp.PassiveOpens, 1) 198 } 199 200 currestab := atomic.AddUint64(&DefaultSnmp.CurrEstab, 1) 201 maxconn := atomic.LoadUint64(&DefaultSnmp.MaxConn) 202 if currestab > maxconn { 203 atomic.CompareAndSwapUint64(&DefaultSnmp.MaxConn, maxconn, currestab) 204 } 205 206 return sess 207 } 208 209 // FlowStats summarizes flow statistics 210 func (s *UDPSession) FlowStats() (btlBw float64, latency float64, lossFrac float64) { 211 return s.kcp.DRE.maxAckRate, s.kcp.DRE.minRtt, float64(s.kcp.retrans) / float64(s.kcp.trans) 212 } 213 214 // Read implements net.Conn 215 func (s *UDPSession) Read(b []byte) (n int, err error) { 216 for { 217 s.mu.Lock() 218 if len(s.bufptr) > 0 { // copy from buffer into b 219 n = copy(b, s.bufptr) 220 s.bufptr = s.bufptr[n:] 221 s.mu.Unlock() 222 atomic.AddUint64(&DefaultSnmp.BytesReceived, uint64(n)) 223 return n, nil 224 } 225 if s.kcp.isDead { 226 s.mu.Unlock() 227 go s.Close() 228 err = io.ErrClosedPipe 229 return 230 } 231 232 if size := s.kcp.PeekSize(); size > 0 { // peek data size from kcp 233 if len(b) >= size { // receive data into 'b' directly 234 s.kcp.Recv(b) 235 s.mu.Unlock() 236 atomic.AddUint64(&DefaultSnmp.BytesReceived, uint64(size)) 237 return size, nil 238 } 239 240 // if necessary resize the stream buffer to guarantee a sufficent buffer space 241 if cap(s.recvbuf) < size { 242 s.recvbuf = make([]byte, size) 243 } 244 245 // resize the length of recvbuf to correspond to data size 246 s.recvbuf = s.recvbuf[:size] 247 s.kcp.Recv(s.recvbuf) 248 n = copy(b, s.recvbuf) // copy to 'b' 249 s.bufptr = s.recvbuf[n:] // pointer update 250 s.mu.Unlock() 251 atomic.AddUint64(&DefaultSnmp.BytesReceived, uint64(n)) 252 return n, nil 253 } 254 255 // deadline for current reading operation 256 var timeout *time.Timer 257 var c <-chan time.Time 258 if !s.rd.IsZero() { 259 if time.Now().After(s.rd) { 260 s.mu.Unlock() 261 return 0, errors.WithStack(errTimeout) 262 } 263 264 delay := s.rd.Sub(time.Now()) 265 timeout = time.NewTimer(delay) 266 c = timeout.C 267 } 268 s.mu.Unlock() 269 270 // wait for read event or timeout or error 271 select { 272 case <-s.chReadEvent: 273 if timeout != nil { 274 timeout.Stop() 275 } 276 case <-c: 277 return 0, errors.WithStack(errTimeout) 278 case <-s.chSocketReadError: 279 return 0, s.socketReadError.Load().(error) 280 case <-s.die: 281 return 0, errors.WithStack(io.ErrClosedPipe) 282 } 283 } 284 } 285 286 // Write implements net.Conn 287 func (s *UDPSession) Write(b []byte) (n int, err error) { 288 return s.WriteBuffers([][]byte{b}) 289 } 290 291 // WriteBuffers write a vector of byte slices to the underlying connection 292 func (s *UDPSession) WriteBuffers(v [][]byte) (n int, err error) { 293 defer s.updater.addSessionIfNotExists(s) 294 for { 295 select { 296 case <-s.chSocketWriteError: 297 return 0, s.socketWriteError.Load().(error) 298 case <-s.die: 299 return 0, errors.WithStack(io.ErrClosedPipe) 300 default: 301 } 302 303 s.mu.Lock() 304 305 // make sure write do not overflow the max sliding window on both side 306 waitsnd := s.kcp.WaitSnd() 307 if waitsnd < int(s.kcp.snd_wnd) && waitsnd < int(s.kcp.rmt_wnd) && 308 (s.kcp.nocwnd == 1 || waitsnd < int(s.kcp.cwnd)) { 309 count := 0 310 for _, b := range v { 311 n += len(b) 312 for { 313 count++ 314 if len(b) <= int(s.kcp.mss) { 315 s.kcp.Send(b) 316 break 317 } else { 318 s.kcp.Send(b[:s.kcp.mss]) 319 b = b[s.kcp.mss:] 320 } 321 } 322 } 323 waitsnd = s.kcp.WaitSnd() 324 if waitsnd >= int(s.kcp.snd_wnd) || waitsnd >= int(s.kcp.rmt_wnd) || !s.writeDelay { 325 s.kcp.flush(false) 326 s.uncork() 327 } 328 s.mu.Unlock() 329 atomic.AddUint64(&DefaultSnmp.BytesSent, uint64(n)) 330 331 return n, nil 332 } 333 334 var timeout *time.Timer 335 var c <-chan time.Time 336 if !s.wd.IsZero() { 337 if time.Now().After(s.wd) { 338 s.mu.Unlock() 339 return 0, errors.WithStack(errTimeout) 340 } 341 delay := s.wd.Sub(time.Now()) 342 timeout = time.NewTimer(delay) 343 c = timeout.C 344 } 345 s.mu.Unlock() 346 347 select { 348 case <-s.chWriteEvent: 349 if timeout != nil { 350 timeout.Stop() 351 } 352 case <-c: 353 return 0, errors.WithStack(errTimeout) 354 case <-s.chSocketWriteError: 355 return 0, s.socketWriteError.Load().(error) 356 case <-s.die: 357 return 0, errors.WithStack(io.ErrClosedPipe) 358 } 359 } 360 } 361 362 // uncork sends data in txqueue if there is any 363 func (s *UDPSession) uncork() { 364 if len(s.txqueue) > 0 { 365 s.tx(s.txqueue) 366 s.txqueue = s.txqueue[:0] 367 } 368 return 369 } 370 371 // Close closes the connection. 372 func (s *UDPSession) Close() error { 373 var once bool 374 s.dieOnce.Do(func() { 375 close(s.die) 376 once = true 377 }) 378 379 if once { 380 // remove from updater 381 s.updater.stop.Kill(io.EOF) 382 atomic.AddUint64(&DefaultSnmp.CurrEstab, ^uint64(0)) 383 384 if s.l != nil { // belongs to listener 385 s.l.closeSession(s.remote) 386 return nil 387 } else { // client socket close 388 return s.conn.Close() 389 } 390 } else { 391 return errors.WithStack(io.ErrClosedPipe) 392 } 393 } 394 395 // LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it. 396 func (s *UDPSession) LocalAddr() net.Addr { return s.conn.LocalAddr() } 397 398 // RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it. 399 func (s *UDPSession) RemoteAddr() net.Addr { return s.remote } 400 401 // SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline. 402 func (s *UDPSession) SetDeadline(t time.Time) error { 403 s.mu.Lock() 404 defer s.mu.Unlock() 405 s.rd = t 406 s.wd = t 407 s.notifyReadEvent() 408 s.notifyWriteEvent() 409 return nil 410 } 411 412 // SetReadDeadline implements the Conn SetReadDeadline method. 413 func (s *UDPSession) SetReadDeadline(t time.Time) error { 414 s.mu.Lock() 415 defer s.mu.Unlock() 416 s.rd = t 417 s.notifyReadEvent() 418 return nil 419 } 420 421 // SetWriteDeadline implements the Conn SetWriteDeadline method. 422 func (s *UDPSession) SetWriteDeadline(t time.Time) error { 423 s.mu.Lock() 424 defer s.mu.Unlock() 425 s.wd = t 426 s.notifyWriteEvent() 427 return nil 428 } 429 430 // SetWriteDelay delays write for bulk transfer until the next update interval 431 func (s *UDPSession) SetWriteDelay(delay bool) { 432 s.mu.Lock() 433 defer s.mu.Unlock() 434 s.writeDelay = delay 435 } 436 437 // SetWindowSize set maximum window size 438 func (s *UDPSession) SetWindowSize(sndwnd, rcvwnd int) { 439 s.mu.Lock() 440 defer s.mu.Unlock() 441 s.kcp.WndSize(sndwnd, rcvwnd) 442 } 443 444 // SetMtu sets the maximum transmission unit(not including UDP header) 445 func (s *UDPSession) SetMtu(mtu int) bool { 446 if mtu > mtuLimit { 447 return false 448 } 449 450 s.mu.Lock() 451 defer s.mu.Unlock() 452 s.kcp.SetMtu(mtu) 453 return true 454 } 455 456 // SetStreamMode toggles the stream mode on/off 457 func (s *UDPSession) SetStreamMode(enable bool) { 458 s.mu.Lock() 459 defer s.mu.Unlock() 460 if enable { 461 s.kcp.stream = 1 462 } else { 463 s.kcp.stream = 0 464 } 465 } 466 467 // SetACKNoDelay changes ack flush option, set true to flush ack immediately, 468 func (s *UDPSession) SetACKNoDelay(nodelay bool) { 469 s.mu.Lock() 470 defer s.mu.Unlock() 471 s.ackNoDelay = nodelay 472 } 473 474 // (deprecated) 475 // 476 // SetDUP duplicates udp packets for kcp output. 477 func (s *UDPSession) SetDUP(dup int) { 478 s.mu.Lock() 479 defer s.mu.Unlock() 480 s.dup = dup 481 } 482 483 // SetNoDelay calls nodelay() of kcp 484 // https://github.com/skywind3000/kcp/blob/master/README.en.md#protocol-configuration 485 func (s *UDPSession) SetNoDelay(nodelay, interval, resend, nc int) { 486 s.mu.Lock() 487 defer s.mu.Unlock() 488 s.kcp.NoDelay(nodelay, interval, resend, nc) 489 } 490 491 // SetDSCP sets the 6bit DSCP field in IPv4 header, or 8bit Traffic Class in IPv6 header. 492 // 493 // if the underlying connection has implemented `func SetDSCP(int) error`, SetDSCP() will invoke 494 // this function instead. 495 // 496 // It has no effect if it's accepted from Listener. 497 func (s *UDPSession) SetDSCP(dscp int) error { 498 s.mu.Lock() 499 defer s.mu.Unlock() 500 if s.l != nil { 501 return errInvalidOperation 502 } 503 504 // interface enabled 505 if ts, ok := s.conn.(setDSCP); ok { 506 return ts.SetDSCP(dscp) 507 } 508 509 if nc, ok := s.conn.(net.Conn); ok { 510 var succeed bool 511 if err := ipv4.NewConn(nc).SetTOS(dscp << 2); err == nil { 512 succeed = true 513 } 514 if err := ipv6.NewConn(nc).SetTrafficClass(dscp); err == nil { 515 succeed = true 516 } 517 518 if succeed { 519 return nil 520 } 521 } 522 return errInvalidOperation 523 } 524 525 // SetReadBuffer sets the socket read buffer, no effect if it's accepted from Listener 526 func (s *UDPSession) SetReadBuffer(bytes int) error { 527 s.mu.Lock() 528 defer s.mu.Unlock() 529 if s.l == nil { 530 if nc, ok := s.conn.(setReadBuffer); ok { 531 return nc.SetReadBuffer(bytes) 532 } 533 } 534 return errInvalidOperation 535 } 536 537 // SetWriteBuffer sets the socket write buffer, no effect if it's accepted from Listener 538 func (s *UDPSession) SetWriteBuffer(bytes int) error { 539 s.mu.Lock() 540 defer s.mu.Unlock() 541 if s.l == nil { 542 if nc, ok := s.conn.(setWriteBuffer); ok { 543 return nc.SetWriteBuffer(bytes) 544 } 545 } 546 return errInvalidOperation 547 } 548 549 func loss2fecfrac(loss float64) float64 { 550 if loss < 0 { 551 return 0 552 } else if loss < 0.01 { 553 return 0 554 } else if loss < 0.02 { 555 return 0.0625 556 } else if loss < 0.03 { 557 return 0.0625 558 } else if loss < 0.04 { 559 return 0.0625 560 } else if loss < 0.05 { 561 return 0.125 562 } else if loss < 0.06 { 563 return 0.125 564 } else if loss < 0.07 { 565 return 0.125 566 } else if loss < 0.08 { 567 return 0.1875 568 } else if loss < 0.09 { 569 return 0.1875 570 } else if loss < 0.1 { 571 return 0.1875 572 } else if loss < 0.11 { 573 return 0.1875 574 } else if loss < 0.12 { 575 return 0.25 576 } else if loss < 0.13 { 577 return 0.25 578 } else if loss < 0.14 { 579 return 0.25 580 } else if loss < 0.15 { 581 return 0.3125 582 } else if loss < 0.16 { 583 return 0.3125 584 } else if loss < 0.17 { 585 return 0.3125 586 } else if loss < 0.18 { 587 return 0.375 588 } else if loss < 0.19 { 589 return 0.375 590 } else if loss < 0.2 { 591 return 0.375 592 } else if loss < 0.21 { 593 return 0.4375 594 } else if loss < 0.22 { 595 return 0.4375 596 } else if loss < 0.23 { 597 return 0.4375 598 } else if loss < 0.24 { 599 return 0.5 600 } else if loss < 0.25 { 601 return 0.5 602 } else if loss < 0.26 { 603 return 0.5 604 } else if loss < 0.27 { 605 return 0.5625 606 } else if loss < 0.28 { 607 return 0.5625 608 } else if loss < 0.29 { 609 return 0.625 610 } else if loss < 0.3 { 611 return 0.625 612 } else if loss < 0.31 { 613 return 0.6875 614 } else if loss < 0.32 { 615 return 0.6875 616 } else if loss < 0.33 { 617 return 0.6875 618 } else if loss < 0.34 { 619 return 0.75 620 } else if loss < 0.35 { 621 return 0.75 622 } else if loss < 0.36 { 623 return 0.8125 624 } else if loss < 0.37 { 625 return 0.8125 626 } else if loss < 0.38 { 627 return 0.875 628 } else if loss < 0.39 { 629 return 0.875 630 } else if loss < 0.4 { 631 return 0.9375 632 } 633 return 1.0 634 } 635 636 // post-processing for sending a packet from kcp core 637 // steps: 638 // 1. FEC packet generation 639 // 2. CRC32 integrity 640 // 3. Encryption 641 // 4. TxQueue 642 func (s *UDPSession) output(buf []byte) { 643 var ecc [][]byte 644 645 // 1. FEC encoding 646 if s.fecEncoder != nil { 647 ecc = s.fecEncoder.encode(buf) 648 } 649 650 // 2&3. crc32 & encryption 651 // if s.block != nil { 652 // s.nonce.Fill(buf[:nonceSize]) 653 // checksum := crc32.ChecksumIEEE(buf[cryptHeaderSize:]) 654 // binary.LittleEndian.PutUint32(buf[nonceSize:], checksum) 655 // s.block.Encrypt(buf, buf) 656 657 // for k := range ecc { 658 // s.nonce.Fill(ecc[k][:nonceSize]) 659 // checksum := crc32.ChecksumIEEE(ecc[k][cryptHeaderSize:]) 660 // binary.LittleEndian.PutUint32(ecc[k][nonceSize:], checksum) 661 // s.block.Encrypt(ecc[k], ecc[k]) 662 // } 663 // } 664 665 // 4. TxQueue 666 var msg ipv4.Message 667 for i := 0; i < s.dup+1; i++ { 668 bts := xmitBuf.Get().([]byte)[:len(buf)] 669 copy(bts, buf) 670 msg.Buffers = [][]byte{bts} 671 msg.Addr = s.remote 672 s.txqueue = append(s.txqueue, msg) 673 } 674 675 fecRate := 0.0 676 if s.lossReporter != nil { 677 //log.Println("LOSS REPORTER", s.lossReporter.UnderlyingLoss(s.remote)) 678 loss := s.lossReporter.UnderlyingLoss(s.remote) 679 fecRate = loss2fecfrac(loss) 680 if fecRate != s.lastFecRate { 681 if doLogging { 682 log.Println("fec rate is", fecRate, loss) 683 } 684 s.lastFecRate = fecRate 685 } 686 } 687 for k := range ecc { 688 if float64(k) <= float64(len(ecc))*fecRate { 689 bts := xmitBuf.Get().([]byte)[:len(ecc[k])] 690 copy(bts, ecc[k]) 691 s.fecbuffer = append(s.fecbuffer, bts) 692 } 693 } 694 if len(s.fecbuffer) > 0 { 695 bts := s.fecbuffer[0] 696 s.fecbuffer = s.fecbuffer[1:] 697 msg.Buffers = [][]byte{bts} 698 msg.Addr = s.remote 699 s.txqueue = append(s.txqueue, msg) 700 } 701 } 702 703 // kcp update, returns interval for next calling 704 func (s *UDPSession) update() (interval time.Duration) { 705 s.mu.Lock() 706 waitsnd := s.kcp.WaitSnd() 707 cwnd := s.kcp.cwnd 708 interval = time.Duration(s.kcp.flush(false)) * time.Millisecond 709 if s.kcp.WaitSnd() < waitsnd || s.kcp.cwnd != cwnd { 710 s.notifyWriteEvent() 711 s.notifyReadEvent() 712 } 713 s.uncork() 714 if s.kcp.quiescent <= 0 || s.kcp.isDead { 715 interval = 0 716 } 717 if s.kcp.isDead { 718 go s.Close() 719 } 720 s.mu.Unlock() 721 return 722 } 723 724 // GetConv gets conversation id of a session 725 func (s *UDPSession) GetConv() uint32 { return s.kcp.conv } 726 727 func (s *UDPSession) notifyReadEvent() { 728 select { 729 case s.chReadEvent <- struct{}{}: 730 default: 731 } 732 } 733 734 func (s *UDPSession) notifyWriteEvent() { 735 select { 736 case s.chWriteEvent <- struct{}{}: 737 default: 738 } 739 } 740 741 func (s *UDPSession) notifyReadError(err error) { 742 s.socketReadErrorOnce.Do(func() { 743 s.socketReadError.Store(err) 744 close(s.chSocketReadError) 745 }) 746 } 747 748 func (s *UDPSession) notifyWriteError(err error) { 749 s.socketWriteErrorOnce.Do(func() { 750 s.socketWriteError.Store(err) 751 close(s.chSocketWriteError) 752 }) 753 } 754 755 // packet input stage 756 func (s *UDPSession) packetInput(data []byte) { 757 s.kcpInput(data) 758 } 759 760 func (s *UDPSession) kcpInput(data []byte) { 761 defer s.updater.addSessionIfNotExists(s) 762 var kcpInErrors, fecErrs, fecRecovered, fecParityShards uint64 763 if s.fecDecoder != nil { 764 if len(data) > fecHeaderSize { // must be larger than fec header size 765 f := fecPacket(data) 766 if f.flag() == typeData || f.flag() == typeParity { // header check 767 if f.flag() == typeParity { 768 fecParityShards++ 769 } 770 771 s.mu.Lock() 772 recovers := s.fecDecoder.decode(f) 773 waitsnd := s.kcp.WaitSnd() 774 if f.flag() == typeData { 775 if ret := s.kcp.Input(data[fecHeaderSizePlus2:], true, s.ackNoDelay); ret != 0 { 776 kcpInErrors++ 777 } 778 } 779 780 for _, r := range recovers { 781 if len(r) >= 2 { // must be larger than 2bytes 782 sz := binary.LittleEndian.Uint16(r) 783 if int(sz) <= len(r) && sz >= 2 { 784 if ret := s.kcp.Input(r[2:sz], false, s.ackNoDelay); ret == 0 { 785 fecRecovered++ 786 } else { 787 kcpInErrors++ 788 } 789 } else { 790 fecErrs++ 791 } 792 } else { 793 fecErrs++ 794 } 795 // recycle the recovers 796 xmitBuf.Put(r) 797 } 798 799 // to notify the readers to receive the data 800 if n := s.kcp.PeekSize(); n > 0 { 801 s.notifyReadEvent() 802 } 803 // to notify the writers when queue is shorter(e.g. ACKed) 804 if s.kcp.WaitSnd() < waitsnd || true { 805 s.notifyWriteEvent() 806 } 807 s.uncork() 808 s.mu.Unlock() 809 } else { 810 kcpInErrors++ 811 } 812 } else { 813 kcpInErrors++ 814 } 815 } else { 816 s.mu.Lock() 817 waitsnd := s.kcp.WaitSnd() 818 if ret := s.kcp.Input(data, true, s.ackNoDelay); ret != 0 { 819 kcpInErrors++ 820 } 821 if n := s.kcp.PeekSize(); n > 0 { 822 s.notifyReadEvent() 823 } 824 if s.kcp.WaitSnd() < waitsnd || true { 825 s.notifyWriteEvent() 826 } 827 s.uncork() 828 s.mu.Unlock() 829 } 830 831 atomic.AddUint64(&DefaultSnmp.InPkts, 1) 832 atomic.AddUint64(&DefaultSnmp.InBytes, uint64(len(data))) 833 if fecParityShards > 0 { 834 atomic.AddUint64(&DefaultSnmp.FECParityShards, fecParityShards) 835 } 836 if kcpInErrors > 0 && s.fecEncoder != nil { 837 log.Println(kcpInErrors, "bad packets, TURNING OFF FEC") 838 s.mu.Lock() 839 s.fecDecoder = nil 840 s.fecEncoder = nil 841 s.kcp.ReserveBytes(s.headerSize()) 842 s.mu.Unlock() 843 atomic.AddUint64(&DefaultSnmp.KCPInErrors, kcpInErrors) 844 } 845 if fecErrs > 0 { 846 atomic.AddUint64(&DefaultSnmp.FECErrs, fecErrs) 847 } 848 if fecRecovered > 0 { 849 atomic.AddUint64(&DefaultSnmp.FECRecovered, fecRecovered) 850 } 851 852 } 853 854 type ( 855 // Listener defines a server which will be waiting to accept incoming connections 856 Listener struct { 857 block BlockCrypt // block encryption 858 dataShards int // FEC data shard 859 parityShards int // FEC parity shard 860 fecDecoder *fecDecoder // FEC mock initialization 861 conn net.PacketConn // the underlying packet connection 862 863 sessions map[string]*UDPSession // all sessions accepted by this Listener 864 sessionLock sync.Mutex 865 chAccepts chan *UDPSession // Listen() backlog 866 chSessionClosed chan net.Addr // session close queue 867 headerSize int // the additional header to a KCP frame 868 869 die chan struct{} // notify the listener has closed 870 dieOnce sync.Once 871 872 // socket error handling 873 socketReadError atomic.Value 874 chSocketReadError chan struct{} 875 socketReadErrorOnce sync.Once 876 877 rd atomic.Value // read deadline for Accept() 878 } 879 ) 880 881 // packet input stage 882 func (l *Listener) packetInput(data []byte, addr net.Addr) { 883 dataValid := false 884 dataValid = true 885 886 if dataValid { 887 l.sessionLock.Lock() 888 s, ok := l.sessions[addr.String()] 889 l.sessionLock.Unlock() 890 891 if !ok { // new address:port 892 if len(l.chAccepts) < cap(l.chAccepts) { // do not let the new sessions overwhelm accept queue 893 var conv uint32 894 convValid := false 895 if l.fecDecoder != nil { 896 isfec := binary.LittleEndian.Uint16(data[4:]) 897 if isfec == typeData { 898 conv = binary.LittleEndian.Uint32(data[fecHeaderSizePlus2:]) 899 convValid = true 900 } 901 } else { 902 conv = binary.LittleEndian.Uint32(data) 903 convValid = true 904 } 905 if !convValid { 906 conv = 814 907 } 908 909 log.Println("conv is", conv) 910 911 s := newUDPSession(conv, l.dataShards, l.parityShards, l, l.conn, addr, l.block) 912 s.kcpInput(data) 913 l.sessionLock.Lock() 914 l.sessions[addr.String()] = s 915 l.sessionLock.Unlock() 916 l.chAccepts <- s 917 } 918 } else { 919 s.kcpInput(data) 920 } 921 } 922 } 923 924 func (l *Listener) notifyReadError(err error) { 925 l.socketReadErrorOnce.Do(func() { 926 l.socketReadError.Store(err) 927 close(l.chSocketReadError) 928 929 // propagate read error to all sessions 930 l.sessionLock.Lock() 931 for _, s := range l.sessions { 932 s.notifyReadError(err) 933 } 934 l.sessionLock.Unlock() 935 }) 936 } 937 938 // SetReadBuffer sets the socket read buffer for the Listener 939 func (l *Listener) SetReadBuffer(bytes int) error { 940 if nc, ok := l.conn.(setReadBuffer); ok { 941 return nc.SetReadBuffer(bytes) 942 } 943 return errInvalidOperation 944 } 945 946 // SetWriteBuffer sets the socket write buffer for the Listener 947 func (l *Listener) SetWriteBuffer(bytes int) error { 948 if nc, ok := l.conn.(setWriteBuffer); ok { 949 return nc.SetWriteBuffer(bytes) 950 } 951 return errInvalidOperation 952 } 953 954 // SetDSCP sets the 6bit DSCP field in IPv4 header, or 8bit Traffic Class in IPv6 header. 955 // 956 // if the underlying connection has implemented `func SetDSCP(int) error`, SetDSCP() will invoke 957 // this function instead. 958 func (l *Listener) SetDSCP(dscp int) error { 959 // interface enabled 960 if ts, ok := l.conn.(setDSCP); ok { 961 return ts.SetDSCP(dscp) 962 } 963 964 if nc, ok := l.conn.(net.Conn); ok { 965 var succeed bool 966 if err := ipv4.NewConn(nc).SetTOS(dscp << 2); err == nil { 967 succeed = true 968 } 969 if err := ipv6.NewConn(nc).SetTrafficClass(dscp); err == nil { 970 succeed = true 971 } 972 973 if succeed { 974 return nil 975 } 976 } 977 return errInvalidOperation 978 } 979 980 // Accept implements the Accept method in the Listener interface; it waits for the next call and returns a generic Conn. 981 func (l *Listener) Accept() (net.Conn, error) { 982 return l.AcceptKCP() 983 } 984 985 // AcceptKCP accepts a KCP connection 986 func (l *Listener) AcceptKCP() (*UDPSession, error) { 987 var timeout <-chan time.Time 988 if tdeadline, ok := l.rd.Load().(time.Time); ok && !tdeadline.IsZero() { 989 timeout = time.After(tdeadline.Sub(time.Now())) 990 } 991 992 select { 993 case <-timeout: 994 return nil, errors.WithStack(errTimeout) 995 case c := <-l.chAccepts: 996 return c, nil 997 case <-l.chSocketReadError: 998 return nil, l.socketReadError.Load().(error) 999 case <-l.die: 1000 return nil, errors.WithStack(io.ErrClosedPipe) 1001 } 1002 } 1003 1004 // SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline. 1005 func (l *Listener) SetDeadline(t time.Time) error { 1006 l.SetReadDeadline(t) 1007 l.SetWriteDeadline(t) 1008 return nil 1009 } 1010 1011 // SetReadDeadline implements the Conn SetReadDeadline method. 1012 func (l *Listener) SetReadDeadline(t time.Time) error { 1013 l.rd.Store(t) 1014 return nil 1015 } 1016 1017 // SetWriteDeadline implements the Conn SetWriteDeadline method. 1018 func (l *Listener) SetWriteDeadline(t time.Time) error { return errInvalidOperation } 1019 1020 // Close stops listening on the UDP address, and closes the socket 1021 func (l *Listener) Close() error { 1022 var once bool 1023 l.dieOnce.Do(func() { 1024 close(l.die) 1025 once = true 1026 }) 1027 1028 if once { 1029 return l.conn.Close() 1030 } else { 1031 return errors.WithStack(io.ErrClosedPipe) 1032 } 1033 } 1034 1035 // closeSession notify the listener that a session has closed 1036 func (l *Listener) closeSession(remote net.Addr) (ret bool) { 1037 l.sessionLock.Lock() 1038 defer l.sessionLock.Unlock() 1039 if _, ok := l.sessions[remote.String()]; ok { 1040 delete(l.sessions, remote.String()) 1041 return true 1042 } 1043 return false 1044 } 1045 1046 // Addr returns the listener's network address, The Addr returned is shared by all invocations of Addr, so do not modify it. 1047 func (l *Listener) Addr() net.Addr { return l.conn.LocalAddr() } 1048 1049 // Listen listens for incoming KCP packets addressed to the local address laddr on the network "udp", 1050 func Listen(laddr string) (net.Listener, error) { return ListenWithOptions(laddr, nil, 0, 0) } 1051 1052 // ListenWithOptions listens for incoming KCP packets addressed to the local address laddr on the network "udp" with packet encryption. 1053 // 1054 // 'block' is the block encryption algorithm to encrypt packets. 1055 // 1056 // 'dataShards', 'parityShards' specifiy how many parity packets will be generated following the data packets. 1057 // 1058 // Check https://github.com/klauspost/reedsolomon for details 1059 func ListenWithOptions(laddr string, block BlockCrypt, dataShards, parityShards int) (*Listener, error) { 1060 udpaddr, err := net.ResolveUDPAddr("udp", laddr) 1061 if err != nil { 1062 return nil, errors.WithStack(err) 1063 } 1064 conn, err := net.ListenUDP("udp", udpaddr) 1065 if err != nil { 1066 return nil, errors.WithStack(err) 1067 } 1068 1069 return ServeConn(block, dataShards, parityShards, conn) 1070 } 1071 1072 // ServeConn serves KCP protocol for a single packet connection. 1073 func ServeConn(block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*Listener, error) { 1074 l := new(Listener) 1075 l.conn = conn 1076 l.sessions = make(map[string]*UDPSession) 1077 l.chAccepts = make(chan *UDPSession, acceptBacklog) 1078 l.chSessionClosed = make(chan net.Addr) 1079 l.die = make(chan struct{}) 1080 l.dataShards = dataShards 1081 l.parityShards = parityShards 1082 l.block = block 1083 l.fecDecoder = newFECDecoder(rxFECMulti*(dataShards+parityShards), dataShards, parityShards) 1084 l.chSocketReadError = make(chan struct{}) 1085 1086 // calculate header size 1087 if l.block != nil { 1088 l.headerSize += cryptHeaderSize 1089 } 1090 if l.fecDecoder != nil { 1091 l.headerSize += fecHeaderSizePlus2 1092 } 1093 1094 go l.monitor() 1095 return l, nil 1096 } 1097 1098 // Dial connects to the remote address "raddr" on the network "udp" without encryption and FEC 1099 func Dial(raddr string) (net.Conn, error) { return DialWithOptions(raddr, nil, 0, 0) } 1100 1101 // DialWithOptions connects to the remote address "raddr" on the network "udp" with packet encryption 1102 // 1103 // 'block' is the block encryption algorithm to encrypt packets. 1104 // 1105 // 'dataShards', 'parityShards' specifiy how many parity packets will be generated following the data packets. 1106 // 1107 // Check https://github.com/klauspost/reedsolomon for details 1108 func DialWithOptions(raddr string, block BlockCrypt, dataShards, parityShards int) (*UDPSession, error) { 1109 // network type detection 1110 udpaddr, err := net.ResolveUDPAddr("udp", raddr) 1111 if err != nil { 1112 return nil, errors.WithStack(err) 1113 } 1114 network := "udp4" 1115 if udpaddr.IP.To4() == nil { 1116 network = "udp" 1117 } 1118 1119 conn, err := net.ListenUDP(network, nil) 1120 if err != nil { 1121 return nil, errors.WithStack(err) 1122 } 1123 1124 return NewConn(raddr, block, dataShards, parityShards, conn) 1125 } 1126 1127 // NewConn3 establishes a session and talks KCP protocol over a packet connection. 1128 func NewConn3(convid uint32, raddr net.Addr, block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*UDPSession, error) { 1129 return newUDPSession(convid, dataShards, parityShards, nil, conn, raddr, block), nil 1130 } 1131 1132 // NewConn2 establishes a session and talks KCP protocol over a packet connection. 1133 func NewConn2(raddr net.Addr, block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*UDPSession, error) { 1134 var convid uint32 1135 binary.Read(rand.Reader, binary.LittleEndian, &convid) 1136 return NewConn3(convid, raddr, block, dataShards, parityShards, conn) 1137 } 1138 1139 // NewConn establishes a session and talks KCP protocol over a packet connection. 1140 func NewConn(raddr string, block BlockCrypt, dataShards, parityShards int, conn net.PacketConn) (*UDPSession, error) { 1141 udpaddr, err := net.ResolveUDPAddr("udp", raddr) 1142 if err != nil { 1143 return nil, errors.WithStack(err) 1144 } 1145 return NewConn2(udpaddr, block, dataShards, parityShards, conn) 1146 }