github.com/Psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/quic/gquic-go/internal/wire/header.go (about)

     1  package wire
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/rand"
     6  	"errors"
     7  	"fmt"
     8  
     9  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/protocol"
    10  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/quic/gquic-go/internal/utils"
    11  )
    12  
    13  // Header is the header of a QUIC packet.
    14  // It contains fields that are only needed for the gQUIC Public Header and the IETF draft Header.
    15  type Header struct {
    16  	IsPublicHeader bool
    17  
    18  	Raw []byte
    19  
    20  	Version protocol.VersionNumber
    21  
    22  	DestConnectionID     protocol.ConnectionID
    23  	SrcConnectionID      protocol.ConnectionID
    24  	OrigDestConnectionID protocol.ConnectionID // only needed in the Retry packet
    25  
    26  	PacketNumberLen protocol.PacketNumberLen
    27  	PacketNumber    protocol.PacketNumber
    28  
    29  	IsVersionNegotiation bool
    30  	SupportedVersions    []protocol.VersionNumber // Version Number sent in a Version Negotiation Packet by the server
    31  
    32  	// only needed for the gQUIC Public Header
    33  	VersionFlag          bool
    34  	ResetFlag            bool
    35  	DiversificationNonce []byte
    36  
    37  	// only needed for the IETF Header
    38  	Type         protocol.PacketType
    39  	IsLongHeader bool
    40  	KeyPhase     int
    41  	PayloadLen   protocol.ByteCount
    42  	Token        []byte
    43  }
    44  
    45  var errInvalidPacketNumberLen = errors.New("invalid packet number length")
    46  
    47  // Write writes the Header.
    48  func (h *Header) Write(b *bytes.Buffer, pers protocol.Perspective, ver protocol.VersionNumber) error {
    49  	if !ver.UsesIETFHeaderFormat() {
    50  		h.IsPublicHeader = true // save that this is a Public Header, so we can log it correctly later
    51  		return h.writePublicHeader(b, pers, ver)
    52  	}
    53  	// write an IETF QUIC header
    54  	if h.IsLongHeader {
    55  		return h.writeLongHeader(b, ver)
    56  	}
    57  	return h.writeShortHeader(b, ver)
    58  }
    59  
    60  // TODO: add support for the key phase
    61  func (h *Header) writeLongHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
    62  	b.WriteByte(byte(0x80 | h.Type))
    63  	utils.BigEndian.WriteUint32(b, uint32(h.Version))
    64  	connIDLen, err := encodeConnIDLen(h.DestConnectionID, h.SrcConnectionID)
    65  	if err != nil {
    66  		return err
    67  	}
    68  	b.WriteByte(connIDLen)
    69  	b.Write(h.DestConnectionID.Bytes())
    70  	b.Write(h.SrcConnectionID.Bytes())
    71  
    72  	if h.Type == protocol.PacketTypeInitial && v.UsesTokenInHeader() {
    73  		utils.WriteVarInt(b, uint64(len(h.Token)))
    74  		b.Write(h.Token)
    75  	}
    76  
    77  	if h.Type == protocol.PacketTypeRetry {
    78  		odcil, err := encodeSingleConnIDLen(h.OrigDestConnectionID)
    79  		if err != nil {
    80  			return err
    81  		}
    82  		// randomize the first 4 bits
    83  		odcilByte := make([]byte, 1)
    84  		_, _ = rand.Read(odcilByte) // it's safe to ignore the error here
    85  		odcilByte[0] = (odcilByte[0] & 0xf0) | odcil
    86  		b.Write(odcilByte)
    87  		b.Write(h.OrigDestConnectionID.Bytes())
    88  		b.Write(h.Token)
    89  		return nil
    90  	}
    91  
    92  	if v.UsesLengthInHeader() {
    93  		utils.WriteVarInt(b, uint64(h.PayloadLen))
    94  	}
    95  	if v.UsesVarintPacketNumbers() {
    96  		return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
    97  	}
    98  	utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
    99  	if h.Type == protocol.PacketType0RTT && v == protocol.Version44 {
   100  		if len(h.DiversificationNonce) != 32 {
   101  			return errors.New("invalid diversification nonce length")
   102  		}
   103  		b.Write(h.DiversificationNonce)
   104  	}
   105  	return nil
   106  }
   107  
   108  func (h *Header) writeShortHeader(b *bytes.Buffer, v protocol.VersionNumber) error {
   109  	typeByte := byte(0x30)
   110  	typeByte |= byte(h.KeyPhase << 6)
   111  	if !v.UsesVarintPacketNumbers() {
   112  		switch h.PacketNumberLen {
   113  		case protocol.PacketNumberLen1:
   114  		case protocol.PacketNumberLen2:
   115  			typeByte |= 0x1
   116  		case protocol.PacketNumberLen4:
   117  			typeByte |= 0x2
   118  		default:
   119  			return errInvalidPacketNumberLen
   120  		}
   121  	}
   122  
   123  	b.WriteByte(typeByte)
   124  	b.Write(h.DestConnectionID.Bytes())
   125  
   126  	if !v.UsesVarintPacketNumbers() {
   127  		switch h.PacketNumberLen {
   128  		case protocol.PacketNumberLen1:
   129  			b.WriteByte(uint8(h.PacketNumber))
   130  		case protocol.PacketNumberLen2:
   131  			utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
   132  		case protocol.PacketNumberLen4:
   133  			utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
   134  		}
   135  		return nil
   136  	}
   137  	return utils.WriteVarIntPacketNumber(b, h.PacketNumber, h.PacketNumberLen)
   138  }
   139  
   140  // writePublicHeader writes a Public Header.
   141  func (h *Header) writePublicHeader(b *bytes.Buffer, pers protocol.Perspective, _ protocol.VersionNumber) error {
   142  	if h.ResetFlag || (h.VersionFlag && pers == protocol.PerspectiveServer) {
   143  		return errors.New("PublicHeader: Can only write regular packets")
   144  	}
   145  	if h.SrcConnectionID.Len() != 0 {
   146  		return errors.New("PublicHeader: SrcConnectionID must not be set")
   147  	}
   148  	if len(h.DestConnectionID) != 0 && len(h.DestConnectionID) != 8 {
   149  		return fmt.Errorf("PublicHeader: wrong length for Connection ID: %d (expected 8)", len(h.DestConnectionID))
   150  	}
   151  
   152  	publicFlagByte := uint8(0x00)
   153  	if h.VersionFlag {
   154  		publicFlagByte |= 0x01
   155  	}
   156  	if h.DestConnectionID.Len() > 0 {
   157  		publicFlagByte |= 0x08
   158  	}
   159  	if len(h.DiversificationNonce) > 0 {
   160  		if len(h.DiversificationNonce) != 32 {
   161  			return errors.New("invalid diversification nonce length")
   162  		}
   163  		publicFlagByte |= 0x04
   164  	}
   165  	switch h.PacketNumberLen {
   166  	case protocol.PacketNumberLen1:
   167  		publicFlagByte |= 0x00
   168  	case protocol.PacketNumberLen2:
   169  		publicFlagByte |= 0x10
   170  	case protocol.PacketNumberLen4:
   171  		publicFlagByte |= 0x20
   172  	}
   173  	b.WriteByte(publicFlagByte)
   174  
   175  	if h.DestConnectionID.Len() > 0 {
   176  		b.Write(h.DestConnectionID)
   177  	}
   178  	if h.VersionFlag && pers == protocol.PerspectiveClient {
   179  		utils.BigEndian.WriteUint32(b, uint32(h.Version))
   180  	}
   181  	if len(h.DiversificationNonce) > 0 {
   182  		b.Write(h.DiversificationNonce)
   183  	}
   184  
   185  	switch h.PacketNumberLen {
   186  	case protocol.PacketNumberLen1:
   187  		b.WriteByte(uint8(h.PacketNumber))
   188  	case protocol.PacketNumberLen2:
   189  		utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
   190  	case protocol.PacketNumberLen4:
   191  		utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
   192  	case protocol.PacketNumberLen6:
   193  		return errInvalidPacketNumberLen
   194  	default:
   195  		return errors.New("PublicHeader: PacketNumberLen not set")
   196  	}
   197  
   198  	return nil
   199  }
   200  
   201  // GetLength determines the length of the Header.
   202  func (h *Header) GetLength(v protocol.VersionNumber) (protocol.ByteCount, error) {
   203  	if !v.UsesIETFHeaderFormat() {
   204  		return h.getPublicHeaderLength()
   205  	}
   206  	return h.getHeaderLength(v)
   207  }
   208  
   209  func (h *Header) getHeaderLength(v protocol.VersionNumber) (protocol.ByteCount, error) {
   210  	if h.IsLongHeader {
   211  		length := 1 /* type byte */ + 4 /* version */ + 1 /* conn id len byte */ + protocol.ByteCount(h.DestConnectionID.Len()+h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen)
   212  		if v.UsesLengthInHeader() {
   213  			length += utils.VarIntLen(uint64(h.PayloadLen))
   214  		}
   215  		if h.Type == protocol.PacketTypeInitial && v.UsesTokenInHeader() {
   216  			length += utils.VarIntLen(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
   217  		}
   218  		if h.Type == protocol.PacketType0RTT && v == protocol.Version44 {
   219  			length += protocol.ByteCount(len(h.DiversificationNonce))
   220  		}
   221  		return length, nil
   222  	}
   223  
   224  	length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
   225  	if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
   226  		return 0, fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
   227  	}
   228  	length += protocol.ByteCount(h.PacketNumberLen)
   229  	return length, nil
   230  }
   231  
   232  // getPublicHeaderLength gets the length of the publicHeader in bytes.
   233  // It can only be called for regular packets.
   234  func (h *Header) getPublicHeaderLength() (protocol.ByteCount, error) {
   235  	length := protocol.ByteCount(1) // 1 byte for public flags
   236  	if h.PacketNumberLen == protocol.PacketNumberLen6 {
   237  		return 0, errInvalidPacketNumberLen
   238  	}
   239  	if h.PacketNumberLen != protocol.PacketNumberLen1 && h.PacketNumberLen != protocol.PacketNumberLen2 && h.PacketNumberLen != protocol.PacketNumberLen4 {
   240  		return 0, errPacketNumberLenNotSet
   241  	}
   242  	length += protocol.ByteCount(h.PacketNumberLen)
   243  	length += protocol.ByteCount(h.DestConnectionID.Len())
   244  	// Version Number in packets sent by the client
   245  	if h.VersionFlag {
   246  		length += 4
   247  	}
   248  	length += protocol.ByteCount(len(h.DiversificationNonce))
   249  	return length, nil
   250  }
   251  
   252  // Log logs the Header
   253  func (h *Header) Log(logger utils.Logger) {
   254  	if h.IsPublicHeader {
   255  		h.logPublicHeader(logger)
   256  	} else {
   257  		h.logHeader(logger)
   258  	}
   259  }
   260  
   261  func (h *Header) logHeader(logger utils.Logger) {
   262  	if h.IsLongHeader {
   263  		if h.Version == 0 {
   264  			logger.Debugf("\tVersionNegotiationPacket{DestConnectionID: %s, SrcConnectionID: %s, SupportedVersions: %s}", h.DestConnectionID, h.SrcConnectionID, h.SupportedVersions)
   265  		} else {
   266  			var token string
   267  			if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
   268  				if len(h.Token) == 0 {
   269  					token = "Token: (empty), "
   270  				} else {
   271  					token = fmt.Sprintf("Token: %#x, ", h.Token)
   272  				}
   273  			}
   274  			if h.Type == protocol.PacketTypeRetry {
   275  				logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sOrigDestConnectionID: %s, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.OrigDestConnectionID, h.Version)
   276  				return
   277  			}
   278  			if h.Version == protocol.Version44 {
   279  				var divNonce string
   280  				if h.Type == protocol.PacketType0RTT {
   281  					divNonce = fmt.Sprintf("Diversification Nonce: %#x, ", h.DiversificationNonce)
   282  				}
   283  				logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, h.PacketNumber, h.PacketNumberLen, divNonce, h.Version)
   284  				return
   285  			}
   286  			logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %#x, PacketNumberLen: %d, PayloadLen: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.PayloadLen, h.Version)
   287  		}
   288  	} else {
   289  		logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, KeyPhase: %d}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
   290  	}
   291  }
   292  
   293  func (h *Header) logPublicHeader(logger utils.Logger) {
   294  	ver := "(unset)"
   295  	if h.Version != 0 {
   296  		ver = h.Version.String()
   297  	}
   298  	logger.Debugf("\tPublic Header{ConnectionID: %s, PacketNumber: %#x, PacketNumberLen: %d, Version: %s, DiversificationNonce: %#v}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, ver, h.DiversificationNonce)
   299  }
   300  
   301  func encodeConnIDLen(dest, src protocol.ConnectionID) (byte, error) {
   302  	dcil, err := encodeSingleConnIDLen(dest)
   303  	if err != nil {
   304  		return 0, err
   305  	}
   306  	scil, err := encodeSingleConnIDLen(src)
   307  	if err != nil {
   308  		return 0, err
   309  	}
   310  	return scil | dcil<<4, nil
   311  }
   312  
   313  func encodeSingleConnIDLen(id protocol.ConnectionID) (byte, error) {
   314  	len := id.Len()
   315  	if len == 0 {
   316  		return 0, nil
   317  	}
   318  	if len < 4 || len > 18 {
   319  		return 0, fmt.Errorf("invalid connection ID length: %d bytes", len)
   320  	}
   321  	return byte(len - 3), nil
   322  }
   323  
   324  func decodeConnIDLen(enc byte) (int /*dest conn id len*/, int /*src conn id len*/) {
   325  	return decodeSingleConnIDLen(enc >> 4), decodeSingleConnIDLen(enc & 0xf)
   326  }
   327  
   328  func decodeSingleConnIDLen(enc uint8) int {
   329  	if enc == 0 {
   330  		return 0
   331  	}
   332  	return int(enc) + 3
   333  }