github.com/geph-official/geph2@v0.22.6-0.20210211030601-f527cb59b0df/libs/kcp-go/fec.go (about)

     1  package kcp
     2  
     3  import (
     4  	"encoding/binary"
     5  	"sync/atomic"
     6  
     7  	"github.com/klauspost/reedsolomon"
     8  )
     9  
    10  const (
    11  	fecHeaderSize      = 6
    12  	fecHeaderSizePlus2 = fecHeaderSize + 2 // plus 2B data size
    13  	typeData           = 0xf1
    14  	typeParity         = 0xf2
    15  	fecExpire          = 60000
    16  )
    17  
    18  // fecPacket is a decoded FEC packet
    19  type fecPacket []byte
    20  
    21  func (bts fecPacket) seqid() uint32 { return binary.LittleEndian.Uint32(bts) }
    22  func (bts fecPacket) flag() uint16  { return binary.LittleEndian.Uint16(bts[4:]) }
    23  func (bts fecPacket) data() []byte  { return bts[6:] }
    24  
    25  // fecElement has auxcilliary time field
    26  type fecElement struct {
    27  	fecPacket
    28  	ts uint32
    29  }
    30  
    31  // fecDecoder for decoding incoming packets
    32  type fecDecoder struct {
    33  	rxlimit      int // queue size limit
    34  	dataShards   int
    35  	parityShards int
    36  	shardSize    int
    37  	rx           []fecElement // ordered receive queue
    38  
    39  	// caches
    40  	decodeCache [][]byte
    41  	flagCache   []bool
    42  
    43  	// zeros
    44  	zeros []byte
    45  
    46  	// RS decoder
    47  	codec reedsolomon.Encoder
    48  }
    49  
    50  func newFECDecoder(rxlimit, dataShards, parityShards int) *fecDecoder {
    51  	if dataShards <= 0 || parityShards <= 0 {
    52  		return nil
    53  	}
    54  	if rxlimit < dataShards+parityShards {
    55  		return nil
    56  	}
    57  
    58  	dec := new(fecDecoder)
    59  	dec.rxlimit = rxlimit
    60  	dec.dataShards = dataShards
    61  	dec.parityShards = parityShards
    62  	dec.shardSize = dataShards + parityShards
    63  	codec, err := reedsolomon.New(dataShards, parityShards)
    64  	if err != nil {
    65  		return nil
    66  	}
    67  	dec.codec = codec
    68  	dec.decodeCache = make([][]byte, dec.shardSize)
    69  	dec.flagCache = make([]bool, dec.shardSize)
    70  	dec.zeros = make([]byte, mtuLimit)
    71  	return dec
    72  }
    73  
    74  // decode a fec packet
    75  func (dec *fecDecoder) decode(in fecPacket) (recovered [][]byte) {
    76  	// insertion
    77  	n := len(dec.rx) - 1
    78  	insertIdx := 0
    79  	for i := n; i >= 0; i-- {
    80  		if in.seqid() == dec.rx[i].seqid() { // de-duplicate
    81  			return nil
    82  		} else if _itimediff(in.seqid(), dec.rx[i].seqid()) > 0 { // insertion
    83  			insertIdx = i + 1
    84  			break
    85  		}
    86  	}
    87  
    88  	// make a copy
    89  	pkt := fecPacket(xmitBuf.Get().([]byte)[:len(in)])
    90  	copy(pkt, in)
    91  	elem := fecElement{pkt, currentMs()}
    92  
    93  	// insert into ordered rx queue
    94  	if insertIdx == n+1 {
    95  		dec.rx = append(dec.rx, elem)
    96  	} else {
    97  		dec.rx = append(dec.rx, fecElement{})
    98  		copy(dec.rx[insertIdx+1:], dec.rx[insertIdx:]) // shift right
    99  		dec.rx[insertIdx] = elem
   100  	}
   101  
   102  	// shard range for current packet
   103  	shardBegin := pkt.seqid() - pkt.seqid()%uint32(dec.shardSize)
   104  	shardEnd := shardBegin + uint32(dec.shardSize) - 1
   105  
   106  	// max search range in ordered queue for current shard
   107  	searchBegin := insertIdx - int(pkt.seqid()%uint32(dec.shardSize))
   108  	if searchBegin < 0 {
   109  		searchBegin = 0
   110  	}
   111  	searchEnd := searchBegin + dec.shardSize - 1
   112  	if searchEnd >= len(dec.rx) {
   113  		searchEnd = len(dec.rx) - 1
   114  	}
   115  
   116  	// re-construct datashards
   117  	if searchEnd-searchBegin+1 >= dec.dataShards {
   118  		var numshard, numDataShard, first, maxlen int
   119  
   120  		// zero caches
   121  		shards := dec.decodeCache
   122  		shardsflag := dec.flagCache
   123  		for k := range dec.decodeCache {
   124  			shards[k] = nil
   125  			shardsflag[k] = false
   126  		}
   127  
   128  		// shard assembly
   129  		for i := searchBegin; i <= searchEnd; i++ {
   130  			seqid := dec.rx[i].seqid()
   131  			if _itimediff(seqid, shardEnd) > 0 {
   132  				break
   133  			} else if _itimediff(seqid, shardBegin) >= 0 {
   134  				shards[seqid%uint32(dec.shardSize)] = dec.rx[i].data()
   135  				shardsflag[seqid%uint32(dec.shardSize)] = true
   136  				numshard++
   137  				if dec.rx[i].flag() == typeData {
   138  					numDataShard++
   139  				}
   140  				if numshard == 1 {
   141  					first = i
   142  				}
   143  				if len(dec.rx[i].data()) > maxlen {
   144  					maxlen = len(dec.rx[i].data())
   145  				}
   146  			}
   147  		}
   148  
   149  		if numDataShard == dec.dataShards {
   150  			// case 1: no loss on data shards
   151  			dec.rx = dec.freeRange(first, numshard, dec.rx)
   152  		} else if numshard >= dec.dataShards {
   153  			// case 2: loss on data shards, but it's recoverable from parity shards
   154  			for k := range shards {
   155  				if shards[k] != nil {
   156  					dlen := len(shards[k])
   157  					shards[k] = shards[k][:maxlen]
   158  					copy(shards[k][dlen:], dec.zeros)
   159  				} else if k < dec.dataShards {
   160  					shards[k] = xmitBuf.Get().([]byte)[:0]
   161  				}
   162  			}
   163  			if err := dec.codec.ReconstructData(shards); err == nil {
   164  				for k := range shards[:dec.dataShards] {
   165  					if !shardsflag[k] {
   166  						// recovered data should be recycled
   167  						recovered = append(recovered, shards[k])
   168  					}
   169  				}
   170  			}
   171  			dec.rx = dec.freeRange(first, numshard, dec.rx)
   172  		}
   173  	}
   174  
   175  	// keep rxlimit
   176  	if len(dec.rx) > dec.rxlimit {
   177  		if dec.rx[0].flag() == typeData { // track the unrecoverable data
   178  			atomic.AddUint64(&DefaultSnmp.FECShortShards, 1)
   179  		}
   180  		dec.rx = dec.freeRange(0, 1, dec.rx)
   181  	}
   182  
   183  	// timeout policy
   184  	current := currentMs()
   185  	numExpired := 0
   186  	for k := range dec.rx {
   187  		if _itimediff(current, dec.rx[k].ts) > fecExpire {
   188  			numExpired++
   189  			continue
   190  		}
   191  		break
   192  	}
   193  	if numExpired > 0 {
   194  		dec.rx = dec.freeRange(0, numExpired, dec.rx)
   195  	}
   196  	return
   197  }
   198  
   199  // free a range of fecPacket
   200  func (dec *fecDecoder) freeRange(first, n int, q []fecElement) []fecElement {
   201  	for i := first; i < first+n; i++ { // recycle buffer
   202  		xmitBuf.Put([]byte(q[i].fecPacket))
   203  	}
   204  
   205  	if first == 0 && n < cap(q)/2 {
   206  		return q[n:]
   207  	}
   208  	copy(q[first:], q[first+n:])
   209  	return q[:len(q)-n]
   210  }
   211  
   212  // release all segments back to xmitBuf
   213  func (dec *fecDecoder) release() {
   214  	if n := len(dec.rx); n > 0 {
   215  		dec.rx = dec.freeRange(0, n, dec.rx)
   216  	}
   217  }
   218  
   219  type (
   220  	// fecEncoder for encoding outgoing packets
   221  	fecEncoder struct {
   222  		dataShards   int
   223  		parityShards int
   224  		shardSize    int
   225  		paws         uint32 // Protect Against Wrapped Sequence numbers
   226  		next         uint32 // next seqid
   227  
   228  		shardCount int // count the number of datashards collected
   229  		maxSize    int // track maximum data length in datashard
   230  
   231  		headerOffset  int // FEC header offset
   232  		payloadOffset int // FEC payload offset
   233  
   234  		// caches
   235  		shardCache  [][]byte
   236  		encodeCache [][]byte
   237  
   238  		// zeros
   239  		zeros []byte
   240  
   241  		// RS encoder
   242  		codec reedsolomon.Encoder
   243  	}
   244  )
   245  
   246  func newFECEncoder(dataShards, parityShards, offset int) *fecEncoder {
   247  	if dataShards <= 0 || parityShards <= 0 {
   248  		return nil
   249  	}
   250  	enc := new(fecEncoder)
   251  	enc.dataShards = dataShards
   252  	enc.parityShards = parityShards
   253  	enc.shardSize = dataShards + parityShards
   254  	enc.paws = 0xffffffff / uint32(enc.shardSize) * uint32(enc.shardSize)
   255  	enc.headerOffset = offset
   256  	enc.payloadOffset = enc.headerOffset + fecHeaderSize
   257  
   258  	codec, err := reedsolomon.New(dataShards, parityShards)
   259  	if err != nil {
   260  		return nil
   261  	}
   262  	enc.codec = codec
   263  
   264  	// caches
   265  	enc.encodeCache = make([][]byte, enc.shardSize)
   266  	enc.shardCache = make([][]byte, enc.shardSize)
   267  	for k := range enc.shardCache {
   268  		enc.shardCache[k] = make([]byte, mtuLimit)
   269  	}
   270  	enc.zeros = make([]byte, mtuLimit)
   271  	return enc
   272  }
   273  
   274  // encodes the packet, outputs parity shards if we have collected quorum datashards
   275  // notice: the contents of 'ps' will be re-written in successive calling
   276  func (enc *fecEncoder) encode(b []byte) (ps [][]byte) {
   277  	// The header format:
   278  	// | FEC SEQID(4B) | FEC TYPE(2B) | SIZE (2B) | PAYLOAD(SIZE-2) |
   279  	// |<-headerOffset                |<-payloadOffset
   280  	enc.markData(b[enc.headerOffset:])
   281  	binary.LittleEndian.PutUint16(b[enc.payloadOffset:], uint16(len(b[enc.payloadOffset:])))
   282  
   283  	// copy data from payloadOffset to fec shard cache
   284  	sz := len(b)
   285  	enc.shardCache[enc.shardCount] = enc.shardCache[enc.shardCount][:sz]
   286  	copy(enc.shardCache[enc.shardCount][enc.payloadOffset:], b[enc.payloadOffset:])
   287  	enc.shardCount++
   288  
   289  	// track max datashard length
   290  	if sz > enc.maxSize {
   291  		enc.maxSize = sz
   292  	}
   293  
   294  	//  Generation of Reed-Solomon Erasure Code
   295  	if enc.shardCount == enc.dataShards {
   296  		// fill '0' into the tail of each datashard
   297  		for i := 0; i < enc.dataShards; i++ {
   298  			shard := enc.shardCache[i]
   299  			slen := len(shard)
   300  			copy(shard[slen:enc.maxSize], enc.zeros)
   301  		}
   302  
   303  		// construct equal-sized slice with stripped header
   304  		cache := enc.encodeCache
   305  		for k := range cache {
   306  			cache[k] = enc.shardCache[k][enc.payloadOffset:enc.maxSize]
   307  		}
   308  
   309  		// encoding
   310  		if err := enc.codec.Encode(cache); err == nil {
   311  			ps = enc.shardCache[enc.dataShards:]
   312  			for k := range ps {
   313  				enc.markParity(ps[k][enc.headerOffset:])
   314  				ps[k] = ps[k][:enc.maxSize]
   315  			}
   316  		}
   317  
   318  		// counters resetting
   319  		enc.shardCount = 0
   320  		enc.maxSize = 0
   321  	}
   322  
   323  	return
   324  }
   325  
   326  func (enc *fecEncoder) markData(data []byte) {
   327  	binary.LittleEndian.PutUint32(data, enc.next)
   328  	binary.LittleEndian.PutUint16(data[4:], typeData)
   329  	enc.next++
   330  }
   331  
   332  func (enc *fecEncoder) markParity(data []byte) {
   333  	binary.LittleEndian.PutUint32(data, enc.next)
   334  	binary.LittleEndian.PutUint16(data[4:], typeParity)
   335  	// sequence wrap will only happen at parity shard
   336  	enc.next = (enc.next + 1) % enc.paws
   337  }