github.com/xraypb/Xray-core@v1.8.1/transport/internet/kcp/connection.go (about) 1 package kcp 2 3 import ( 4 "bytes" 5 "io" 6 "net" 7 "runtime" 8 "sync" 9 "sync/atomic" 10 "time" 11 12 "github.com/xraypb/Xray-core/common/buf" 13 "github.com/xraypb/Xray-core/common/signal" 14 "github.com/xraypb/Xray-core/common/signal/semaphore" 15 ) 16 17 var ( 18 ErrIOTimeout = newError("Read/Write timeout") 19 ErrClosedListener = newError("Listener closed.") 20 ErrClosedConnection = newError("Connection closed.") 21 ) 22 23 // State of the connection 24 type State int32 25 26 // Is returns true if current State is one of the candidates. 27 func (s State) Is(states ...State) bool { 28 for _, state := range states { 29 if s == state { 30 return true 31 } 32 } 33 return false 34 } 35 36 const ( 37 StateActive State = 0 // Connection is active 38 StateReadyToClose State = 1 // Connection is closed locally 39 StatePeerClosed State = 2 // Connection is closed on remote 40 StateTerminating State = 3 // Connection is ready to be destroyed locally 41 StatePeerTerminating State = 4 // Connection is ready to be destroyed on remote 42 StateTerminated State = 5 // Connection is destroyed. 43 ) 44 45 func nowMillisec() int64 { 46 now := time.Now() 47 return now.Unix()*1000 + int64(now.Nanosecond()/1000000) 48 } 49 50 type RoundTripInfo struct { 51 sync.RWMutex 52 variation uint32 53 srtt uint32 54 rto uint32 55 minRtt uint32 56 updatedTimestamp uint32 57 } 58 59 func (info *RoundTripInfo) UpdatePeerRTO(rto uint32, current uint32) { 60 info.Lock() 61 defer info.Unlock() 62 63 if current-info.updatedTimestamp < 3000 { 64 return 65 } 66 67 info.updatedTimestamp = current 68 info.rto = rto 69 } 70 71 func (info *RoundTripInfo) Update(rtt uint32, current uint32) { 72 if rtt > 0x7FFFFFFF { 73 return 74 } 75 info.Lock() 76 defer info.Unlock() 77 78 // https://tools.ietf.org/html/rfc6298 79 if info.srtt == 0 { 80 info.srtt = rtt 81 info.variation = rtt / 2 82 } else { 83 delta := rtt - info.srtt 84 if info.srtt > rtt { 85 delta = info.srtt - rtt 86 } 87 info.variation = (3*info.variation + delta) / 4 88 info.srtt = (7*info.srtt + rtt) / 8 89 if info.srtt < info.minRtt { 90 info.srtt = info.minRtt 91 } 92 } 93 var rto uint32 94 if info.minRtt < 4*info.variation { 95 rto = info.srtt + 4*info.variation 96 } else { 97 rto = info.srtt + info.variation 98 } 99 100 if rto > 10000 { 101 rto = 10000 102 } 103 info.rto = rto * 5 / 4 104 info.updatedTimestamp = current 105 } 106 107 func (info *RoundTripInfo) Timeout() uint32 { 108 info.RLock() 109 defer info.RUnlock() 110 111 return info.rto 112 } 113 114 func (info *RoundTripInfo) SmoothedTime() uint32 { 115 info.RLock() 116 defer info.RUnlock() 117 118 return info.srtt 119 } 120 121 type Updater struct { 122 interval int64 123 shouldContinue func() bool 124 shouldTerminate func() bool 125 updateFunc func() 126 notifier *semaphore.Instance 127 } 128 129 func NewUpdater(interval uint32, shouldContinue func() bool, shouldTerminate func() bool, updateFunc func()) *Updater { 130 u := &Updater{ 131 interval: int64(time.Duration(interval) * time.Millisecond), 132 shouldContinue: shouldContinue, 133 shouldTerminate: shouldTerminate, 134 updateFunc: updateFunc, 135 notifier: semaphore.New(1), 136 } 137 return u 138 } 139 140 func (u *Updater) WakeUp() { 141 select { 142 case <-u.notifier.Wait(): 143 go u.run() 144 default: 145 } 146 } 147 148 func (u *Updater) run() { 149 defer u.notifier.Signal() 150 151 if u.shouldTerminate() { 152 return 153 } 154 ticker := time.NewTicker(u.Interval()) 155 for u.shouldContinue() { 156 u.updateFunc() 157 <-ticker.C 158 } 159 ticker.Stop() 160 } 161 162 func (u *Updater) Interval() time.Duration { 163 return time.Duration(atomic.LoadInt64(&u.interval)) 164 } 165 166 func (u *Updater) SetInterval(d time.Duration) { 167 atomic.StoreInt64(&u.interval, int64(d)) 168 } 169 170 type ConnMetadata struct { 171 LocalAddr net.Addr 172 RemoteAddr net.Addr 173 Conversation uint16 174 } 175 176 // Connection is a KCP connection over UDP. 177 type Connection struct { 178 meta ConnMetadata 179 closer io.Closer 180 rd time.Time 181 wd time.Time // write deadline 182 since int64 183 dataInput *signal.Notifier 184 dataOutput *signal.Notifier 185 Config *Config 186 187 state State 188 stateBeginTime uint32 189 lastIncomingTime uint32 190 lastPingTime uint32 191 192 mss uint32 193 roundTrip *RoundTripInfo 194 195 receivingWorker *ReceivingWorker 196 sendingWorker *SendingWorker 197 198 output SegmentWriter 199 200 dataUpdater *Updater 201 pingUpdater *Updater 202 } 203 204 // NewConnection create a new KCP connection between local and remote. 205 func NewConnection(meta ConnMetadata, writer PacketWriter, closer io.Closer, config *Config) *Connection { 206 newError("#", meta.Conversation, " creating connection to ", meta.RemoteAddr).WriteToLog() 207 208 conn := &Connection{ 209 meta: meta, 210 closer: closer, 211 since: nowMillisec(), 212 dataInput: signal.NewNotifier(), 213 dataOutput: signal.NewNotifier(), 214 Config: config, 215 output: NewRetryableWriter(NewSegmentWriter(writer)), 216 mss: config.GetMTUValue() - uint32(writer.Overhead()) - DataSegmentOverhead, 217 roundTrip: &RoundTripInfo{ 218 rto: 100, 219 minRtt: config.GetTTIValue(), 220 }, 221 } 222 223 conn.receivingWorker = NewReceivingWorker(conn) 224 conn.sendingWorker = NewSendingWorker(conn) 225 226 isTerminating := func() bool { 227 return conn.State().Is(StateTerminating, StateTerminated) 228 } 229 isTerminated := func() bool { 230 return conn.State() == StateTerminated 231 } 232 conn.dataUpdater = NewUpdater( 233 config.GetTTIValue(), 234 func() bool { 235 return !isTerminating() && (conn.sendingWorker.UpdateNecessary() || conn.receivingWorker.UpdateNecessary()) 236 }, 237 isTerminating, 238 conn.updateTask) 239 conn.pingUpdater = NewUpdater( 240 5000, // 5 seconds 241 func() bool { return !isTerminated() }, 242 isTerminated, 243 conn.updateTask) 244 conn.pingUpdater.WakeUp() 245 246 return conn 247 } 248 249 func (c *Connection) Elapsed() uint32 { 250 return uint32(nowMillisec() - c.since) 251 } 252 253 // ReadMultiBuffer implements buf.Reader. 254 func (c *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) { 255 if c == nil { 256 return nil, io.EOF 257 } 258 259 for { 260 if c.State().Is(StateReadyToClose, StateTerminating, StateTerminated) { 261 return nil, io.EOF 262 } 263 mb := c.receivingWorker.ReadMultiBuffer() 264 if !mb.IsEmpty() { 265 c.dataUpdater.WakeUp() 266 return mb, nil 267 } 268 269 if c.State() == StatePeerTerminating { 270 return nil, io.EOF 271 } 272 273 if err := c.waitForDataInput(); err != nil { 274 return nil, err 275 } 276 } 277 } 278 279 func (c *Connection) waitForDataInput() error { 280 for i := 0; i < 16; i++ { 281 select { 282 case <-c.dataInput.Wait(): 283 return nil 284 default: 285 runtime.Gosched() 286 } 287 } 288 289 duration := time.Second * 16 290 if !c.rd.IsZero() { 291 duration = time.Until(c.rd) 292 if duration < 0 { 293 return ErrIOTimeout 294 } 295 } 296 297 timeout := time.NewTimer(duration) 298 defer timeout.Stop() 299 300 select { 301 case <-c.dataInput.Wait(): 302 case <-timeout.C: 303 if !c.rd.IsZero() && c.rd.Before(time.Now()) { 304 return ErrIOTimeout 305 } 306 } 307 308 return nil 309 } 310 311 // Read implements the Conn Read method. 312 func (c *Connection) Read(b []byte) (int, error) { 313 if c == nil { 314 return 0, io.EOF 315 } 316 317 for { 318 if c.State().Is(StateReadyToClose, StateTerminating, StateTerminated) { 319 return 0, io.EOF 320 } 321 nBytes := c.receivingWorker.Read(b) 322 if nBytes > 0 { 323 c.dataUpdater.WakeUp() 324 return nBytes, nil 325 } 326 327 if err := c.waitForDataInput(); err != nil { 328 return 0, err 329 } 330 } 331 } 332 333 func (c *Connection) waitForDataOutput() error { 334 for i := 0; i < 16; i++ { 335 select { 336 case <-c.dataOutput.Wait(): 337 return nil 338 default: 339 runtime.Gosched() 340 } 341 } 342 343 duration := time.Second * 16 344 if !c.wd.IsZero() { 345 duration = time.Until(c.wd) 346 if duration < 0 { 347 return ErrIOTimeout 348 } 349 } 350 351 timeout := time.NewTimer(duration) 352 defer timeout.Stop() 353 354 select { 355 case <-c.dataOutput.Wait(): 356 case <-timeout.C: 357 if !c.wd.IsZero() && c.wd.Before(time.Now()) { 358 return ErrIOTimeout 359 } 360 } 361 362 return nil 363 } 364 365 // Write implements io.Writer. 366 func (c *Connection) Write(b []byte) (int, error) { 367 reader := bytes.NewReader(b) 368 if err := c.writeMultiBufferInternal(reader); err != nil { 369 return 0, err 370 } 371 return len(b), nil 372 } 373 374 // WriteMultiBuffer implements buf.Writer. 375 func (c *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error { 376 reader := &buf.MultiBufferContainer{ 377 MultiBuffer: mb, 378 } 379 defer reader.Close() 380 381 return c.writeMultiBufferInternal(reader) 382 } 383 384 func (c *Connection) writeMultiBufferInternal(reader io.Reader) error { 385 updatePending := false 386 defer func() { 387 if updatePending { 388 c.dataUpdater.WakeUp() 389 } 390 }() 391 392 var b *buf.Buffer 393 defer b.Release() 394 395 for { 396 for { 397 if c == nil || c.State() != StateActive { 398 return io.ErrClosedPipe 399 } 400 401 if b == nil { 402 b = buf.New() 403 _, err := b.ReadFrom(io.LimitReader(reader, int64(c.mss))) 404 if err != nil { 405 return nil 406 } 407 } 408 409 if !c.sendingWorker.Push(b) { 410 break 411 } 412 updatePending = true 413 b = nil 414 } 415 416 if updatePending { 417 c.dataUpdater.WakeUp() 418 updatePending = false 419 } 420 421 if err := c.waitForDataOutput(); err != nil { 422 return err 423 } 424 } 425 } 426 427 func (c *Connection) SetState(state State) { 428 current := c.Elapsed() 429 atomic.StoreInt32((*int32)(&c.state), int32(state)) 430 atomic.StoreUint32(&c.stateBeginTime, current) 431 newError("#", c.meta.Conversation, " entering state ", state, " at ", current).AtDebug().WriteToLog() 432 433 switch state { 434 case StateReadyToClose: 435 c.receivingWorker.CloseRead() 436 case StatePeerClosed: 437 c.sendingWorker.CloseWrite() 438 case StateTerminating: 439 c.receivingWorker.CloseRead() 440 c.sendingWorker.CloseWrite() 441 c.pingUpdater.SetInterval(time.Second) 442 case StatePeerTerminating: 443 c.sendingWorker.CloseWrite() 444 c.pingUpdater.SetInterval(time.Second) 445 case StateTerminated: 446 c.receivingWorker.CloseRead() 447 c.sendingWorker.CloseWrite() 448 c.pingUpdater.SetInterval(time.Second) 449 c.dataUpdater.WakeUp() 450 c.pingUpdater.WakeUp() 451 go c.Terminate() 452 } 453 } 454 455 // Close closes the connection. 456 func (c *Connection) Close() error { 457 if c == nil { 458 return ErrClosedConnection 459 } 460 461 c.dataInput.Signal() 462 c.dataOutput.Signal() 463 464 switch c.State() { 465 case StateReadyToClose, StateTerminating, StateTerminated: 466 return ErrClosedConnection 467 case StateActive: 468 c.SetState(StateReadyToClose) 469 case StatePeerClosed: 470 c.SetState(StateTerminating) 471 case StatePeerTerminating: 472 c.SetState(StateTerminated) 473 } 474 475 newError("#", c.meta.Conversation, " closing connection to ", c.meta.RemoteAddr).WriteToLog() 476 477 return nil 478 } 479 480 // LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it. 481 func (c *Connection) LocalAddr() net.Addr { 482 if c == nil { 483 return nil 484 } 485 return c.meta.LocalAddr 486 } 487 488 // RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it. 489 func (c *Connection) RemoteAddr() net.Addr { 490 if c == nil { 491 return nil 492 } 493 return c.meta.RemoteAddr 494 } 495 496 // SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline. 497 func (c *Connection) SetDeadline(t time.Time) error { 498 if err := c.SetReadDeadline(t); err != nil { 499 return err 500 } 501 return c.SetWriteDeadline(t) 502 } 503 504 // SetReadDeadline implements the Conn SetReadDeadline method. 505 func (c *Connection) SetReadDeadline(t time.Time) error { 506 if c == nil || c.State() != StateActive { 507 return ErrClosedConnection 508 } 509 c.rd = t 510 return nil 511 } 512 513 // SetWriteDeadline implements the Conn SetWriteDeadline method. 514 func (c *Connection) SetWriteDeadline(t time.Time) error { 515 if c == nil || c.State() != StateActive { 516 return ErrClosedConnection 517 } 518 c.wd = t 519 return nil 520 } 521 522 // kcp update, input loop 523 func (c *Connection) updateTask() { 524 c.flush() 525 } 526 527 func (c *Connection) Terminate() { 528 if c == nil { 529 return 530 } 531 newError("#", c.meta.Conversation, " terminating connection to ", c.RemoteAddr()).WriteToLog() 532 533 // v.SetState(StateTerminated) 534 c.dataInput.Signal() 535 c.dataOutput.Signal() 536 537 c.closer.Close() 538 c.sendingWorker.Release() 539 c.receivingWorker.Release() 540 } 541 542 func (c *Connection) HandleOption(opt SegmentOption) { 543 if (opt & SegmentOptionClose) == SegmentOptionClose { 544 c.OnPeerClosed() 545 } 546 } 547 548 func (c *Connection) OnPeerClosed() { 549 switch c.State() { 550 case StateReadyToClose: 551 c.SetState(StateTerminating) 552 case StateActive: 553 c.SetState(StatePeerClosed) 554 } 555 } 556 557 // Input when you received a low level packet (eg. UDP packet), call it 558 func (c *Connection) Input(segments []Segment) { 559 current := c.Elapsed() 560 atomic.StoreUint32(&c.lastIncomingTime, current) 561 562 for _, seg := range segments { 563 if seg.Conversation() != c.meta.Conversation { 564 break 565 } 566 567 switch seg := seg.(type) { 568 case *DataSegment: 569 c.HandleOption(seg.Option) 570 c.receivingWorker.ProcessSegment(seg) 571 if c.receivingWorker.IsDataAvailable() { 572 c.dataInput.Signal() 573 } 574 c.dataUpdater.WakeUp() 575 case *AckSegment: 576 c.HandleOption(seg.Option) 577 c.sendingWorker.ProcessSegment(current, seg, c.roundTrip.Timeout()) 578 c.dataOutput.Signal() 579 c.dataUpdater.WakeUp() 580 case *CmdOnlySegment: 581 c.HandleOption(seg.Option) 582 if seg.Command() == CommandTerminate { 583 switch c.State() { 584 case StateActive, StatePeerClosed: 585 c.SetState(StatePeerTerminating) 586 case StateReadyToClose: 587 c.SetState(StateTerminating) 588 case StateTerminating: 589 c.SetState(StateTerminated) 590 } 591 } 592 if seg.Option == SegmentOptionClose || seg.Command() == CommandTerminate { 593 c.dataInput.Signal() 594 c.dataOutput.Signal() 595 } 596 c.sendingWorker.ProcessReceivingNext(seg.ReceivingNext) 597 c.receivingWorker.ProcessSendingNext(seg.SendingNext) 598 c.roundTrip.UpdatePeerRTO(seg.PeerRTO, current) 599 seg.Release() 600 default: 601 } 602 } 603 } 604 605 func (c *Connection) flush() { 606 current := c.Elapsed() 607 608 if c.State() == StateTerminated { 609 return 610 } 611 if c.State() == StateActive && current-atomic.LoadUint32(&c.lastIncomingTime) >= 30000 { 612 c.Close() 613 } 614 if c.State() == StateReadyToClose && c.sendingWorker.IsEmpty() { 615 c.SetState(StateTerminating) 616 } 617 618 if c.State() == StateTerminating { 619 newError("#", c.meta.Conversation, " sending terminating cmd.").AtDebug().WriteToLog() 620 c.Ping(current, CommandTerminate) 621 622 if current-atomic.LoadUint32(&c.stateBeginTime) > 8000 { 623 c.SetState(StateTerminated) 624 } 625 return 626 } 627 if c.State() == StatePeerTerminating && current-atomic.LoadUint32(&c.stateBeginTime) > 4000 { 628 c.SetState(StateTerminating) 629 } 630 631 if c.State() == StateReadyToClose && current-atomic.LoadUint32(&c.stateBeginTime) > 15000 { 632 c.SetState(StateTerminating) 633 } 634 635 // flush acknowledges 636 c.receivingWorker.Flush(current) 637 c.sendingWorker.Flush(current) 638 639 if current-atomic.LoadUint32(&c.lastPingTime) >= 3000 { 640 c.Ping(current, CommandPing) 641 } 642 } 643 644 func (c *Connection) State() State { 645 return State(atomic.LoadInt32((*int32)(&c.state))) 646 } 647 648 func (c *Connection) Ping(current uint32, cmd Command) { 649 seg := NewCmdOnlySegment() 650 seg.Conv = c.meta.Conversation 651 seg.Cmd = cmd 652 seg.ReceivingNext = c.receivingWorker.NextNumber() 653 seg.SendingNext = c.sendingWorker.FirstUnacknowledged() 654 seg.PeerRTO = c.roundTrip.Timeout() 655 if c.State() == StateReadyToClose { 656 seg.Option = SegmentOptionClose 657 } 658 c.output.Write(seg) 659 atomic.StoreUint32(&c.lastPingTime, current) 660 seg.Release() 661 }