github.com/geph-official/geph2@v0.22.6-0.20210211030601-f527cb59b0df/libs/niaucchi5/urtcp.go (about)

     1  package niaucchi5
     2  
     3  import (
     4  	"bufio"
     5  	"encoding/binary"
     6  	"io"
     7  	"log"
     8  	"math"
     9  	"math/rand"
    10  	"net"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/bep/debounce"
    15  	pool "github.com/libp2p/go-buffer-pool"
    16  	"gopkg.in/tomb.v1"
    17  )
    18  
    19  // alias for separating heap-managed and pool-managed slices
    20  type poolSlice []byte
    21  
    22  const maxSegmentSize = 16384
    23  const flagAck = 30001
    24  const flagPing = 40001
    25  const flagPong = 40002
    26  
    27  var delayer = debounce.New(50 * time.Millisecond)
    28  
    29  // URTCP implements "unreliable TCP". This is an unreliable PacketWire implementation over reliable net.Conn's like TCP, yet avoids excessive bufferbloat, "TCP over TCP" problems, etc.
    30  type URTCP struct {
    31  	wire     net.Conn
    32  	wireRead io.Reader
    33  
    34  	// synchronization
    35  	cvar     *sync.Cond
    36  	death    tomb.Tomb
    37  	deatherr error
    38  
    39  	// sending variables
    40  	sv struct {
    41  		sendBuffer    chan poolSlice // includes 2-byte length header
    42  		inflight      int
    43  		inflightLimit int
    44  		delivered     uint64
    45  		lastAckNano   uint64
    46  
    47  		pingSendNano   uint64
    48  		ping           uint64
    49  		pingUpdateTime time.Time
    50  
    51  		bw           float64
    52  		bwUpdateTime time.Time
    53  	}
    54  
    55  	// receiving variables
    56  	rv struct {
    57  		recvBuffer  []poolSlice
    58  		unacked     int
    59  		lastAckTime time.Time
    60  	}
    61  }
    62  
    63  // NewURTCP creates a new URTCP instance.
    64  func NewURTCP(wire net.Conn) *URTCP {
    65  	tr := &URTCP{
    66  		wire:     wire,
    67  		wireRead: bufio.NewReaderSize(wire, 4096),
    68  	}
    69  	tr.cvar = sync.NewCond(new(sync.Mutex))
    70  	tr.sv.sendBuffer = make(chan poolSlice, 1048576)
    71  	tr.sv.inflightLimit = 100 * 1024
    72  	go func() {
    73  		<-tr.death.Dying()
    74  		tr.cvar.L.Lock()
    75  		tr.deatherr = tr.death.Err()
    76  		tr.cvar.Broadcast()
    77  		tr.cvar.L.Unlock()
    78  	}()
    79  	go tr.sendLoop()
    80  	go tr.recvLoop()
    81  	return tr
    82  }
    83  
    84  // SendSegment sends a single segment.
    85  func (tr *URTCP) SendSegment(seg []byte, block bool) (err error) {
    86  	tr.cvar.L.Lock()
    87  	defer tr.cvar.L.Unlock()
    88  	defer tr.cvar.Broadcast()
    89  	if block {
    90  		for tr.sv.inflight > tr.sv.inflightLimit && tr.deatherr == nil {
    91  			tr.cvar.Wait()
    92  		}
    93  	} else {
    94  		// random early detection
    95  		minThresh := 0.8
    96  		dropProb := math.Max(0, float64(tr.sv.inflight)/float64(tr.sv.inflightLimit)-minThresh) / (1 - minThresh)
    97  		if rand.Float64() < dropProb {
    98  			return
    99  		}
   100  	}
   101  
   102  	if tr.deatherr != nil {
   103  		err = tr.deatherr
   104  		return
   105  	}
   106  	// we have free space now, enqueue for sending
   107  	segclone := poolSlice(pool.Get(len(seg) + 2))
   108  	binary.LittleEndian.PutUint16(segclone[:2], uint16(len(seg)))
   109  	copy(segclone[2:], seg)
   110  	tr.queueForSend(segclone)
   111  	// send req for ack if needed
   112  	tr.sv.inflight += len(seg)
   113  	if tr.sv.pingSendNano == 0 {
   114  		tr.sv.pingSendNano = uint64(time.Now().UnixNano())
   115  		ping := poolSlice(pool.Get(2))
   116  		binary.LittleEndian.PutUint16(ping, flagPing)
   117  		tr.queueForSend(ping)
   118  	}
   119  	return
   120  }
   121  
   122  // RecvSegment receives a single segment.
   123  func (tr *URTCP) RecvSegment(seg []byte) (n int, err error) {
   124  	tr.cvar.L.Lock()
   125  	defer tr.cvar.L.Unlock()
   126  	for len(tr.rv.recvBuffer) == 0 && tr.deatherr == nil {
   127  		tr.cvar.Wait()
   128  	}
   129  	if tr.deatherr != nil {
   130  		err = tr.deatherr
   131  		return
   132  	}
   133  	fs := tr.rv.recvBuffer[0]
   134  	tr.rv.recvBuffer = tr.rv.recvBuffer[1:]
   135  	n = copy(seg, fs)
   136  	pool.Put(fs)
   137  	return
   138  }
   139  
   140  func (tr *URTCP) queueForSend(b poolSlice) {
   141  	select {
   142  	case tr.sv.sendBuffer <- b:
   143  	default:
   144  		go func() {
   145  			select {
   146  			case tr.sv.sendBuffer <- b:
   147  			case <-tr.death.Dying():
   148  			}
   149  		}()
   150  	}
   151  }
   152  
   153  func (tr *URTCP) sendLoop() {
   154  	defer tr.wire.Close()
   155  	for {
   156  		select {
   157  		case toSend := <-tr.sv.sendBuffer:
   158  			if len(toSend) > maxSegmentSize+2 {
   159  				panic("shouldn't happen")
   160  			}
   161  			_, err := tr.wire.Write(toSend)
   162  			if err != nil {
   163  				tr.death.Kill(err)
   164  				tr.cvar.Broadcast()
   165  				return
   166  			}
   167  			pool.Put(toSend)
   168  		case <-tr.death.Dying():
   169  			return
   170  		}
   171  	}
   172  }
   173  
   174  func (tr *URTCP) forceAck() {
   175  	if tr.rv.unacked > 0 {
   176  		ack := poolSlice(pool.Get(2 + 8 + 8))
   177  		binary.LittleEndian.PutUint16(ack[:2], flagAck)
   178  		binary.LittleEndian.PutUint64(ack[2:][:8], uint64(time.Now().UnixNano()))
   179  		binary.LittleEndian.PutUint64(ack[2:][8:], uint64(tr.rv.unacked))
   180  		tr.rv.unacked = 0
   181  		tr.queueForSend(ack)
   182  		tr.rv.lastAckTime = time.Now()
   183  	}
   184  }
   185  
   186  func (tr *URTCP) delayedAck() {
   187  	delayer(func() {
   188  		tr.cvar.L.Lock()
   189  		defer tr.cvar.L.Unlock()
   190  		tr.forceAck()
   191  	})
   192  }
   193  
   194  func (tr *URTCP) recvLoop() {
   195  	defer tr.wire.Close()
   196  	defer tr.cvar.Broadcast()
   197  	defer tr.death.Kill(io.ErrClosedPipe)
   198  	lenbuf := make([]byte, 2)
   199  	for {
   200  		_, err := io.ReadFull(tr.wireRead, lenbuf)
   201  		if err != nil {
   202  			return
   203  		}
   204  		lenint := binary.LittleEndian.Uint16(lenbuf)
   205  		switch lenint {
   206  		case flagPing:
   207  			pingPkt := pool.Get(8)
   208  			binary.LittleEndian.PutUint16(pingPkt, flagPong)
   209  			tr.queueForSend(pingPkt)
   210  		case flagPong:
   211  			tr.cvar.L.Lock()
   212  			now := time.Now()
   213  			pingSample := uint64(time.Now().UnixNano()) - tr.sv.pingSendNano
   214  			if pingSample < tr.sv.ping || now.Sub(tr.sv.pingUpdateTime) > time.Second*30 {
   215  				tr.sv.ping = pingSample
   216  				tr.sv.pingUpdateTime = now
   217  			}
   218  			log.Println("********* ping sample", float64(pingSample)*1e-9*1000)
   219  			tr.sv.pingSendNano = 0
   220  			tr.cvar.L.Unlock()
   221  		case flagAck:
   222  			rUnixNanoB := pool.Get(8 + 8)
   223  			_, err = io.ReadFull(tr.wireRead, rUnixNanoB)
   224  			if err != nil {
   225  				return
   226  			}
   227  			rUnixNano := binary.LittleEndian.Uint64(rUnixNanoB[:8])
   228  			rAckCount := binary.LittleEndian.Uint64(rUnixNanoB[8:])
   229  			pool.Put(rUnixNanoB)
   230  
   231  			tr.cvar.L.Lock()
   232  			//now := time.Now()
   233  
   234  			ping := float64(tr.sv.ping) * 1e-9
   235  
   236  			deltaD := float64(rAckCount)
   237  			tr.sv.delivered += uint64(deltaD)
   238  			tr.sv.inflight -= int(rAckCount)
   239  			deltaT := float64(rUnixNano - tr.sv.lastAckNano)
   240  			tr.sv.lastAckNano = rUnixNano
   241  			bwSample := 1e9 * deltaD / deltaT
   242  			// if bwSample > tr.sv.bw || now.Sub(tr.sv.bwUpdateTime).Seconds() > ping*10 {
   243  			// 	tr.sv.bw = bwSample
   244  			// 	tr.sv.bwUpdateTime = now
   245  			// }
   246  			if bwSample > tr.sv.bw {
   247  				tr.sv.bw = bwSample*0.5 + tr.sv.bw*0.5
   248  			} else {
   249  				tr.sv.bw = bwSample*0.1 + tr.sv.bw*0.9
   250  			}
   251  
   252  			Bps := tr.sv.bw
   253  			log.Println("bw sample", int(Bps/1000), "KB/s")
   254  			bdp := Bps * ping
   255  			tgtIFL := int(bdp*3) + 100*1024
   256  			tr.sv.inflightLimit = tgtIFL
   257  			// if rand.Int()%10 == 0 {
   258  			// 	tr.sv.inflightLimit = int(bdp)
   259  			// 	log.Println("SEVERE LIMITING")
   260  			// }
   261  
   262  			tr.cvar.Broadcast()
   263  			tr.cvar.L.Unlock()
   264  
   265  		default:
   266  			body := poolSlice(pool.Get(int(lenint)))
   267  			_, err = io.ReadFull(tr.wireRead, body)
   268  			if err != nil {
   269  				return
   270  			}
   271  			// notify the world
   272  			tr.cvar.L.Lock()
   273  			tr.rv.recvBuffer = append(tr.rv.recvBuffer, body)
   274  			tr.rv.unacked += len(body)
   275  			if time.Since(tr.rv.lastAckTime) > time.Millisecond*20 {
   276  				tr.forceAck()
   277  			} else {
   278  				tr.delayedAck()
   279  			}
   280  			tr.cvar.Broadcast()
   281  			tr.cvar.L.Unlock()
   282  		}
   283  	}
   284  }