github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/transport/internet/kcp/connection.go (about)

     1  package kcp
     2  
     3  import (
     4  	"bytes"
     5  	"io"
     6  	"net"
     7  	"runtime"
     8  	"sync"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/xtls/xray-core/common/buf"
    13  	"github.com/xtls/xray-core/common/signal"
    14  	"github.com/xtls/xray-core/common/signal/semaphore"
    15  )
    16  
    17  var (
    18  	ErrIOTimeout        = newError("Read/Write timeout")
    19  	ErrClosedListener   = newError("Listener closed.")
    20  	ErrClosedConnection = newError("Connection closed.")
    21  )
    22  
    23  // State of the connection
    24  type State int32
    25  
    26  // Is returns true if current State is one of the candidates.
    27  func (s State) Is(states ...State) bool {
    28  	for _, state := range states {
    29  		if s == state {
    30  			return true
    31  		}
    32  	}
    33  	return false
    34  }
    35  
    36  const (
    37  	StateActive          State = 0 // Connection is active
    38  	StateReadyToClose    State = 1 // Connection is closed locally
    39  	StatePeerClosed      State = 2 // Connection is closed on remote
    40  	StateTerminating     State = 3 // Connection is ready to be destroyed locally
    41  	StatePeerTerminating State = 4 // Connection is ready to be destroyed on remote
    42  	StateTerminated      State = 5 // Connection is destroyed.
    43  )
    44  
    45  func nowMillisec() int64 {
    46  	now := time.Now()
    47  	return now.Unix()*1000 + int64(now.Nanosecond()/1000000)
    48  }
    49  
    50  type RoundTripInfo struct {
    51  	sync.RWMutex
    52  	variation        uint32
    53  	srtt             uint32
    54  	rto              uint32
    55  	minRtt           uint32
    56  	updatedTimestamp uint32
    57  }
    58  
    59  func (info *RoundTripInfo) UpdatePeerRTO(rto uint32, current uint32) {
    60  	info.Lock()
    61  	defer info.Unlock()
    62  
    63  	if current-info.updatedTimestamp < 3000 {
    64  		return
    65  	}
    66  
    67  	info.updatedTimestamp = current
    68  	info.rto = rto
    69  }
    70  
    71  func (info *RoundTripInfo) Update(rtt uint32, current uint32) {
    72  	if rtt > 0x7FFFFFFF {
    73  		return
    74  	}
    75  	info.Lock()
    76  	defer info.Unlock()
    77  
    78  	// https://tools.ietf.org/html/rfc6298
    79  	if info.srtt == 0 {
    80  		info.srtt = rtt
    81  		info.variation = rtt / 2
    82  	} else {
    83  		delta := rtt - info.srtt
    84  		if info.srtt > rtt {
    85  			delta = info.srtt - rtt
    86  		}
    87  		info.variation = (3*info.variation + delta) / 4
    88  		info.srtt = (7*info.srtt + rtt) / 8
    89  		if info.srtt < info.minRtt {
    90  			info.srtt = info.minRtt
    91  		}
    92  	}
    93  	var rto uint32
    94  	if info.minRtt < 4*info.variation {
    95  		rto = info.srtt + 4*info.variation
    96  	} else {
    97  		rto = info.srtt + info.variation
    98  	}
    99  
   100  	if rto > 10000 {
   101  		rto = 10000
   102  	}
   103  	info.rto = rto * 5 / 4
   104  	info.updatedTimestamp = current
   105  }
   106  
   107  func (info *RoundTripInfo) Timeout() uint32 {
   108  	info.RLock()
   109  	defer info.RUnlock()
   110  
   111  	return info.rto
   112  }
   113  
   114  func (info *RoundTripInfo) SmoothedTime() uint32 {
   115  	info.RLock()
   116  	defer info.RUnlock()
   117  
   118  	return info.srtt
   119  }
   120  
   121  type Updater struct {
   122  	interval        int64
   123  	shouldContinue  func() bool
   124  	shouldTerminate func() bool
   125  	updateFunc      func()
   126  	notifier        *semaphore.Instance
   127  }
   128  
   129  func NewUpdater(interval uint32, shouldContinue func() bool, shouldTerminate func() bool, updateFunc func()) *Updater {
   130  	u := &Updater{
   131  		interval:        int64(time.Duration(interval) * time.Millisecond),
   132  		shouldContinue:  shouldContinue,
   133  		shouldTerminate: shouldTerminate,
   134  		updateFunc:      updateFunc,
   135  		notifier:        semaphore.New(1),
   136  	}
   137  	return u
   138  }
   139  
   140  func (u *Updater) WakeUp() {
   141  	select {
   142  	case <-u.notifier.Wait():
   143  		go u.run()
   144  	default:
   145  	}
   146  }
   147  
   148  func (u *Updater) run() {
   149  	defer u.notifier.Signal()
   150  
   151  	if u.shouldTerminate() {
   152  		return
   153  	}
   154  	ticker := time.NewTicker(u.Interval())
   155  	for u.shouldContinue() {
   156  		u.updateFunc()
   157  		<-ticker.C
   158  	}
   159  	ticker.Stop()
   160  }
   161  
   162  func (u *Updater) Interval() time.Duration {
   163  	return time.Duration(atomic.LoadInt64(&u.interval))
   164  }
   165  
   166  func (u *Updater) SetInterval(d time.Duration) {
   167  	atomic.StoreInt64(&u.interval, int64(d))
   168  }
   169  
   170  type ConnMetadata struct {
   171  	LocalAddr    net.Addr
   172  	RemoteAddr   net.Addr
   173  	Conversation uint16
   174  }
   175  
   176  // Connection is a KCP connection over UDP.
   177  type Connection struct {
   178  	meta       ConnMetadata
   179  	closer     io.Closer
   180  	rd         time.Time
   181  	wd         time.Time // write deadline
   182  	since      int64
   183  	dataInput  *signal.Notifier
   184  	dataOutput *signal.Notifier
   185  	Config     *Config
   186  
   187  	state            State
   188  	stateBeginTime   uint32
   189  	lastIncomingTime uint32
   190  	lastPingTime     uint32
   191  
   192  	mss       uint32
   193  	roundTrip *RoundTripInfo
   194  
   195  	receivingWorker *ReceivingWorker
   196  	sendingWorker   *SendingWorker
   197  
   198  	output SegmentWriter
   199  
   200  	dataUpdater *Updater
   201  	pingUpdater *Updater
   202  }
   203  
   204  // NewConnection create a new KCP connection between local and remote.
   205  func NewConnection(meta ConnMetadata, writer PacketWriter, closer io.Closer, config *Config) *Connection {
   206  	newError("#", meta.Conversation, " creating connection to ", meta.RemoteAddr).WriteToLog()
   207  
   208  	conn := &Connection{
   209  		meta:       meta,
   210  		closer:     closer,
   211  		since:      nowMillisec(),
   212  		dataInput:  signal.NewNotifier(),
   213  		dataOutput: signal.NewNotifier(),
   214  		Config:     config,
   215  		output:     NewRetryableWriter(NewSegmentWriter(writer)),
   216  		mss:        config.GetMTUValue() - uint32(writer.Overhead()) - DataSegmentOverhead,
   217  		roundTrip: &RoundTripInfo{
   218  			rto:    100,
   219  			minRtt: config.GetTTIValue(),
   220  		},
   221  	}
   222  
   223  	conn.receivingWorker = NewReceivingWorker(conn)
   224  	conn.sendingWorker = NewSendingWorker(conn)
   225  
   226  	isTerminating := func() bool {
   227  		return conn.State().Is(StateTerminating, StateTerminated)
   228  	}
   229  	isTerminated := func() bool {
   230  		return conn.State() == StateTerminated
   231  	}
   232  	conn.dataUpdater = NewUpdater(
   233  		config.GetTTIValue(),
   234  		func() bool {
   235  			return !isTerminating() && (conn.sendingWorker.UpdateNecessary() || conn.receivingWorker.UpdateNecessary())
   236  		},
   237  		isTerminating,
   238  		conn.updateTask)
   239  	conn.pingUpdater = NewUpdater(
   240  		5000, // 5 seconds
   241  		func() bool { return !isTerminated() },
   242  		isTerminated,
   243  		conn.updateTask)
   244  	conn.pingUpdater.WakeUp()
   245  
   246  	return conn
   247  }
   248  
   249  func (c *Connection) Elapsed() uint32 {
   250  	return uint32(nowMillisec() - c.since)
   251  }
   252  
   253  // ReadMultiBuffer implements buf.Reader.
   254  func (c *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) {
   255  	if c == nil {
   256  		return nil, io.EOF
   257  	}
   258  
   259  	for {
   260  		if c.State().Is(StateReadyToClose, StateTerminating, StateTerminated) {
   261  			return nil, io.EOF
   262  		}
   263  		mb := c.receivingWorker.ReadMultiBuffer()
   264  		if !mb.IsEmpty() {
   265  			c.dataUpdater.WakeUp()
   266  			return mb, nil
   267  		}
   268  
   269  		if c.State() == StatePeerTerminating {
   270  			return nil, io.EOF
   271  		}
   272  
   273  		if err := c.waitForDataInput(); err != nil {
   274  			return nil, err
   275  		}
   276  	}
   277  }
   278  
   279  func (c *Connection) waitForDataInput() error {
   280  	for i := 0; i < 16; i++ {
   281  		select {
   282  		case <-c.dataInput.Wait():
   283  			return nil
   284  		default:
   285  			runtime.Gosched()
   286  		}
   287  	}
   288  
   289  	duration := time.Second * 16
   290  	if !c.rd.IsZero() {
   291  		duration = time.Until(c.rd)
   292  		if duration < 0 {
   293  			return ErrIOTimeout
   294  		}
   295  	}
   296  
   297  	timeout := time.NewTimer(duration)
   298  	defer timeout.Stop()
   299  
   300  	select {
   301  	case <-c.dataInput.Wait():
   302  	case <-timeout.C:
   303  		if !c.rd.IsZero() && c.rd.Before(time.Now()) {
   304  			return ErrIOTimeout
   305  		}
   306  	}
   307  
   308  	return nil
   309  }
   310  
   311  // Read implements the Conn Read method.
   312  func (c *Connection) Read(b []byte) (int, error) {
   313  	if c == nil {
   314  		return 0, io.EOF
   315  	}
   316  
   317  	for {
   318  		if c.State().Is(StateReadyToClose, StateTerminating, StateTerminated) {
   319  			return 0, io.EOF
   320  		}
   321  		nBytes := c.receivingWorker.Read(b)
   322  		if nBytes > 0 {
   323  			c.dataUpdater.WakeUp()
   324  			return nBytes, nil
   325  		}
   326  
   327  		if err := c.waitForDataInput(); err != nil {
   328  			return 0, err
   329  		}
   330  	}
   331  }
   332  
   333  func (c *Connection) waitForDataOutput() error {
   334  	for i := 0; i < 16; i++ {
   335  		select {
   336  		case <-c.dataOutput.Wait():
   337  			return nil
   338  		default:
   339  			runtime.Gosched()
   340  		}
   341  	}
   342  
   343  	duration := time.Second * 16
   344  	if !c.wd.IsZero() {
   345  		duration = time.Until(c.wd)
   346  		if duration < 0 {
   347  			return ErrIOTimeout
   348  		}
   349  	}
   350  
   351  	timeout := time.NewTimer(duration)
   352  	defer timeout.Stop()
   353  
   354  	select {
   355  	case <-c.dataOutput.Wait():
   356  	case <-timeout.C:
   357  		if !c.wd.IsZero() && c.wd.Before(time.Now()) {
   358  			return ErrIOTimeout
   359  		}
   360  	}
   361  
   362  	return nil
   363  }
   364  
   365  // Write implements io.Writer.
   366  func (c *Connection) Write(b []byte) (int, error) {
   367  	reader := bytes.NewReader(b)
   368  	if err := c.writeMultiBufferInternal(reader); err != nil {
   369  		return 0, err
   370  	}
   371  	return len(b), nil
   372  }
   373  
   374  // WriteMultiBuffer implements buf.Writer.
   375  func (c *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
   376  	reader := &buf.MultiBufferContainer{
   377  		MultiBuffer: mb,
   378  	}
   379  	defer reader.Close()
   380  
   381  	return c.writeMultiBufferInternal(reader)
   382  }
   383  
   384  func (c *Connection) writeMultiBufferInternal(reader io.Reader) error {
   385  	updatePending := false
   386  	defer func() {
   387  		if updatePending {
   388  			c.dataUpdater.WakeUp()
   389  		}
   390  	}()
   391  
   392  	var b *buf.Buffer
   393  	defer b.Release()
   394  
   395  	for {
   396  		for {
   397  			if c == nil || c.State() != StateActive {
   398  				return io.ErrClosedPipe
   399  			}
   400  
   401  			if b == nil {
   402  				b = buf.New()
   403  				_, err := b.ReadFrom(io.LimitReader(reader, int64(c.mss)))
   404  				if err != nil {
   405  					return nil
   406  				}
   407  			}
   408  
   409  			if !c.sendingWorker.Push(b) {
   410  				break
   411  			}
   412  			updatePending = true
   413  			b = nil
   414  		}
   415  
   416  		if updatePending {
   417  			c.dataUpdater.WakeUp()
   418  			updatePending = false
   419  		}
   420  
   421  		if err := c.waitForDataOutput(); err != nil {
   422  			return err
   423  		}
   424  	}
   425  }
   426  
   427  func (c *Connection) SetState(state State) {
   428  	current := c.Elapsed()
   429  	atomic.StoreInt32((*int32)(&c.state), int32(state))
   430  	atomic.StoreUint32(&c.stateBeginTime, current)
   431  	newError("#", c.meta.Conversation, " entering state ", state, " at ", current).AtDebug().WriteToLog()
   432  
   433  	switch state {
   434  	case StateReadyToClose:
   435  		c.receivingWorker.CloseRead()
   436  	case StatePeerClosed:
   437  		c.sendingWorker.CloseWrite()
   438  	case StateTerminating:
   439  		c.receivingWorker.CloseRead()
   440  		c.sendingWorker.CloseWrite()
   441  		c.pingUpdater.SetInterval(time.Second)
   442  	case StatePeerTerminating:
   443  		c.sendingWorker.CloseWrite()
   444  		c.pingUpdater.SetInterval(time.Second)
   445  	case StateTerminated:
   446  		c.receivingWorker.CloseRead()
   447  		c.sendingWorker.CloseWrite()
   448  		c.pingUpdater.SetInterval(time.Second)
   449  		c.dataUpdater.WakeUp()
   450  		c.pingUpdater.WakeUp()
   451  		go c.Terminate()
   452  	}
   453  }
   454  
   455  // Close closes the connection.
   456  func (c *Connection) Close() error {
   457  	if c == nil {
   458  		return ErrClosedConnection
   459  	}
   460  
   461  	c.dataInput.Signal()
   462  	c.dataOutput.Signal()
   463  
   464  	switch c.State() {
   465  	case StateReadyToClose, StateTerminating, StateTerminated:
   466  		return ErrClosedConnection
   467  	case StateActive:
   468  		c.SetState(StateReadyToClose)
   469  	case StatePeerClosed:
   470  		c.SetState(StateTerminating)
   471  	case StatePeerTerminating:
   472  		c.SetState(StateTerminated)
   473  	}
   474  
   475  	newError("#", c.meta.Conversation, " closing connection to ", c.meta.RemoteAddr).WriteToLog()
   476  
   477  	return nil
   478  }
   479  
   480  // LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it.
   481  func (c *Connection) LocalAddr() net.Addr {
   482  	if c == nil {
   483  		return nil
   484  	}
   485  	return c.meta.LocalAddr
   486  }
   487  
   488  // RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it.
   489  func (c *Connection) RemoteAddr() net.Addr {
   490  	if c == nil {
   491  		return nil
   492  	}
   493  	return c.meta.RemoteAddr
   494  }
   495  
   496  // SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline.
   497  func (c *Connection) SetDeadline(t time.Time) error {
   498  	if err := c.SetReadDeadline(t); err != nil {
   499  		return err
   500  	}
   501  	return c.SetWriteDeadline(t)
   502  }
   503  
   504  // SetReadDeadline implements the Conn SetReadDeadline method.
   505  func (c *Connection) SetReadDeadline(t time.Time) error {
   506  	if c == nil || c.State() != StateActive {
   507  		return ErrClosedConnection
   508  	}
   509  	c.rd = t
   510  	return nil
   511  }
   512  
   513  // SetWriteDeadline implements the Conn SetWriteDeadline method.
   514  func (c *Connection) SetWriteDeadline(t time.Time) error {
   515  	if c == nil || c.State() != StateActive {
   516  		return ErrClosedConnection
   517  	}
   518  	c.wd = t
   519  	return nil
   520  }
   521  
   522  // kcp update, input loop
   523  func (c *Connection) updateTask() {
   524  	c.flush()
   525  }
   526  
   527  func (c *Connection) Terminate() {
   528  	if c == nil {
   529  		return
   530  	}
   531  	newError("#", c.meta.Conversation, " terminating connection to ", c.RemoteAddr()).WriteToLog()
   532  
   533  	// v.SetState(StateTerminated)
   534  	c.dataInput.Signal()
   535  	c.dataOutput.Signal()
   536  
   537  	c.closer.Close()
   538  	c.sendingWorker.Release()
   539  	c.receivingWorker.Release()
   540  }
   541  
   542  func (c *Connection) HandleOption(opt SegmentOption) {
   543  	if (opt & SegmentOptionClose) == SegmentOptionClose {
   544  		c.OnPeerClosed()
   545  	}
   546  }
   547  
   548  func (c *Connection) OnPeerClosed() {
   549  	switch c.State() {
   550  	case StateReadyToClose:
   551  		c.SetState(StateTerminating)
   552  	case StateActive:
   553  		c.SetState(StatePeerClosed)
   554  	}
   555  }
   556  
   557  // Input when you received a low level packet (eg. UDP packet), call it
   558  func (c *Connection) Input(segments []Segment) {
   559  	current := c.Elapsed()
   560  	atomic.StoreUint32(&c.lastIncomingTime, current)
   561  
   562  	for _, seg := range segments {
   563  		if seg.Conversation() != c.meta.Conversation {
   564  			break
   565  		}
   566  
   567  		switch seg := seg.(type) {
   568  		case *DataSegment:
   569  			c.HandleOption(seg.Option)
   570  			c.receivingWorker.ProcessSegment(seg)
   571  			if c.receivingWorker.IsDataAvailable() {
   572  				c.dataInput.Signal()
   573  			}
   574  			c.dataUpdater.WakeUp()
   575  		case *AckSegment:
   576  			c.HandleOption(seg.Option)
   577  			c.sendingWorker.ProcessSegment(current, seg, c.roundTrip.Timeout())
   578  			c.dataOutput.Signal()
   579  			c.dataUpdater.WakeUp()
   580  		case *CmdOnlySegment:
   581  			c.HandleOption(seg.Option)
   582  			if seg.Command() == CommandTerminate {
   583  				switch c.State() {
   584  				case StateActive, StatePeerClosed:
   585  					c.SetState(StatePeerTerminating)
   586  				case StateReadyToClose:
   587  					c.SetState(StateTerminating)
   588  				case StateTerminating:
   589  					c.SetState(StateTerminated)
   590  				}
   591  			}
   592  			if seg.Option == SegmentOptionClose || seg.Command() == CommandTerminate {
   593  				c.dataInput.Signal()
   594  				c.dataOutput.Signal()
   595  			}
   596  			c.sendingWorker.ProcessReceivingNext(seg.ReceivingNext)
   597  			c.receivingWorker.ProcessSendingNext(seg.SendingNext)
   598  			c.roundTrip.UpdatePeerRTO(seg.PeerRTO, current)
   599  			seg.Release()
   600  		default:
   601  		}
   602  	}
   603  }
   604  
   605  func (c *Connection) flush() {
   606  	current := c.Elapsed()
   607  
   608  	if c.State() == StateTerminated {
   609  		return
   610  	}
   611  	if c.State() == StateActive && current-atomic.LoadUint32(&c.lastIncomingTime) >= 30000 {
   612  		c.Close()
   613  	}
   614  	if c.State() == StateReadyToClose && c.sendingWorker.IsEmpty() {
   615  		c.SetState(StateTerminating)
   616  	}
   617  
   618  	if c.State() == StateTerminating {
   619  		newError("#", c.meta.Conversation, " sending terminating cmd.").AtDebug().WriteToLog()
   620  		c.Ping(current, CommandTerminate)
   621  
   622  		if current-atomic.LoadUint32(&c.stateBeginTime) > 8000 {
   623  			c.SetState(StateTerminated)
   624  		}
   625  		return
   626  	}
   627  	if c.State() == StatePeerTerminating && current-atomic.LoadUint32(&c.stateBeginTime) > 4000 {
   628  		c.SetState(StateTerminating)
   629  	}
   630  
   631  	if c.State() == StateReadyToClose && current-atomic.LoadUint32(&c.stateBeginTime) > 15000 {
   632  		c.SetState(StateTerminating)
   633  	}
   634  
   635  	// flush acknowledges
   636  	c.receivingWorker.Flush(current)
   637  	c.sendingWorker.Flush(current)
   638  
   639  	if current-atomic.LoadUint32(&c.lastPingTime) >= 3000 {
   640  		c.Ping(current, CommandPing)
   641  	}
   642  }
   643  
   644  func (c *Connection) State() State {
   645  	return State(atomic.LoadInt32((*int32)(&c.state)))
   646  }
   647  
   648  func (c *Connection) Ping(current uint32, cmd Command) {
   649  	seg := NewCmdOnlySegment()
   650  	seg.Conv = c.meta.Conversation
   651  	seg.Cmd = cmd
   652  	seg.ReceivingNext = c.receivingWorker.NextNumber()
   653  	seg.SendingNext = c.sendingWorker.FirstUnacknowledged()
   654  	seg.PeerRTO = c.roundTrip.Timeout()
   655  	if c.State() == StateReadyToClose {
   656  		seg.Option = SegmentOptionClose
   657  	}
   658  	c.output.Write(seg)
   659  	atomic.StoreUint32(&c.lastPingTime, current)
   660  	seg.Release()
   661  }