github.com/v2fly/v2ray-core/v4@v4.45.2/transport/internet/kcp/connection.go (about)

     1  //go:build !confonly
     2  // +build !confonly
     3  
     4  package kcp
     5  
     6  import (
     7  	"bytes"
     8  	"io"
     9  	"net"
    10  	"runtime"
    11  	"sync"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	"github.com/v2fly/v2ray-core/v4/common/buf"
    16  	"github.com/v2fly/v2ray-core/v4/common/signal"
    17  	"github.com/v2fly/v2ray-core/v4/common/signal/semaphore"
    18  )
    19  
    20  var (
    21  	ErrIOTimeout        = newError("Read/Write timeout")
    22  	ErrClosedListener   = newError("Listener closed.")
    23  	ErrClosedConnection = newError("Connection closed.")
    24  )
    25  
    26  // State of the connection
    27  type State int32
    28  
    29  // Is returns true if current State is one of the candidates.
    30  func (s State) Is(states ...State) bool {
    31  	for _, state := range states {
    32  		if s == state {
    33  			return true
    34  		}
    35  	}
    36  	return false
    37  }
    38  
    39  const (
    40  	StateActive          State = 0 // Connection is active
    41  	StateReadyToClose    State = 1 // Connection is closed locally
    42  	StatePeerClosed      State = 2 // Connection is closed on remote
    43  	StateTerminating     State = 3 // Connection is ready to be destroyed locally
    44  	StatePeerTerminating State = 4 // Connection is ready to be destroyed on remote
    45  	StateTerminated      State = 5 // Connection is destroyed.
    46  )
    47  
    48  func nowMillisec() int64 {
    49  	now := time.Now()
    50  	return now.Unix()*1000 + int64(now.Nanosecond()/1000000)
    51  }
    52  
    53  type RoundTripInfo struct {
    54  	sync.RWMutex
    55  	variation        uint32
    56  	srtt             uint32
    57  	rto              uint32
    58  	minRtt           uint32
    59  	updatedTimestamp uint32
    60  }
    61  
    62  func (info *RoundTripInfo) UpdatePeerRTO(rto uint32, current uint32) {
    63  	info.Lock()
    64  	defer info.Unlock()
    65  
    66  	if current-info.updatedTimestamp < 3000 {
    67  		return
    68  	}
    69  
    70  	info.updatedTimestamp = current
    71  	info.rto = rto
    72  }
    73  
    74  func (info *RoundTripInfo) Update(rtt uint32, current uint32) {
    75  	if rtt > 0x7FFFFFFF {
    76  		return
    77  	}
    78  	info.Lock()
    79  	defer info.Unlock()
    80  
    81  	// https://tools.ietf.org/html/rfc6298
    82  	if info.srtt == 0 {
    83  		info.srtt = rtt
    84  		info.variation = rtt / 2
    85  	} else {
    86  		delta := rtt - info.srtt
    87  		if info.srtt > rtt {
    88  			delta = info.srtt - rtt
    89  		}
    90  		info.variation = (3*info.variation + delta) / 4
    91  		info.srtt = (7*info.srtt + rtt) / 8
    92  		if info.srtt < info.minRtt {
    93  			info.srtt = info.minRtt
    94  		}
    95  	}
    96  	var rto uint32
    97  	if info.minRtt < 4*info.variation {
    98  		rto = info.srtt + 4*info.variation
    99  	} else {
   100  		rto = info.srtt + info.variation
   101  	}
   102  
   103  	if rto > 10000 {
   104  		rto = 10000
   105  	}
   106  	info.rto = rto * 5 / 4
   107  	info.updatedTimestamp = current
   108  }
   109  
   110  func (info *RoundTripInfo) Timeout() uint32 {
   111  	info.RLock()
   112  	defer info.RUnlock()
   113  
   114  	return info.rto
   115  }
   116  
   117  func (info *RoundTripInfo) SmoothedTime() uint32 {
   118  	info.RLock()
   119  	defer info.RUnlock()
   120  
   121  	return info.srtt
   122  }
   123  
   124  type Updater struct {
   125  	interval        int64
   126  	shouldContinue  func() bool
   127  	shouldTerminate func() bool
   128  	updateFunc      func()
   129  	notifier        *semaphore.Instance
   130  }
   131  
   132  func NewUpdater(interval uint32, shouldContinue func() bool, shouldTerminate func() bool, updateFunc func()) *Updater {
   133  	u := &Updater{
   134  		interval:        int64(time.Duration(interval) * time.Millisecond),
   135  		shouldContinue:  shouldContinue,
   136  		shouldTerminate: shouldTerminate,
   137  		updateFunc:      updateFunc,
   138  		notifier:        semaphore.New(1),
   139  	}
   140  	return u
   141  }
   142  
   143  func (u *Updater) WakeUp() {
   144  	select {
   145  	case <-u.notifier.Wait():
   146  		go u.run()
   147  	default:
   148  	}
   149  }
   150  
   151  func (u *Updater) run() {
   152  	defer u.notifier.Signal()
   153  
   154  	if u.shouldTerminate() {
   155  		return
   156  	}
   157  	ticker := time.NewTicker(u.Interval())
   158  	for u.shouldContinue() {
   159  		u.updateFunc()
   160  		<-ticker.C
   161  	}
   162  	ticker.Stop()
   163  }
   164  
   165  func (u *Updater) Interval() time.Duration {
   166  	return time.Duration(atomic.LoadInt64(&u.interval))
   167  }
   168  
   169  func (u *Updater) SetInterval(d time.Duration) {
   170  	atomic.StoreInt64(&u.interval, int64(d))
   171  }
   172  
   173  type ConnMetadata struct {
   174  	LocalAddr    net.Addr
   175  	RemoteAddr   net.Addr
   176  	Conversation uint16
   177  }
   178  
   179  // Connection is a KCP connection over UDP.
   180  type Connection struct {
   181  	meta       ConnMetadata
   182  	closer     io.Closer
   183  	rd         time.Time
   184  	wd         time.Time // write deadline
   185  	since      int64
   186  	dataInput  *signal.Notifier
   187  	dataOutput *signal.Notifier
   188  	Config     *Config
   189  
   190  	state            State
   191  	stateBeginTime   uint32
   192  	lastIncomingTime uint32
   193  	lastPingTime     uint32
   194  
   195  	mss       uint32
   196  	roundTrip *RoundTripInfo
   197  
   198  	receivingWorker *ReceivingWorker
   199  	sendingWorker   *SendingWorker
   200  
   201  	output SegmentWriter
   202  
   203  	dataUpdater *Updater
   204  	pingUpdater *Updater
   205  }
   206  
   207  // NewConnection create a new KCP connection between local and remote.
   208  func NewConnection(meta ConnMetadata, writer PacketWriter, closer io.Closer, config *Config) *Connection {
   209  	newError("#", meta.Conversation, " creating connection to ", meta.RemoteAddr).WriteToLog()
   210  
   211  	conn := &Connection{
   212  		meta:       meta,
   213  		closer:     closer,
   214  		since:      nowMillisec(),
   215  		dataInput:  signal.NewNotifier(),
   216  		dataOutput: signal.NewNotifier(),
   217  		Config:     config,
   218  		output:     NewRetryableWriter(NewSegmentWriter(writer)),
   219  		mss:        config.GetMTUValue() - uint32(writer.Overhead()) - DataSegmentOverhead,
   220  		roundTrip: &RoundTripInfo{
   221  			rto:    100,
   222  			minRtt: config.GetTTIValue(),
   223  		},
   224  	}
   225  
   226  	conn.receivingWorker = NewReceivingWorker(conn)
   227  	conn.sendingWorker = NewSendingWorker(conn)
   228  
   229  	isTerminating := func() bool {
   230  		return conn.State().Is(StateTerminating, StateTerminated)
   231  	}
   232  	isTerminated := func() bool {
   233  		return conn.State() == StateTerminated
   234  	}
   235  	conn.dataUpdater = NewUpdater(
   236  		config.GetTTIValue(),
   237  		func() bool {
   238  			return !isTerminating() && (conn.sendingWorker.UpdateNecessary() || conn.receivingWorker.UpdateNecessary())
   239  		},
   240  		isTerminating,
   241  		conn.updateTask)
   242  	conn.pingUpdater = NewUpdater(
   243  		5000, // 5 seconds
   244  		func() bool { return !isTerminated() },
   245  		isTerminated,
   246  		conn.updateTask)
   247  	conn.pingUpdater.WakeUp()
   248  
   249  	return conn
   250  }
   251  
   252  func (c *Connection) Elapsed() uint32 {
   253  	return uint32(nowMillisec() - c.since)
   254  }
   255  
   256  // ReadMultiBuffer implements buf.Reader.
   257  func (c *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) {
   258  	if c == nil {
   259  		return nil, io.EOF
   260  	}
   261  
   262  	for {
   263  		if c.State().Is(StateReadyToClose, StateTerminating, StateTerminated) {
   264  			return nil, io.EOF
   265  		}
   266  		mb := c.receivingWorker.ReadMultiBuffer()
   267  		if !mb.IsEmpty() {
   268  			c.dataUpdater.WakeUp()
   269  			return mb, nil
   270  		}
   271  
   272  		if c.State() == StatePeerTerminating {
   273  			return nil, io.EOF
   274  		}
   275  
   276  		if err := c.waitForDataInput(); err != nil {
   277  			return nil, err
   278  		}
   279  	}
   280  }
   281  
   282  func (c *Connection) waitForDataInput() error {
   283  	for i := 0; i < 16; i++ {
   284  		select {
   285  		case <-c.dataInput.Wait():
   286  			return nil
   287  		default:
   288  			runtime.Gosched()
   289  		}
   290  	}
   291  
   292  	duration := time.Second * 16
   293  	if !c.rd.IsZero() {
   294  		duration = time.Until(c.rd)
   295  		if duration < 0 {
   296  			return ErrIOTimeout
   297  		}
   298  	}
   299  
   300  	timeout := time.NewTimer(duration)
   301  	defer timeout.Stop()
   302  
   303  	select {
   304  	case <-c.dataInput.Wait():
   305  	case <-timeout.C:
   306  		if !c.rd.IsZero() && c.rd.Before(time.Now()) {
   307  			return ErrIOTimeout
   308  		}
   309  	}
   310  
   311  	return nil
   312  }
   313  
   314  // Read implements the Conn Read method.
   315  func (c *Connection) Read(b []byte) (int, error) {
   316  	if c == nil {
   317  		return 0, io.EOF
   318  	}
   319  
   320  	for {
   321  		if c.State().Is(StateReadyToClose, StateTerminating, StateTerminated) {
   322  			return 0, io.EOF
   323  		}
   324  		nBytes := c.receivingWorker.Read(b)
   325  		if nBytes > 0 {
   326  			c.dataUpdater.WakeUp()
   327  			return nBytes, nil
   328  		}
   329  
   330  		if err := c.waitForDataInput(); err != nil {
   331  			return 0, err
   332  		}
   333  	}
   334  }
   335  
   336  func (c *Connection) waitForDataOutput() error {
   337  	for i := 0; i < 16; i++ {
   338  		select {
   339  		case <-c.dataOutput.Wait():
   340  			return nil
   341  		default:
   342  			runtime.Gosched()
   343  		}
   344  	}
   345  
   346  	duration := time.Second * 16
   347  	if !c.wd.IsZero() {
   348  		duration = time.Until(c.wd)
   349  		if duration < 0 {
   350  			return ErrIOTimeout
   351  		}
   352  	}
   353  
   354  	timeout := time.NewTimer(duration)
   355  	defer timeout.Stop()
   356  
   357  	select {
   358  	case <-c.dataOutput.Wait():
   359  	case <-timeout.C:
   360  		if !c.wd.IsZero() && c.wd.Before(time.Now()) {
   361  			return ErrIOTimeout
   362  		}
   363  	}
   364  
   365  	return nil
   366  }
   367  
   368  // Write implements io.Writer.
   369  func (c *Connection) Write(b []byte) (int, error) {
   370  	reader := bytes.NewReader(b)
   371  	if err := c.writeMultiBufferInternal(reader); err != nil {
   372  		return 0, err
   373  	}
   374  	return len(b), nil
   375  }
   376  
   377  // WriteMultiBuffer implements buf.Writer.
   378  func (c *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
   379  	reader := &buf.MultiBufferContainer{
   380  		MultiBuffer: mb,
   381  	}
   382  	defer reader.Close()
   383  
   384  	return c.writeMultiBufferInternal(reader)
   385  }
   386  
   387  func (c *Connection) writeMultiBufferInternal(reader io.Reader) error {
   388  	updatePending := false
   389  	defer func() {
   390  		if updatePending {
   391  			c.dataUpdater.WakeUp()
   392  		}
   393  	}()
   394  
   395  	var b *buf.Buffer
   396  	defer b.Release()
   397  
   398  	for {
   399  		for {
   400  			if c == nil || c.State() != StateActive {
   401  				return io.ErrClosedPipe
   402  			}
   403  
   404  			if b == nil {
   405  				b = buf.New()
   406  				_, err := b.ReadFrom(io.LimitReader(reader, int64(c.mss)))
   407  				if err != nil {
   408  					return nil
   409  				}
   410  			}
   411  
   412  			if !c.sendingWorker.Push(b) {
   413  				break
   414  			}
   415  			updatePending = true
   416  			b = nil
   417  		}
   418  
   419  		if updatePending {
   420  			c.dataUpdater.WakeUp()
   421  			updatePending = false
   422  		}
   423  
   424  		if err := c.waitForDataOutput(); err != nil {
   425  			return err
   426  		}
   427  	}
   428  }
   429  
   430  func (c *Connection) SetState(state State) {
   431  	current := c.Elapsed()
   432  	atomic.StoreInt32((*int32)(&c.state), int32(state))
   433  	atomic.StoreUint32(&c.stateBeginTime, current)
   434  	newError("#", c.meta.Conversation, " entering state ", state, " at ", current).AtDebug().WriteToLog()
   435  
   436  	switch state {
   437  	case StateReadyToClose:
   438  		c.receivingWorker.CloseRead()
   439  	case StatePeerClosed:
   440  		c.sendingWorker.CloseWrite()
   441  	case StateTerminating:
   442  		c.receivingWorker.CloseRead()
   443  		c.sendingWorker.CloseWrite()
   444  		c.pingUpdater.SetInterval(time.Second)
   445  	case StatePeerTerminating:
   446  		c.sendingWorker.CloseWrite()
   447  		c.pingUpdater.SetInterval(time.Second)
   448  	case StateTerminated:
   449  		c.receivingWorker.CloseRead()
   450  		c.sendingWorker.CloseWrite()
   451  		c.pingUpdater.SetInterval(time.Second)
   452  		c.dataUpdater.WakeUp()
   453  		c.pingUpdater.WakeUp()
   454  		go c.Terminate()
   455  	}
   456  }
   457  
   458  // Close closes the connection.
   459  func (c *Connection) Close() error {
   460  	if c == nil {
   461  		return ErrClosedConnection
   462  	}
   463  
   464  	c.dataInput.Signal()
   465  	c.dataOutput.Signal()
   466  
   467  	switch c.State() {
   468  	case StateReadyToClose, StateTerminating, StateTerminated:
   469  		return ErrClosedConnection
   470  	case StateActive:
   471  		c.SetState(StateReadyToClose)
   472  	case StatePeerClosed:
   473  		c.SetState(StateTerminating)
   474  	case StatePeerTerminating:
   475  		c.SetState(StateTerminated)
   476  	}
   477  
   478  	newError("#", c.meta.Conversation, " closing connection to ", c.meta.RemoteAddr).WriteToLog()
   479  
   480  	return nil
   481  }
   482  
   483  // LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it.
   484  func (c *Connection) LocalAddr() net.Addr {
   485  	if c == nil {
   486  		return nil
   487  	}
   488  	return c.meta.LocalAddr
   489  }
   490  
   491  // RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it.
   492  func (c *Connection) RemoteAddr() net.Addr {
   493  	if c == nil {
   494  		return nil
   495  	}
   496  	return c.meta.RemoteAddr
   497  }
   498  
   499  // SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline.
   500  func (c *Connection) SetDeadline(t time.Time) error {
   501  	if err := c.SetReadDeadline(t); err != nil {
   502  		return err
   503  	}
   504  	return c.SetWriteDeadline(t)
   505  }
   506  
   507  // SetReadDeadline implements the Conn SetReadDeadline method.
   508  func (c *Connection) SetReadDeadline(t time.Time) error {
   509  	if c == nil || c.State() != StateActive {
   510  		return ErrClosedConnection
   511  	}
   512  	c.rd = t
   513  	return nil
   514  }
   515  
   516  // SetWriteDeadline implements the Conn SetWriteDeadline method.
   517  func (c *Connection) SetWriteDeadline(t time.Time) error {
   518  	if c == nil || c.State() != StateActive {
   519  		return ErrClosedConnection
   520  	}
   521  	c.wd = t
   522  	return nil
   523  }
   524  
   525  // kcp update, input loop
   526  func (c *Connection) updateTask() {
   527  	c.flush()
   528  }
   529  
   530  func (c *Connection) Terminate() {
   531  	if c == nil {
   532  		return
   533  	}
   534  	newError("#", c.meta.Conversation, " terminating connection to ", c.RemoteAddr()).WriteToLog()
   535  
   536  	// v.SetState(StateTerminated)
   537  	c.dataInput.Signal()
   538  	c.dataOutput.Signal()
   539  
   540  	c.closer.Close()
   541  	c.sendingWorker.Release()
   542  	c.receivingWorker.Release()
   543  }
   544  
   545  func (c *Connection) HandleOption(opt SegmentOption) {
   546  	if (opt & SegmentOptionClose) == SegmentOptionClose {
   547  		c.OnPeerClosed()
   548  	}
   549  }
   550  
   551  func (c *Connection) OnPeerClosed() {
   552  	switch c.State() {
   553  	case StateReadyToClose:
   554  		c.SetState(StateTerminating)
   555  	case StateActive:
   556  		c.SetState(StatePeerClosed)
   557  	}
   558  }
   559  
   560  // Input when you received a low level packet (eg. UDP packet), call it
   561  func (c *Connection) Input(segments []Segment) {
   562  	current := c.Elapsed()
   563  	atomic.StoreUint32(&c.lastIncomingTime, current)
   564  
   565  	for _, seg := range segments {
   566  		if seg.Conversation() != c.meta.Conversation {
   567  			break
   568  		}
   569  
   570  		switch seg := seg.(type) {
   571  		case *DataSegment:
   572  			c.HandleOption(seg.Option)
   573  			c.receivingWorker.ProcessSegment(seg)
   574  			if c.receivingWorker.IsDataAvailable() {
   575  				c.dataInput.Signal()
   576  			}
   577  			c.dataUpdater.WakeUp()
   578  		case *AckSegment:
   579  			c.HandleOption(seg.Option)
   580  			c.sendingWorker.ProcessSegment(current, seg, c.roundTrip.Timeout())
   581  			c.dataOutput.Signal()
   582  			c.dataUpdater.WakeUp()
   583  		case *CmdOnlySegment:
   584  			c.HandleOption(seg.Option)
   585  			if seg.Command() == CommandTerminate {
   586  				switch c.State() {
   587  				case StateActive, StatePeerClosed:
   588  					c.SetState(StatePeerTerminating)
   589  				case StateReadyToClose:
   590  					c.SetState(StateTerminating)
   591  				case StateTerminating:
   592  					c.SetState(StateTerminated)
   593  				}
   594  			}
   595  			if seg.Option == SegmentOptionClose || seg.Command() == CommandTerminate {
   596  				c.dataInput.Signal()
   597  				c.dataOutput.Signal()
   598  			}
   599  			c.sendingWorker.ProcessReceivingNext(seg.ReceivingNext)
   600  			c.receivingWorker.ProcessSendingNext(seg.SendingNext)
   601  			c.roundTrip.UpdatePeerRTO(seg.PeerRTO, current)
   602  			seg.Release()
   603  		default:
   604  		}
   605  	}
   606  }
   607  
   608  func (c *Connection) flush() {
   609  	current := c.Elapsed()
   610  
   611  	if c.State() == StateTerminated {
   612  		return
   613  	}
   614  	if c.State() == StateActive && current-atomic.LoadUint32(&c.lastIncomingTime) >= 30000 {
   615  		c.Close()
   616  	}
   617  	if c.State() == StateReadyToClose && c.sendingWorker.IsEmpty() {
   618  		c.SetState(StateTerminating)
   619  	}
   620  
   621  	if c.State() == StateTerminating {
   622  		newError("#", c.meta.Conversation, " sending terminating cmd.").AtDebug().WriteToLog()
   623  		c.Ping(current, CommandTerminate)
   624  
   625  		if current-atomic.LoadUint32(&c.stateBeginTime) > 8000 {
   626  			c.SetState(StateTerminated)
   627  		}
   628  		return
   629  	}
   630  	if c.State() == StatePeerTerminating && current-atomic.LoadUint32(&c.stateBeginTime) > 4000 {
   631  		c.SetState(StateTerminating)
   632  	}
   633  
   634  	if c.State() == StateReadyToClose && current-atomic.LoadUint32(&c.stateBeginTime) > 15000 {
   635  		c.SetState(StateTerminating)
   636  	}
   637  
   638  	// flush acknowledges
   639  	c.receivingWorker.Flush(current)
   640  	c.sendingWorker.Flush(current)
   641  
   642  	if current-atomic.LoadUint32(&c.lastPingTime) >= 3000 {
   643  		c.Ping(current, CommandPing)
   644  	}
   645  }
   646  
   647  func (c *Connection) State() State {
   648  	return State(atomic.LoadInt32((*int32)(&c.state)))
   649  }
   650  
   651  func (c *Connection) Ping(current uint32, cmd Command) {
   652  	seg := NewCmdOnlySegment()
   653  	seg.Conv = c.meta.Conversation
   654  	seg.Cmd = cmd
   655  	seg.ReceivingNext = c.receivingWorker.NextNumber()
   656  	seg.SendingNext = c.sendingWorker.FirstUnacknowledged()
   657  	seg.PeerRTO = c.roundTrip.Timeout()
   658  	if c.State() == StateReadyToClose {
   659  		seg.Option = SegmentOptionClose
   660  	}
   661  	c.output.Write(seg)
   662  	atomic.StoreUint32(&c.lastPingTime, current)
   663  	seg.Release()
   664  }