github.com/AntonOrnatskyi/goproxy@v0.0.0-20190205095733-4526a9fa18b4/core/dst/conn.go (about)

     1  // Copyright 2014 The DST Authors. All rights reserved.
     2  // Use of this source code is governed by an MIT-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package dst
     6  
     7  import (
     8  	"bytes"
     9  	crand "crypto/rand"
    10  	"encoding/binary"
    11  	"fmt"
    12  	"io"
    13  	"runtime/debug"
    14  
    15  	"math/rand"
    16  	"net"
    17  	"sync"
    18  	"sync/atomic"
    19  	"time"
    20  )
    21  
    22  const (
    23  	defExpTime       = 100 * time.Millisecond // N * (4 * RTT + RTTVar + SYN)
    24  	expCountClose    = 8                      // close connection after this many Exps
    25  	minTimeClose     = 5 * time.Second        // if at least this long has passed
    26  	maxInputBuffer   = 8 << 20                // bytes
    27  	muxBufferPackets = 128                    // buffer size of channel between mux and reader routine
    28  	rttMeasureWindow = 32                     // number of packets to track for RTT averaging
    29  	rttMeasureSample = 128                    // Sample every ... packet for RTT
    30  
    31  	// number of bytes to subtract from MTU when chunking data, to try to
    32  	// avoid fragmentation
    33  	sliceOverhead = 8 /*pppoe, similar*/ + 20 /*ipv4*/ + 8 /*udp*/ + 16 /*dst*/
    34  )
    35  
    36  func init() {
    37  	// Properly seed the random number generator that we use for sequence
    38  	// numbers and stuff.
    39  	buf := make([]byte, 8)
    40  	if n, err := crand.Read(buf); n != 8 || err != nil {
    41  		panic("init random failure")
    42  	}
    43  	rand.Seed(int64(binary.BigEndian.Uint64(buf)))
    44  }
    45  
    46  // TODO: export this interface when it's usable from the outside
    47  type congestionController interface {
    48  	Ack()
    49  	NegAck()
    50  	Exp()
    51  	SendWindow() int
    52  	PacketRate() int // PPS
    53  	UpdateRTT(time.Duration)
    54  }
    55  
    56  // Conn is an SDT connection carried over a Mux.
    57  type Conn struct {
    58  	// Set at creation, thereafter immutable:
    59  
    60  	mux          *Mux
    61  	dst          net.Addr
    62  	connID       connectionID
    63  	remoteConnID connectionID
    64  	in           chan packet
    65  	cc           congestionController
    66  	packetSize   int
    67  	closed       chan struct{}
    68  	closeOnce    sync.Once
    69  
    70  	// Touched by more than one goroutine, needs locking.
    71  
    72  	nextSeqNoMut sync.Mutex
    73  	nextSeqNo    sequenceNo
    74  
    75  	inbufMut  sync.Mutex
    76  	inbufCond *sync.Cond
    77  	inbuf     bytes.Buffer
    78  
    79  	expMut sync.Mutex
    80  	exp    *time.Timer
    81  
    82  	sendBuffer *sendBuffer // goroutine safe
    83  
    84  	packetDelays     [rttMeasureWindow]time.Duration
    85  	packetDelaysSlot int
    86  	packetDelaysMut  sync.Mutex
    87  
    88  	// Owned by the reader routine, needs no locking
    89  
    90  	recvBuffer        packetList
    91  	nextRecvSeqNo     sequenceNo
    92  	lastAckedSeqNo    sequenceNo
    93  	lastNegAckedSeqNo sequenceNo
    94  	expCount          int
    95  	expReset          time.Time
    96  
    97  	// Only accessed atomically
    98  
    99  	packetsIn         int64
   100  	packetsOut        int64
   101  	bytesIn           int64
   102  	bytesOut          int64
   103  	resentPackets     int64
   104  	droppedPackets    int64
   105  	outOfOrderPackets int64
   106  
   107  	// Special
   108  
   109  	debugResetRecvSeqNo chan sequenceNo
   110  }
   111  
   112  func newConn(m *Mux, dst net.Addr) *Conn {
   113  	conn := &Conn{
   114  		mux:                 m,
   115  		dst:                 dst,
   116  		nextSeqNo:           sequenceNo(rand.Uint32()),
   117  		packetSize:          maxPacketSize,
   118  		in:                  make(chan packet, muxBufferPackets),
   119  		closed:              make(chan struct{}),
   120  		sendBuffer:          newSendBuffer(m),
   121  		exp:                 time.NewTimer(defExpTime),
   122  		debugResetRecvSeqNo: make(chan sequenceNo),
   123  		expReset:            time.Now(),
   124  	}
   125  
   126  	conn.lastAckedSeqNo = conn.nextSeqNo - 1
   127  	conn.inbufCond = sync.NewCond(&conn.inbufMut)
   128  
   129  	conn.cc = newWindowCC()
   130  	conn.sendBuffer.SetWindowAndRate(conn.cc.SendWindow(), conn.cc.PacketRate())
   131  	conn.recvBuffer.Resize(128)
   132  
   133  	return conn
   134  }
   135  
   136  func (c *Conn) start() {
   137  	go func() {
   138  		defer func() {
   139  			if e := recover(); e != nil {
   140  				fmt.Printf("crashed, err: %s\nstack:\n%s", e, string(debug.Stack()))
   141  			}
   142  		}()
   143  		c.reader()
   144  	}()
   145  }
   146  
   147  func (c *Conn) reader() {
   148  	if debugConnection {
   149  		log.Println(c, "reader() starting")
   150  		defer log.Println(c, "reader() exiting")
   151  	}
   152  
   153  	for {
   154  		select {
   155  		case <-c.closed:
   156  			// Ack any received but not yet acked messages.
   157  			c.sendAck(0)
   158  
   159  			// Send a shutdown message.
   160  			c.nextSeqNoMut.Lock()
   161  			c.mux.write(packet{
   162  				src: c.connID,
   163  				dst: c.dst,
   164  				hdr: header{
   165  					packetType: typeShutdown,
   166  					connID:     c.remoteConnID,
   167  					sequenceNo: c.nextSeqNo,
   168  				},
   169  			})
   170  			c.nextSeqNo++
   171  			c.nextSeqNoMut.Unlock()
   172  			atomic.AddInt64(&c.packetsOut, 1)
   173  			atomic.AddInt64(&c.bytesOut, dstHeaderLen)
   174  			return
   175  
   176  		case pkt := <-c.in:
   177  			atomic.AddInt64(&c.packetsIn, 1)
   178  			atomic.AddInt64(&c.bytesIn, dstHeaderLen+int64(len(pkt.data)))
   179  
   180  			c.expCount = 1
   181  
   182  			switch pkt.hdr.packetType {
   183  			case typeData:
   184  				c.rcvData(pkt)
   185  			case typeAck:
   186  				c.rcvAck(pkt)
   187  			case typeNegAck:
   188  				c.rcvNegAck(pkt)
   189  			case typeShutdown:
   190  				c.rcvShutdown(pkt)
   191  			default:
   192  				log.Println("Unhandled packet", pkt)
   193  				continue
   194  			}
   195  
   196  		case <-c.exp.C:
   197  			c.eventExp()
   198  			c.resetExp()
   199  
   200  		case n := <-c.debugResetRecvSeqNo:
   201  			// Back door for testing
   202  			c.lastAckedSeqNo = n - 1
   203  			c.nextRecvSeqNo = n
   204  		}
   205  	}
   206  }
   207  
   208  func (c *Conn) eventExp() {
   209  	c.expCount++
   210  
   211  	if c.sendBuffer.lost.Len() > 0 || c.sendBuffer.send.Len() > 0 {
   212  		c.cc.Exp()
   213  		c.sendBuffer.SetWindowAndRate(c.cc.SendWindow(), c.cc.PacketRate())
   214  		c.sendBuffer.ScheduleResend()
   215  
   216  		if debugConnection {
   217  			log.Println(c, "did resends due to Exp")
   218  		}
   219  
   220  		if c.expCount > expCountClose && time.Since(c.expReset) > minTimeClose {
   221  			if debugConnection {
   222  				log.Println(c, "close due to Exp")
   223  			}
   224  
   225  			// We're shutting down due to repeated exp:s. Don't wait for the
   226  			// send buffer to drain, which it would otherwise do in
   227  			// c.Close()..
   228  			c.sendBuffer.CrashStop()
   229  
   230  			c.Close()
   231  		}
   232  	}
   233  }
   234  
   235  func (c *Conn) rcvAck(pkt packet) {
   236  	ack := pkt.hdr.sequenceNo
   237  
   238  	if debugConnection {
   239  		log.Printf("%v read Ack %v", c, ack)
   240  	}
   241  
   242  	c.cc.Ack()
   243  
   244  	if ack%rttMeasureSample == 0 {
   245  		if ts := timestamp(binary.BigEndian.Uint32(pkt.data)); ts > 0 {
   246  			if delay := time.Duration(timestampMicros()-ts) * time.Microsecond; delay > 0 {
   247  				c.packetDelaysMut.Lock()
   248  				c.packetDelays[c.packetDelaysSlot] = delay
   249  				c.packetDelaysSlot = (c.packetDelaysSlot + 1) % len(c.packetDelays)
   250  				c.packetDelaysMut.Unlock()
   251  
   252  				if rtt, n := c.averageDelay(); n > 8 {
   253  					c.cc.UpdateRTT(rtt)
   254  				}
   255  			}
   256  		}
   257  	}
   258  
   259  	c.sendBuffer.Acknowledge(ack)
   260  	c.sendBuffer.SetWindowAndRate(c.cc.SendWindow(), c.cc.PacketRate())
   261  
   262  	c.resetExp()
   263  }
   264  
   265  func (c *Conn) averageDelay() (time.Duration, int) {
   266  	var total time.Duration
   267  	var n int
   268  
   269  	c.packetDelaysMut.Lock()
   270  	for _, d := range c.packetDelays {
   271  		if d != 0 {
   272  			total += d
   273  			n++
   274  		}
   275  	}
   276  	c.packetDelaysMut.Unlock()
   277  
   278  	if n == 0 {
   279  		return 0, 0
   280  	}
   281  	return total / time.Duration(n), n
   282  }
   283  
   284  func (c *Conn) rcvNegAck(pkt packet) {
   285  	nak := pkt.hdr.sequenceNo
   286  
   287  	if debugConnection {
   288  		log.Printf("%v read NegAck %v", c, nak)
   289  	}
   290  
   291  	c.sendBuffer.NegativeAck(nak)
   292  
   293  	//c.cc.NegAck()
   294  	c.resetExp()
   295  }
   296  
   297  func (c *Conn) rcvShutdown(pkt packet) {
   298  	// XXX: We accept shutdown packets somewhat from the future since the
   299  	// sender will number the shutdown after any packets that might still be
   300  	// in the write buffer. This should be fixed to let the write buffer empty
   301  	// on close and reduce the window here.
   302  	if pkt.LessSeq(c.nextRecvSeqNo + 128) {
   303  		if debugConnection {
   304  			log.Println(c, "close due to shutdown")
   305  		}
   306  		c.Close()
   307  	}
   308  }
   309  
   310  func (c *Conn) rcvData(pkt packet) {
   311  	if debugConnection {
   312  		log.Println(c, "recv data", pkt.hdr)
   313  	}
   314  
   315  	if pkt.LessSeq(c.nextRecvSeqNo) {
   316  		if debugConnection {
   317  			log.Printf("%v old packet received; seq %v, expected %v", c, pkt.hdr.sequenceNo, c.nextRecvSeqNo)
   318  		}
   319  		atomic.AddInt64(&c.droppedPackets, 1)
   320  		return
   321  	}
   322  
   323  	if debugConnection {
   324  		log.Println(c, "into recv buffer:", pkt)
   325  	}
   326  	c.recvBuffer.InsertSorted(pkt)
   327  	if c.recvBuffer.LowestSeq() == c.nextRecvSeqNo {
   328  		for _, pkt := range c.recvBuffer.PopSequence(^sequenceNo(0)) {
   329  			if debugConnection {
   330  				log.Println(c, "from recv buffer:", pkt)
   331  			}
   332  
   333  			// An in-sequence packet.
   334  
   335  			c.nextRecvSeqNo = pkt.hdr.sequenceNo + 1
   336  
   337  			c.sendAck(pkt.hdr.timestamp)
   338  
   339  			c.inbufMut.Lock()
   340  			for c.inbuf.Len() > len(pkt.data)+maxInputBuffer {
   341  				c.inbufCond.Wait()
   342  				select {
   343  				case <-c.closed:
   344  					return
   345  				default:
   346  				}
   347  			}
   348  
   349  			c.inbuf.Write(pkt.data)
   350  			c.inbufCond.Broadcast()
   351  			c.inbufMut.Unlock()
   352  		}
   353  	} else {
   354  		if debugConnection {
   355  			log.Printf("%v lost; seq %v, expected %v", c, pkt.hdr.sequenceNo, c.nextRecvSeqNo)
   356  		}
   357  		c.recvBuffer.InsertSorted(pkt)
   358  		c.sendNegAck()
   359  		atomic.AddInt64(&c.outOfOrderPackets, 1)
   360  	}
   361  }
   362  
   363  func (c *Conn) sendAck(ts timestamp) {
   364  	if c.lastAckedSeqNo == c.nextRecvSeqNo {
   365  		return
   366  	}
   367  
   368  	var buf [4]byte
   369  	binary.BigEndian.PutUint32(buf[:], uint32(ts))
   370  	c.mux.write(packet{
   371  		src: c.connID,
   372  		dst: c.dst,
   373  		hdr: header{
   374  			packetType: typeAck,
   375  			connID:     c.remoteConnID,
   376  			sequenceNo: c.nextRecvSeqNo,
   377  		},
   378  		data: buf[:],
   379  	})
   380  
   381  	atomic.AddInt64(&c.packetsOut, 1)
   382  	atomic.AddInt64(&c.bytesOut, dstHeaderLen)
   383  	if debugConnection {
   384  		log.Printf("%v send Ack %v", c, c.nextRecvSeqNo)
   385  	}
   386  
   387  	c.lastAckedSeqNo = c.nextRecvSeqNo
   388  }
   389  
   390  func (c *Conn) sendNegAck() {
   391  	if c.lastNegAckedSeqNo == c.nextRecvSeqNo {
   392  		return
   393  	}
   394  
   395  	c.mux.write(packet{
   396  		src: c.connID,
   397  		dst: c.dst,
   398  		hdr: header{
   399  			packetType: typeNegAck,
   400  			connID:     c.remoteConnID,
   401  			sequenceNo: c.nextRecvSeqNo,
   402  		},
   403  	})
   404  
   405  	atomic.AddInt64(&c.packetsOut, 1)
   406  	atomic.AddInt64(&c.bytesOut, dstHeaderLen)
   407  	if debugConnection {
   408  		log.Printf("%v send NegAck %v", c, c.nextRecvSeqNo)
   409  	}
   410  
   411  	c.lastNegAckedSeqNo = c.nextRecvSeqNo
   412  }
   413  
   414  func (c *Conn) resetExp() {
   415  	d, _ := c.averageDelay()
   416  	d = d*4 + 10*time.Millisecond
   417  
   418  	if d < defExpTime {
   419  		d = defExpTime
   420  	}
   421  
   422  	c.expMut.Lock()
   423  	c.exp.Reset(d)
   424  	c.expMut.Unlock()
   425  }
   426  
   427  // String returns a string representation of the connection.
   428  func (c *Conn) String() string {
   429  	return fmt.Sprintf("%v/%v/%v", c.connID, c.LocalAddr(), c.RemoteAddr())
   430  }
   431  
   432  // Read reads data from the connection.
   433  // Read can be made to time out and return a Error with Timeout() == true
   434  // after a fixed time limit; see SetDeadline and SetReadDeadline.
   435  func (c *Conn) Read(b []byte) (n int, err error) {
   436  	defer func() {
   437  		if e := recover(); e != nil {
   438  			n = 0
   439  			err = io.EOF
   440  		}
   441  	}()
   442  	c.inbufMut.Lock()
   443  	defer c.inbufMut.Unlock()
   444  	for c.inbuf.Len() == 0 {
   445  		select {
   446  		case <-c.closed:
   447  			return 0, io.EOF
   448  		default:
   449  		}
   450  		c.inbufCond.Wait()
   451  	}
   452  	return c.inbuf.Read(b)
   453  }
   454  
   455  // Write writes data to the connection.
   456  // Write can be made to time out and return a Error with Timeout() == true
   457  // after a fixed time limit; see SetDeadline and SetWriteDeadline.
   458  func (c *Conn) Write(b []byte) (n int, err error) {
   459  	select {
   460  	case <-c.closed:
   461  		return 0, ErrClosedConn
   462  	default:
   463  	}
   464  
   465  	sent := 0
   466  	sliceSize := c.packetSize - sliceOverhead
   467  	for i := 0; i < len(b); i += sliceSize {
   468  		nxt := i + sliceSize
   469  		if nxt > len(b) {
   470  			nxt = len(b)
   471  		}
   472  		slice := b[i:nxt]
   473  		sliceCopy := c.mux.buffers.Get().([]byte)[:len(slice)]
   474  		copy(sliceCopy, slice)
   475  
   476  		c.nextSeqNoMut.Lock()
   477  		pkt := packet{
   478  			src: c.connID,
   479  			dst: c.dst,
   480  			hdr: header{
   481  				packetType: typeData,
   482  				sequenceNo: c.nextSeqNo,
   483  				connID:     c.remoteConnID,
   484  			},
   485  			data: sliceCopy,
   486  		}
   487  		c.nextSeqNo++
   488  		c.nextSeqNoMut.Unlock()
   489  
   490  		if err := c.sendBuffer.Write(pkt); err != nil {
   491  			return sent, err
   492  		}
   493  
   494  		atomic.AddInt64(&c.packetsOut, 1)
   495  		atomic.AddInt64(&c.bytesOut, int64(len(slice)+dstHeaderLen))
   496  
   497  		sent += len(slice)
   498  		c.resetExp()
   499  	}
   500  	return sent, nil
   501  }
   502  
   503  // Close closes the connection.
   504  // Any blocked Read or Write operations will be unblocked and return errors.
   505  func (c *Conn) Close() error {
   506  	defer func() {
   507  		_ = recover()
   508  	}()
   509  	c.closeOnce.Do(func() {
   510  		if debugConnection {
   511  			log.Println(c, "explicit close start")
   512  			defer log.Println(c, "explicit close done")
   513  		}
   514  
   515  		// XXX: Ugly hack to implement lingering sockets...
   516  		time.Sleep(4 * defExpTime)
   517  
   518  		c.sendBuffer.Stop()
   519  		c.mux.removeConn(c)
   520  		close(c.closed)
   521  
   522  		c.inbufMut.Lock()
   523  		c.inbufCond.Broadcast()
   524  		c.inbufMut.Unlock()
   525  	})
   526  	return nil
   527  }
   528  
   529  // LocalAddr returns the local network address.
   530  func (c *Conn) LocalAddr() net.Addr {
   531  	return c.mux.Addr()
   532  }
   533  
   534  // RemoteAddr returns the remote network address.
   535  func (c *Conn) RemoteAddr() net.Addr {
   536  	return c.dst
   537  }
   538  
   539  // SetDeadline sets the read and write deadlines associated
   540  // with the connection. It is equivalent to calling both
   541  // SetReadDeadline and SetWriteDeadline.
   542  //
   543  // A deadline is an absolute time after which I/O operations
   544  // fail with a timeout (see type Error) instead of
   545  // blocking. The deadline applies to all future I/O, not just
   546  // the immediately following call to Read or Write.
   547  //
   548  // An idle timeout can be implemented by repeatedly extending
   549  // the deadline after successful Read or Write calls.
   550  //
   551  // A zero value for t means I/O operations will not time out.
   552  //
   553  // BUG(jb): SetDeadline is not implemented.
   554  func (c *Conn) SetDeadline(t time.Time) error {
   555  	return ErrNotImplemented
   556  }
   557  
   558  // SetReadDeadline sets the deadline for future Read calls.
   559  // A zero value for t means Read will not time out.
   560  //
   561  // BUG(jb): SetReadDeadline is not implemented.
   562  func (c *Conn) SetReadDeadline(t time.Time) error {
   563  	return ErrNotImplemented
   564  }
   565  
   566  // SetWriteDeadline sets the deadline for future Write calls.
   567  // Even if write times out, it may return n > 0, indicating that
   568  // some of the data was successfully written.
   569  // A zero value for t means Write will not time out.
   570  //
   571  // BUG(jb): SetWriteDeadline is not implemented.
   572  func (c *Conn) SetWriteDeadline(t time.Time) error {
   573  	return ErrNotImplemented
   574  }
   575  
   576  type Statistics struct {
   577  	DataPacketsIn     int64
   578  	DataPacketsOut    int64
   579  	DataBytesIn       int64
   580  	DataBytesOut      int64
   581  	ResentPackets     int64
   582  	DroppedPackets    int64
   583  	OutOfOrderPackets int64
   584  }
   585  
   586  // String returns a printable represetnation of the Statistics.
   587  func (s Statistics) String() string {
   588  	return fmt.Sprintf("PktsIn: %d, PktsOut: %d, BytesIn: %d, BytesOut: %d, PktsResent: %d, PktsDropped: %d, PktsOutOfOrder: %d",
   589  		s.DataPacketsIn, s.DataPacketsOut, s.DataBytesIn, s.DataBytesOut, s.ResentPackets, s.DroppedPackets, s.OutOfOrderPackets)
   590  }
   591  
   592  // GetStatistics returns a snapsht of the current connection statistics.
   593  func (c *Conn) GetStatistics() Statistics {
   594  	return Statistics{
   595  		DataPacketsIn:     atomic.LoadInt64(&c.packetsIn),
   596  		DataPacketsOut:    atomic.LoadInt64(&c.packetsOut),
   597  		DataBytesIn:       atomic.LoadInt64(&c.bytesIn),
   598  		DataBytesOut:      atomic.LoadInt64(&c.bytesOut),
   599  		ResentPackets:     atomic.LoadInt64(&c.resentPackets),
   600  		DroppedPackets:    atomic.LoadInt64(&c.droppedPackets),
   601  		OutOfOrderPackets: atomic.LoadInt64(&c.outOfOrderPackets),
   602  	}
   603  }