github.com/danielpfeifer02/quic-go-prio-packs@v0.41.0-28/internal/protocol/connection_id.go (about)

     1  package protocol
     2  
     3  import (
     4  	"crypto/rand"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  )
     9  
    10  var ErrInvalidConnectionIDLen = errors.New("invalid Connection ID length")
    11  
    12  // An ArbitraryLenConnectionID is a QUIC Connection ID able to represent Connection IDs according to RFC 8999.
    13  // Future QUIC versions might allow connection ID lengths up to 255 bytes, while QUIC v1
    14  // restricts the length to 20 bytes.
    15  type ArbitraryLenConnectionID []byte
    16  
    17  // PRIO_PACKS_TAG
    18  const PriorityConnIDLen uint8 = 16
    19  
    20  func (c ArbitraryLenConnectionID) Len() int {
    21  	return len(c)
    22  }
    23  
    24  func (c ArbitraryLenConnectionID) Bytes() []byte {
    25  	return c
    26  }
    27  
    28  func (c ArbitraryLenConnectionID) String() string {
    29  	if c.Len() == 0 {
    30  		return "(empty)"
    31  	}
    32  	return fmt.Sprintf("%x", c.Bytes())
    33  }
    34  
    35  const maxConnectionIDLen = 20
    36  
    37  // A ConnectionID in QUIC
    38  type ConnectionID struct {
    39  	b [20]byte
    40  	l uint8
    41  }
    42  
    43  // GenerateConnectionID generates a connection ID using cryptographic random
    44  func GenerateConnectionID(l int) (ConnectionID, error) {
    45  	var c ConnectionID
    46  	c.l = uint8(l)
    47  	_, err := rand.Read(c.b[:l])
    48  	return c, err
    49  }
    50  
    51  // ParseConnectionID interprets b as a Connection ID.
    52  // It panics if b is longer than 20 bytes.
    53  func ParseConnectionID(b []byte) ConnectionID {
    54  	if len(b) > maxConnectionIDLen {
    55  		panic("invalid conn id length")
    56  	}
    57  	var c ConnectionID
    58  	c.l = uint8(len(b))
    59  	copy(c.b[:c.l], b)
    60  	return c
    61  }
    62  
    63  // GenerateConnectionIDForInitial generates a connection ID for the Initial packet.
    64  // It uses a length randomly chosen between 8 and 20 bytes.
    65  func GenerateConnectionIDForInitial() (ConnectionID, error) {
    66  	// r := make([]byte, 1)
    67  	// if _, err := rand.Read(r); err != nil {
    68  	// 	return ConnectionID{}, err
    69  	// }
    70  	// l := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1)
    71  
    72  	// PRIO_PACKS_TAG
    73  	l := int(PriorityConnIDLen)
    74  	return GenerateConnectionID(l)
    75  }
    76  
    77  // ReadConnectionID reads a connection ID of length len from the given io.Reader.
    78  // It returns io.EOF if there are not enough bytes to read.
    79  func ReadConnectionID(r io.Reader, l int) (ConnectionID, error) {
    80  	var c ConnectionID
    81  	if l == 0 {
    82  		return c, nil
    83  	}
    84  	if l > maxConnectionIDLen {
    85  		return c, ErrInvalidConnectionIDLen
    86  	}
    87  	c.l = uint8(l)
    88  	_, err := io.ReadFull(r, c.b[:l])
    89  	if err == io.ErrUnexpectedEOF {
    90  		return c, io.EOF
    91  	}
    92  	return c, err
    93  }
    94  
    95  // Len returns the length of the connection ID in bytes
    96  func (c ConnectionID) Len() int {
    97  	return int(c.l)
    98  }
    99  
   100  // Bytes returns the byte representation
   101  func (c ConnectionID) Bytes() []byte {
   102  	return c.b[:c.l]
   103  }
   104  
   105  func (c ConnectionID) String() string {
   106  	if c.Len() == 0 {
   107  		return "(empty)"
   108  	}
   109  	return fmt.Sprintf("%x", c.Bytes())
   110  }
   111  
   112  type DefaultConnectionIDGenerator struct {
   113  	ConnLen int
   114  }
   115  
   116  func (d *DefaultConnectionIDGenerator) GenerateConnectionID() (ConnectionID, error) {
   117  	return GenerateConnectionID(d.ConnLen)
   118  }
   119  
   120  func (d *DefaultConnectionIDGenerator) ConnectionIDLen() int {
   121  	return d.ConnLen
   122  }
   123  
   124  // PRIO_PACKS_TAG
   125  type PriorityConnectionIDGenerator struct {
   126  	ConnLen            int
   127  	NumberOfPriorities int
   128  	PriorityCounter    int8
   129  	NextPriority       int8
   130  	NextPriorityValid  bool
   131  }
   132  
   133  func (t *PriorityConnectionIDGenerator) GenerateConnectionID() (ConnectionID, error) {
   134  
   135  	// PRIO_PACKS_TAG
   136  	// TODOME: better way than casting?
   137  	if t.ConnLen != int(PriorityConnIDLen) {
   138  		fmt.Println("Priority-Connection ID length is not 16")
   139  		return ConnectionID{}, ErrInvalidConnectionIDLen
   140  	}
   141  
   142  	// PRIO_PACKS_TAG
   143  	// this part is for specifically setting the next priority
   144  	// which is used in the case that an older connection ID is
   145  	// retired and a new one with the same priority is needed
   146  	if t.NextPriorityValid {
   147  		t.PriorityCounter = t.NextPriority
   148  		t.NextPriorityValid = false
   149  	}
   150  
   151  	var c ConnectionID
   152  	c.l = uint8(t.ConnLen)
   153  	_, err := rand.Read(c.b[1:t.ConnLen])
   154  	if err != nil {
   155  		return c, err
   156  	}
   157  
   158  	// add priority counter as the first byte of the connection ID and
   159  	c.b[0] = byte(t.PriorityCounter)
   160  
   161  	// first modulo, then increment since 0 is encoding for NoPriority and
   162  	// actual priorities start at 1 and go up to NumberOfPriorities
   163  	t.PriorityCounter = (t.PriorityCounter % int8(t.NumberOfPriorities)) + 1
   164  	return c, nil
   165  }
   166  
   167  func (t *PriorityConnectionIDGenerator) ConnectionIDLen() int {
   168  	return t.ConnLen
   169  }