github.com/Uhtred009/v2ray-core-1@v4.31.2+incompatible/transport/internet/kcp/connection.go (about)

     1  // +build !confonly
     2  
     3  package kcp
     4  
     5  import (
     6  	"bytes"
     7  	"io"
     8  	"net"
     9  	"runtime"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"v2ray.com/core/common/buf"
    15  	"v2ray.com/core/common/signal"
    16  	"v2ray.com/core/common/signal/semaphore"
    17  )
    18  
    19  var (
    20  	ErrIOTimeout        = newError("Read/Write timeout")
    21  	ErrClosedListener   = newError("Listener closed.")
    22  	ErrClosedConnection = newError("Connection closed.")
    23  )
    24  
    25  // State of the connection
    26  type State int32
    27  
    28  // Is returns true if current State is one of the candidates.
    29  func (s State) Is(states ...State) bool {
    30  	for _, state := range states {
    31  		if s == state {
    32  			return true
    33  		}
    34  	}
    35  	return false
    36  }
    37  
    38  const (
    39  	StateActive          State = 0 // Connection is active
    40  	StateReadyToClose    State = 1 // Connection is closed locally
    41  	StatePeerClosed      State = 2 // Connection is closed on remote
    42  	StateTerminating     State = 3 // Connection is ready to be destroyed locally
    43  	StatePeerTerminating State = 4 // Connection is ready to be destroyed on remote
    44  	StateTerminated      State = 5 // Connection is destroyed.
    45  )
    46  
    47  func nowMillisec() int64 {
    48  	now := time.Now()
    49  	return now.Unix()*1000 + int64(now.Nanosecond()/1000000)
    50  }
    51  
    52  type RoundTripInfo struct {
    53  	sync.RWMutex
    54  	variation        uint32
    55  	srtt             uint32
    56  	rto              uint32
    57  	minRtt           uint32
    58  	updatedTimestamp uint32
    59  }
    60  
    61  func (info *RoundTripInfo) UpdatePeerRTO(rto uint32, current uint32) {
    62  	info.Lock()
    63  	defer info.Unlock()
    64  
    65  	if current-info.updatedTimestamp < 3000 {
    66  		return
    67  	}
    68  
    69  	info.updatedTimestamp = current
    70  	info.rto = rto
    71  }
    72  
    73  func (info *RoundTripInfo) Update(rtt uint32, current uint32) {
    74  	if rtt > 0x7FFFFFFF {
    75  		return
    76  	}
    77  	info.Lock()
    78  	defer info.Unlock()
    79  
    80  	// https://tools.ietf.org/html/rfc6298
    81  	if info.srtt == 0 {
    82  		info.srtt = rtt
    83  		info.variation = rtt / 2
    84  	} else {
    85  		delta := rtt - info.srtt
    86  		if info.srtt > rtt {
    87  			delta = info.srtt - rtt
    88  		}
    89  		info.variation = (3*info.variation + delta) / 4
    90  		info.srtt = (7*info.srtt + rtt) / 8
    91  		if info.srtt < info.minRtt {
    92  			info.srtt = info.minRtt
    93  		}
    94  	}
    95  	var rto uint32
    96  	if info.minRtt < 4*info.variation {
    97  		rto = info.srtt + 4*info.variation
    98  	} else {
    99  		rto = info.srtt + info.variation
   100  	}
   101  
   102  	if rto > 10000 {
   103  		rto = 10000
   104  	}
   105  	info.rto = rto * 5 / 4
   106  	info.updatedTimestamp = current
   107  }
   108  
   109  func (info *RoundTripInfo) Timeout() uint32 {
   110  	info.RLock()
   111  	defer info.RUnlock()
   112  
   113  	return info.rto
   114  }
   115  
   116  func (info *RoundTripInfo) SmoothedTime() uint32 {
   117  	info.RLock()
   118  	defer info.RUnlock()
   119  
   120  	return info.srtt
   121  }
   122  
   123  type Updater struct {
   124  	interval        int64
   125  	shouldContinue  func() bool
   126  	shouldTerminate func() bool
   127  	updateFunc      func()
   128  	notifier        *semaphore.Instance
   129  }
   130  
   131  func NewUpdater(interval uint32, shouldContinue func() bool, shouldTerminate func() bool, updateFunc func()) *Updater {
   132  	u := &Updater{
   133  		interval:        int64(time.Duration(interval) * time.Millisecond),
   134  		shouldContinue:  shouldContinue,
   135  		shouldTerminate: shouldTerminate,
   136  		updateFunc:      updateFunc,
   137  		notifier:        semaphore.New(1),
   138  	}
   139  	return u
   140  }
   141  
   142  func (u *Updater) WakeUp() {
   143  	select {
   144  	case <-u.notifier.Wait():
   145  		go u.run()
   146  	default:
   147  	}
   148  }
   149  
   150  func (u *Updater) run() {
   151  	defer u.notifier.Signal()
   152  
   153  	if u.shouldTerminate() {
   154  		return
   155  	}
   156  	ticker := time.NewTicker(u.Interval())
   157  	for u.shouldContinue() {
   158  		u.updateFunc()
   159  		<-ticker.C
   160  	}
   161  	ticker.Stop()
   162  }
   163  
   164  func (u *Updater) Interval() time.Duration {
   165  	return time.Duration(atomic.LoadInt64(&u.interval))
   166  }
   167  
   168  func (u *Updater) SetInterval(d time.Duration) {
   169  	atomic.StoreInt64(&u.interval, int64(d))
   170  }
   171  
   172  type ConnMetadata struct {
   173  	LocalAddr    net.Addr
   174  	RemoteAddr   net.Addr
   175  	Conversation uint16
   176  }
   177  
   178  // Connection is a KCP connection over UDP.
   179  type Connection struct {
   180  	meta       ConnMetadata
   181  	closer     io.Closer
   182  	rd         time.Time
   183  	wd         time.Time // write deadline
   184  	since      int64
   185  	dataInput  *signal.Notifier
   186  	dataOutput *signal.Notifier
   187  	Config     *Config
   188  
   189  	state            State
   190  	stateBeginTime   uint32
   191  	lastIncomingTime uint32
   192  	lastPingTime     uint32
   193  
   194  	mss       uint32
   195  	roundTrip *RoundTripInfo
   196  
   197  	receivingWorker *ReceivingWorker
   198  	sendingWorker   *SendingWorker
   199  
   200  	output SegmentWriter
   201  
   202  	dataUpdater *Updater
   203  	pingUpdater *Updater
   204  }
   205  
   206  // NewConnection create a new KCP connection between local and remote.
   207  func NewConnection(meta ConnMetadata, writer PacketWriter, closer io.Closer, config *Config) *Connection {
   208  	newError("#", meta.Conversation, " creating connection to ", meta.RemoteAddr).WriteToLog()
   209  
   210  	conn := &Connection{
   211  		meta:       meta,
   212  		closer:     closer,
   213  		since:      nowMillisec(),
   214  		dataInput:  signal.NewNotifier(),
   215  		dataOutput: signal.NewNotifier(),
   216  		Config:     config,
   217  		output:     NewRetryableWriter(NewSegmentWriter(writer)),
   218  		mss:        config.GetMTUValue() - uint32(writer.Overhead()) - DataSegmentOverhead,
   219  		roundTrip: &RoundTripInfo{
   220  			rto:    100,
   221  			minRtt: config.GetTTIValue(),
   222  		},
   223  	}
   224  
   225  	conn.receivingWorker = NewReceivingWorker(conn)
   226  	conn.sendingWorker = NewSendingWorker(conn)
   227  
   228  	isTerminating := func() bool {
   229  		return conn.State().Is(StateTerminating, StateTerminated)
   230  	}
   231  	isTerminated := func() bool {
   232  		return conn.State() == StateTerminated
   233  	}
   234  	conn.dataUpdater = NewUpdater(
   235  		config.GetTTIValue(),
   236  		func() bool {
   237  			return !isTerminating() && (conn.sendingWorker.UpdateNecessary() || conn.receivingWorker.UpdateNecessary())
   238  		},
   239  		isTerminating,
   240  		conn.updateTask)
   241  	conn.pingUpdater = NewUpdater(
   242  		5000, // 5 seconds
   243  		func() bool { return !isTerminated() },
   244  		isTerminated,
   245  		conn.updateTask)
   246  	conn.pingUpdater.WakeUp()
   247  
   248  	return conn
   249  }
   250  
   251  func (c *Connection) Elapsed() uint32 {
   252  	return uint32(nowMillisec() - c.since)
   253  }
   254  
   255  // ReadMultiBuffer implements buf.Reader.
   256  func (c *Connection) ReadMultiBuffer() (buf.MultiBuffer, error) {
   257  	if c == nil {
   258  		return nil, io.EOF
   259  	}
   260  
   261  	for {
   262  		if c.State().Is(StateReadyToClose, StateTerminating, StateTerminated) {
   263  			return nil, io.EOF
   264  		}
   265  		mb := c.receivingWorker.ReadMultiBuffer()
   266  		if !mb.IsEmpty() {
   267  			c.dataUpdater.WakeUp()
   268  			return mb, nil
   269  		}
   270  
   271  		if c.State() == StatePeerTerminating {
   272  			return nil, io.EOF
   273  		}
   274  
   275  		if err := c.waitForDataInput(); err != nil {
   276  			return nil, err
   277  		}
   278  	}
   279  }
   280  
   281  func (c *Connection) waitForDataInput() error {
   282  	for i := 0; i < 16; i++ {
   283  		select {
   284  		case <-c.dataInput.Wait():
   285  			return nil
   286  		default:
   287  			runtime.Gosched()
   288  		}
   289  	}
   290  
   291  	duration := time.Second * 16
   292  	if !c.rd.IsZero() {
   293  		duration = time.Until(c.rd)
   294  		if duration < 0 {
   295  			return ErrIOTimeout
   296  		}
   297  	}
   298  
   299  	timeout := time.NewTimer(duration)
   300  	defer timeout.Stop()
   301  
   302  	select {
   303  	case <-c.dataInput.Wait():
   304  	case <-timeout.C:
   305  		if !c.rd.IsZero() && c.rd.Before(time.Now()) {
   306  			return ErrIOTimeout
   307  		}
   308  	}
   309  
   310  	return nil
   311  }
   312  
   313  // Read implements the Conn Read method.
   314  func (c *Connection) Read(b []byte) (int, error) {
   315  	if c == nil {
   316  		return 0, io.EOF
   317  	}
   318  
   319  	for {
   320  		if c.State().Is(StateReadyToClose, StateTerminating, StateTerminated) {
   321  			return 0, io.EOF
   322  		}
   323  		nBytes := c.receivingWorker.Read(b)
   324  		if nBytes > 0 {
   325  			c.dataUpdater.WakeUp()
   326  			return nBytes, nil
   327  		}
   328  
   329  		if err := c.waitForDataInput(); err != nil {
   330  			return 0, err
   331  		}
   332  	}
   333  }
   334  
   335  func (c *Connection) waitForDataOutput() error {
   336  	for i := 0; i < 16; i++ {
   337  		select {
   338  		case <-c.dataOutput.Wait():
   339  			return nil
   340  		default:
   341  			runtime.Gosched()
   342  		}
   343  	}
   344  
   345  	duration := time.Second * 16
   346  	if !c.wd.IsZero() {
   347  		duration = time.Until(c.wd)
   348  		if duration < 0 {
   349  			return ErrIOTimeout
   350  		}
   351  	}
   352  
   353  	timeout := time.NewTimer(duration)
   354  	defer timeout.Stop()
   355  
   356  	select {
   357  	case <-c.dataOutput.Wait():
   358  	case <-timeout.C:
   359  		if !c.wd.IsZero() && c.wd.Before(time.Now()) {
   360  			return ErrIOTimeout
   361  		}
   362  	}
   363  
   364  	return nil
   365  }
   366  
   367  // Write implements io.Writer.
   368  func (c *Connection) Write(b []byte) (int, error) {
   369  	reader := bytes.NewReader(b)
   370  	if err := c.writeMultiBufferInternal(reader); err != nil {
   371  		return 0, err
   372  	}
   373  	return len(b), nil
   374  }
   375  
   376  // WriteMultiBuffer implements buf.Writer.
   377  func (c *Connection) WriteMultiBuffer(mb buf.MultiBuffer) error {
   378  	reader := &buf.MultiBufferContainer{
   379  		MultiBuffer: mb,
   380  	}
   381  	defer reader.Close()
   382  
   383  	return c.writeMultiBufferInternal(reader)
   384  }
   385  
   386  func (c *Connection) writeMultiBufferInternal(reader io.Reader) error {
   387  	updatePending := false
   388  	defer func() {
   389  		if updatePending {
   390  			c.dataUpdater.WakeUp()
   391  		}
   392  	}()
   393  
   394  	var b *buf.Buffer
   395  	defer b.Release()
   396  
   397  	for {
   398  		for {
   399  			if c == nil || c.State() != StateActive {
   400  				return io.ErrClosedPipe
   401  			}
   402  
   403  			if b == nil {
   404  				b = buf.New()
   405  				_, err := b.ReadFrom(io.LimitReader(reader, int64(c.mss)))
   406  				if err != nil {
   407  					return nil
   408  				}
   409  			}
   410  
   411  			if !c.sendingWorker.Push(b) {
   412  				break
   413  			}
   414  			updatePending = true
   415  			b = nil
   416  		}
   417  
   418  		if updatePending {
   419  			c.dataUpdater.WakeUp()
   420  			updatePending = false
   421  		}
   422  
   423  		if err := c.waitForDataOutput(); err != nil {
   424  			return err
   425  		}
   426  	}
   427  }
   428  
   429  func (c *Connection) SetState(state State) {
   430  	current := c.Elapsed()
   431  	atomic.StoreInt32((*int32)(&c.state), int32(state))
   432  	atomic.StoreUint32(&c.stateBeginTime, current)
   433  	newError("#", c.meta.Conversation, " entering state ", state, " at ", current).AtDebug().WriteToLog()
   434  
   435  	switch state {
   436  	case StateReadyToClose:
   437  		c.receivingWorker.CloseRead()
   438  	case StatePeerClosed:
   439  		c.sendingWorker.CloseWrite()
   440  	case StateTerminating:
   441  		c.receivingWorker.CloseRead()
   442  		c.sendingWorker.CloseWrite()
   443  		c.pingUpdater.SetInterval(time.Second)
   444  	case StatePeerTerminating:
   445  		c.sendingWorker.CloseWrite()
   446  		c.pingUpdater.SetInterval(time.Second)
   447  	case StateTerminated:
   448  		c.receivingWorker.CloseRead()
   449  		c.sendingWorker.CloseWrite()
   450  		c.pingUpdater.SetInterval(time.Second)
   451  		c.dataUpdater.WakeUp()
   452  		c.pingUpdater.WakeUp()
   453  		go c.Terminate()
   454  	}
   455  }
   456  
   457  // Close closes the connection.
   458  func (c *Connection) Close() error {
   459  	if c == nil {
   460  		return ErrClosedConnection
   461  	}
   462  
   463  	c.dataInput.Signal()
   464  	c.dataOutput.Signal()
   465  
   466  	switch c.State() {
   467  	case StateReadyToClose, StateTerminating, StateTerminated:
   468  		return ErrClosedConnection
   469  	case StateActive:
   470  		c.SetState(StateReadyToClose)
   471  	case StatePeerClosed:
   472  		c.SetState(StateTerminating)
   473  	case StatePeerTerminating:
   474  		c.SetState(StateTerminated)
   475  	}
   476  
   477  	newError("#", c.meta.Conversation, " closing connection to ", c.meta.RemoteAddr).WriteToLog()
   478  
   479  	return nil
   480  }
   481  
   482  // LocalAddr returns the local network address. The Addr returned is shared by all invocations of LocalAddr, so do not modify it.
   483  func (c *Connection) LocalAddr() net.Addr {
   484  	if c == nil {
   485  		return nil
   486  	}
   487  	return c.meta.LocalAddr
   488  }
   489  
   490  // RemoteAddr returns the remote network address. The Addr returned is shared by all invocations of RemoteAddr, so do not modify it.
   491  func (c *Connection) RemoteAddr() net.Addr {
   492  	if c == nil {
   493  		return nil
   494  	}
   495  	return c.meta.RemoteAddr
   496  }
   497  
   498  // SetDeadline sets the deadline associated with the listener. A zero time value disables the deadline.
   499  func (c *Connection) SetDeadline(t time.Time) error {
   500  	if err := c.SetReadDeadline(t); err != nil {
   501  		return err
   502  	}
   503  	return c.SetWriteDeadline(t)
   504  }
   505  
   506  // SetReadDeadline implements the Conn SetReadDeadline method.
   507  func (c *Connection) SetReadDeadline(t time.Time) error {
   508  	if c == nil || c.State() != StateActive {
   509  		return ErrClosedConnection
   510  	}
   511  	c.rd = t
   512  	return nil
   513  }
   514  
   515  // SetWriteDeadline implements the Conn SetWriteDeadline method.
   516  func (c *Connection) SetWriteDeadline(t time.Time) error {
   517  	if c == nil || c.State() != StateActive {
   518  		return ErrClosedConnection
   519  	}
   520  	c.wd = t
   521  	return nil
   522  }
   523  
   524  // kcp update, input loop
   525  func (c *Connection) updateTask() {
   526  	c.flush()
   527  }
   528  
   529  func (c *Connection) Terminate() {
   530  	if c == nil {
   531  		return
   532  	}
   533  	newError("#", c.meta.Conversation, " terminating connection to ", c.RemoteAddr()).WriteToLog()
   534  
   535  	//v.SetState(StateTerminated)
   536  	c.dataInput.Signal()
   537  	c.dataOutput.Signal()
   538  
   539  	c.closer.Close()
   540  	c.sendingWorker.Release()
   541  	c.receivingWorker.Release()
   542  }
   543  
   544  func (c *Connection) HandleOption(opt SegmentOption) {
   545  	if (opt & SegmentOptionClose) == SegmentOptionClose {
   546  		c.OnPeerClosed()
   547  	}
   548  }
   549  
   550  func (c *Connection) OnPeerClosed() {
   551  	switch c.State() {
   552  	case StateReadyToClose:
   553  		c.SetState(StateTerminating)
   554  	case StateActive:
   555  		c.SetState(StatePeerClosed)
   556  	}
   557  }
   558  
   559  // Input when you received a low level packet (eg. UDP packet), call it
   560  func (c *Connection) Input(segments []Segment) {
   561  	current := c.Elapsed()
   562  	atomic.StoreUint32(&c.lastIncomingTime, current)
   563  
   564  	for _, seg := range segments {
   565  		if seg.Conversation() != c.meta.Conversation {
   566  			break
   567  		}
   568  
   569  		switch seg := seg.(type) {
   570  		case *DataSegment:
   571  			c.HandleOption(seg.Option)
   572  			c.receivingWorker.ProcessSegment(seg)
   573  			if c.receivingWorker.IsDataAvailable() {
   574  				c.dataInput.Signal()
   575  			}
   576  			c.dataUpdater.WakeUp()
   577  		case *AckSegment:
   578  			c.HandleOption(seg.Option)
   579  			c.sendingWorker.ProcessSegment(current, seg, c.roundTrip.Timeout())
   580  			c.dataOutput.Signal()
   581  			c.dataUpdater.WakeUp()
   582  		case *CmdOnlySegment:
   583  			c.HandleOption(seg.Option)
   584  			if seg.Command() == CommandTerminate {
   585  				switch c.State() {
   586  				case StateActive, StatePeerClosed:
   587  					c.SetState(StatePeerTerminating)
   588  				case StateReadyToClose:
   589  					c.SetState(StateTerminating)
   590  				case StateTerminating:
   591  					c.SetState(StateTerminated)
   592  				}
   593  			}
   594  			if seg.Option == SegmentOptionClose || seg.Command() == CommandTerminate {
   595  				c.dataInput.Signal()
   596  				c.dataOutput.Signal()
   597  			}
   598  			c.sendingWorker.ProcessReceivingNext(seg.ReceivingNext)
   599  			c.receivingWorker.ProcessSendingNext(seg.SendingNext)
   600  			c.roundTrip.UpdatePeerRTO(seg.PeerRTO, current)
   601  			seg.Release()
   602  		default:
   603  		}
   604  	}
   605  }
   606  
   607  func (c *Connection) flush() {
   608  	current := c.Elapsed()
   609  
   610  	if c.State() == StateTerminated {
   611  		return
   612  	}
   613  	if c.State() == StateActive && current-atomic.LoadUint32(&c.lastIncomingTime) >= 30000 {
   614  		c.Close()
   615  	}
   616  	if c.State() == StateReadyToClose && c.sendingWorker.IsEmpty() {
   617  		c.SetState(StateTerminating)
   618  	}
   619  
   620  	if c.State() == StateTerminating {
   621  		newError("#", c.meta.Conversation, " sending terminating cmd.").AtDebug().WriteToLog()
   622  		c.Ping(current, CommandTerminate)
   623  
   624  		if current-atomic.LoadUint32(&c.stateBeginTime) > 8000 {
   625  			c.SetState(StateTerminated)
   626  		}
   627  		return
   628  	}
   629  	if c.State() == StatePeerTerminating && current-atomic.LoadUint32(&c.stateBeginTime) > 4000 {
   630  		c.SetState(StateTerminating)
   631  	}
   632  
   633  	if c.State() == StateReadyToClose && current-atomic.LoadUint32(&c.stateBeginTime) > 15000 {
   634  		c.SetState(StateTerminating)
   635  	}
   636  
   637  	// flush acknowledges
   638  	c.receivingWorker.Flush(current)
   639  	c.sendingWorker.Flush(current)
   640  
   641  	if current-atomic.LoadUint32(&c.lastPingTime) >= 3000 {
   642  		c.Ping(current, CommandPing)
   643  	}
   644  }
   645  
   646  func (c *Connection) State() State {
   647  	return State(atomic.LoadInt32((*int32)(&c.state)))
   648  }
   649  
   650  func (c *Connection) Ping(current uint32, cmd Command) {
   651  	seg := NewCmdOnlySegment()
   652  	seg.Conv = c.meta.Conversation
   653  	seg.Cmd = cmd
   654  	seg.ReceivingNext = c.receivingWorker.NextNumber()
   655  	seg.SendingNext = c.sendingWorker.FirstUnacknowledged()
   656  	seg.PeerRTO = c.roundTrip.Timeout()
   657  	if c.State() == StateReadyToClose {
   658  		seg.Option = SegmentOptionClose
   659  	}
   660  	c.output.Write(seg)
   661  	atomic.StoreUint32(&c.lastPingTime, current)
   662  	seg.Release()
   663  }