github.com/geph-official/geph2@v0.22.6-0.20210211030601-f527cb59b0df/libs/backedtcp/backed.go (about)

     1  package backedtcp
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"io"
     8  	"net"
     9  	"strings"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	pool "github.com/libp2p/go-buffer-pool"
    15  	"gopkg.in/tomb.v1"
    16  )
    17  
    18  const maxBufferSize = 500 * 1000
    19  
    20  type backedWriter struct {
    21  	lastsn uint64
    22  	buffer []byte
    23  	lk     sync.Mutex
    24  }
    25  
    26  func (br *backedWriter) addData(ob []byte) {
    27  	b := make([]byte, len(ob))
    28  	copy(b, ob)
    29  	br.lastsn += uint64(len(b))
    30  	br.buffer = append(br.buffer, b...)
    31  	//log.Println("addData buffer size now", cap(br.buffer))
    32  	if len(br.buffer) >= maxBufferSize {
    33  		br.buffer = br.buffer[len(br.buffer)-maxBufferSize:]
    34  	}
    35  }
    36  
    37  func (br *backedWriter) reset() {
    38  	if len(br.buffer) > 10*1000 {
    39  		newlen := len(br.buffer) / 2
    40  		nbuf := make([]byte, newlen)
    41  		copy(nbuf, br.buffer[len(br.buffer)-newlen:])
    42  		br.buffer = nbuf
    43  	}
    44  }
    45  
    46  func (br *backedWriter) since(sn uint64) []byte {
    47  	if sn == br.lastsn {
    48  		return make([]byte, 0)
    49  	}
    50  	if sn > br.lastsn || sn < br.lastsn-uint64(len(br.buffer)) {
    51  		return nil
    52  	}
    53  	return br.buffer[len(br.buffer)-int(br.lastsn-sn):]
    54  }
    55  
    56  // Socket represents a single BackedTCP connection
    57  type Socket struct {
    58  	bw          backedWriter
    59  	getWire     func() (net.Conn, error)
    60  	cachedWires chan net.Conn
    61  	chWrite     chan []byte
    62  	chRead      chan []byte
    63  	chReplace   chan struct{}
    64  	readBuf     bytes.Buffer
    65  	readBytes   uint64
    66  	death       tomb.Tomb
    67  	remAddr     atomic.Value
    68  	locAddr     atomic.Value
    69  
    70  	rDeadline atomic.Value
    71  	wDeadline atomic.Value
    72  }
    73  
    74  // NewSocket constructs a new BackedTCP connection.
    75  func NewSocket(getWire func() (net.Conn, error)) *Socket {
    76  	s := &Socket{
    77  		getWire:     getWire,
    78  		chWrite:     make(chan []byte),
    79  		cachedWires: make(chan net.Conn, 10000),
    80  		chRead:      make(chan []byte),
    81  		chReplace:   make(chan struct{}),
    82  	}
    83  	s.SetDeadline(time.Time{})
    84  	go s.mainLoop()
    85  	return s
    86  }
    87  
    88  func (sock *Socket) realGetWire() (net.Conn, error) {
    89  	select {
    90  	case c := <-sock.cachedWires:
    91  		return c, nil
    92  	default:
    93  		return sock.getWire()
    94  	}
    95  }
    96  
    97  func (sock *Socket) mainLoop() {
    98  	for {
    99  		select {
   100  		case <-sock.death.Dying():
   101  			return
   102  		default:
   103  		}
   104  		// first we get a wire
   105  		wire, err := sock.realGetWire()
   106  		if err != nil {
   107  			// this is fatal
   108  			sock.death.Kill(err)
   109  			return
   110  		}
   111  		wra := wire.RemoteAddr()
   112  		wla := wire.LocalAddr()
   113  		sock.remAddr.Store(&wra)
   114  		sock.locAddr.Store(&wla)
   115  		stopWrite := make(chan struct{})
   116  		// negotiation shouldn't take more than 10 secs
   117  		wire.SetDeadline(time.Now().Add(time.Second * 10))
   118  		sent := make(chan bool)
   119  		// negotiate
   120  		go func() {
   121  			defer close(sent)
   122  			// we write our total bytes read. in a new goroutine to prevent dedlock
   123  			binary.Write(wire, binary.BigEndian, sock.readBytes)
   124  		}()
   125  		// read the remote bytes read
   126  		var theirReadBytes uint64
   127  		err = binary.Read(wire, binary.BigEndian, &theirReadBytes)
   128  		if err != nil {
   129  			wire.Close()
   130  			continue
   131  		}
   132  		<-sent
   133  		// get the data that needs to be resent
   134  		toResend := sock.bw.since(theirReadBytes)
   135  		if toResend == nil {
   136  			// out of range
   137  			sock.death.Kill(errors.New("out of resumption range"))
   138  			return
   139  		}
   140  		wire.SetDeadline(time.Time{})
   141  		done := make(chan bool)
   142  		go func() {
   143  			defer close(done)
   144  			defer close(stopWrite)
   145  			sock.readLoop(wire)
   146  		}()
   147  		sock.writeLoop(toResend, wire, stopWrite)
   148  		<-done
   149  	}
   150  }
   151  
   152  func (sock *Socket) writeLoop(toResend []byte, wire net.Conn, stopWrite chan struct{}) {
   153  	defer wire.Close()
   154  	wire.SetWriteDeadline(sock.wDeadline.Load().(time.Time))
   155  	_, err := wire.Write(toResend)
   156  	if err != nil {
   157  		return
   158  	}
   159  	for {
   160  		var timeout <-chan time.Time
   161  		if sock.bw.buffer != nil {
   162  			timeout = time.After(time.Second * 10)
   163  		}
   164  		select {
   165  		case toWrite := <-sock.chWrite:
   166  			// first we remember this so that we can restore
   167  			sock.bw.addData(toWrite)
   168  			wire.SetWriteDeadline(sock.wDeadline.Load().(time.Time))
   169  			// then we try to write. it's okay if we fail!
   170  			_, err := wire.Write(toWrite)
   171  			pool.GlobalPool.Put(toWrite)
   172  			if err != nil {
   173  				if strings.Contains(err.Error(), "timeout") {
   174  					sock.death.Kill(err)
   175  				}
   176  				return
   177  			}
   178  		case <-timeout:
   179  			sock.bw.reset()
   180  		case <-stopWrite:
   181  			//log.Println("writeLoop stopped")
   182  			return
   183  		case <-sock.chReplace:
   184  			//log.Println("writeLoop stopped for replace")
   185  			return
   186  		case <-sock.death.Dying():
   187  			//log.Println("writeLoop forced to die", sock.death.Err())
   188  			return
   189  		}
   190  	}
   191  }
   192  
   193  func (sock *Socket) readLoop(wire net.Conn) {
   194  	defer wire.Close()
   195  	// just loop and read and feed into the channel
   196  	for {
   197  		wire.SetReadDeadline(sock.rDeadline.Load().(time.Time))
   198  		buf := pool.GlobalPool.Get(65536)
   199  		n, err := wire.Read(buf)
   200  		if err != nil {
   201  			return
   202  		}
   203  		sock.readBytes += uint64(n)
   204  		sock.chRead <- buf[:n]
   205  	}
   206  }
   207  
   208  // Reset forces the socket to discard its underlying connection and reconnect.
   209  func (sock *Socket) Reset() (err error) {
   210  	wire, err := sock.getWire()
   211  	if err != nil {
   212  		return
   213  	}
   214  	sock.cachedWires <- wire
   215  	select {
   216  	case sock.chReplace <- struct{}{}:
   217  		return
   218  	case <-sock.death.Dying():
   219  		err = sock.death.Err()
   220  		return
   221  	}
   222  }
   223  
   224  // Close closes the socket.
   225  func (sock *Socket) Close() (err error) {
   226  	sock.death.Kill(io.ErrClosedPipe)
   227  	return
   228  }
   229  
   230  func (sock *Socket) Read(p []byte) (n int, err error) {
   231  	for {
   232  		if sock.readBuf.Len() > 0 {
   233  			return sock.readBuf.Read(p)
   234  		}
   235  		select {
   236  		case <-sock.death.Dying():
   237  			err = sock.death.Err()
   238  			return
   239  		case bts := <-sock.chRead:
   240  			sock.readBuf.Write(bts)
   241  			pool.GlobalPool.Put(bts)
   242  		}
   243  	}
   244  }
   245  
   246  func (sock *Socket) Write(p []byte) (n int, err error) {
   247  	buf := pool.GlobalPool.Get(len(p))
   248  	copy(buf, p)
   249  	select {
   250  	case sock.chWrite <- buf:
   251  		n = len(p)
   252  		return
   253  	case <-sock.death.Dying():
   254  		err = sock.death.Err()
   255  		return
   256  	}
   257  }
   258  
   259  func (sock *Socket) LocalAddr() net.Addr {
   260  	zz := sock.locAddr.Load()
   261  	if zz == nil {
   262  		return dummyAddr("dummy-local")
   263  	}
   264  	return *(zz.(*net.Addr))
   265  }
   266  
   267  func (sock *Socket) RemoteAddr() net.Addr {
   268  	zz := sock.remAddr.Load()
   269  	if zz == nil {
   270  		return dummyAddr("dummy-remote")
   271  	}
   272  	return *(zz.(*net.Addr))
   273  }
   274  
   275  func (sock *Socket) SetDeadline(t time.Time) error {
   276  	sock.SetReadDeadline(t)
   277  	sock.SetWriteDeadline(t)
   278  	return nil
   279  }
   280  
   281  func (sock *Socket) SetReadDeadline(t time.Time) error {
   282  	sock.rDeadline.Store(t)
   283  	return nil
   284  }
   285  
   286  func (sock *Socket) SetWriteDeadline(t time.Time) error {
   287  	sock.wDeadline.Store(t)
   288  	return nil
   289  }
   290  
   291  type dummyAddr string
   292  
   293  func (da dummyAddr) String() string {
   294  	return string(da)
   295  }
   296  
   297  func (da dummyAddr) Network() string {
   298  	return string(da)
   299  }