github.com/sagernet/quic-go@v0.43.1-beta.1/ech/packet_unpacker.go (about) 1 package quic 2 3 import ( 4 "bytes" 5 "fmt" 6 "time" 7 8 "github.com/sagernet/quic-go/internal/handshake_ech" 9 "github.com/sagernet/quic-go/internal/protocol" 10 "github.com/sagernet/quic-go/internal/qerr" 11 "github.com/sagernet/quic-go/internal/wire" 12 ) 13 14 type headerDecryptor interface { 15 DecryptHeader(sample []byte, firstByte *byte, pnBytes []byte) 16 } 17 18 type headerParseError struct { 19 err error 20 } 21 22 func (e *headerParseError) Unwrap() error { 23 return e.err 24 } 25 26 func (e *headerParseError) Error() string { 27 return e.err.Error() 28 } 29 30 type unpackedPacket struct { 31 hdr *wire.ExtendedHeader 32 encryptionLevel protocol.EncryptionLevel 33 data []byte 34 } 35 36 // The packetUnpacker unpacks QUIC packets. 37 type packetUnpacker struct { 38 cs handshake.CryptoSetup 39 40 shortHdrConnIDLen int 41 } 42 43 var _ unpacker = &packetUnpacker{} 44 45 func newPacketUnpacker(cs handshake.CryptoSetup, shortHdrConnIDLen int) *packetUnpacker { 46 return &packetUnpacker{ 47 cs: cs, 48 shortHdrConnIDLen: shortHdrConnIDLen, 49 } 50 } 51 52 // UnpackLongHeader unpacks a Long Header packet. 53 // If the reserved bits are invalid, the error is wire.ErrInvalidReservedBits. 54 // If any other error occurred when parsing the header, the error is of type headerParseError. 55 // If decrypting the payload fails for any reason, the error is the error returned by the AEAD. 56 func (u *packetUnpacker) UnpackLongHeader(hdr *wire.Header, rcvTime time.Time, data []byte, v protocol.Version) (*unpackedPacket, error) { 57 var encLevel protocol.EncryptionLevel 58 var extHdr *wire.ExtendedHeader 59 var decrypted []byte 60 //nolint:exhaustive // Retry packets can't be unpacked. 61 switch hdr.Type { 62 case protocol.PacketTypeInitial: 63 encLevel = protocol.EncryptionInitial 64 opener, err := u.cs.GetInitialOpener() 65 if err != nil { 66 return nil, err 67 } 68 extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v) 69 if err != nil { 70 return nil, err 71 } 72 case protocol.PacketTypeHandshake: 73 encLevel = protocol.EncryptionHandshake 74 opener, err := u.cs.GetHandshakeOpener() 75 if err != nil { 76 return nil, err 77 } 78 extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v) 79 if err != nil { 80 return nil, err 81 } 82 case protocol.PacketType0RTT: 83 encLevel = protocol.Encryption0RTT 84 opener, err := u.cs.Get0RTTOpener() 85 if err != nil { 86 return nil, err 87 } 88 extHdr, decrypted, err = u.unpackLongHeaderPacket(opener, hdr, data, v) 89 if err != nil { 90 return nil, err 91 } 92 default: 93 return nil, fmt.Errorf("unknown packet type: %s", hdr.Type) 94 } 95 96 if len(decrypted) == 0 { 97 return nil, &qerr.TransportError{ 98 ErrorCode: qerr.ProtocolViolation, 99 ErrorMessage: "empty packet", 100 } 101 } 102 103 return &unpackedPacket{ 104 hdr: extHdr, 105 encryptionLevel: encLevel, 106 data: decrypted, 107 }, nil 108 } 109 110 func (u *packetUnpacker) UnpackShortHeader(rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) { 111 opener, err := u.cs.Get1RTTOpener() 112 if err != nil { 113 return 0, 0, 0, nil, err 114 } 115 pn, pnLen, kp, decrypted, err := u.unpackShortHeaderPacket(opener, rcvTime, data) 116 if err != nil { 117 return 0, 0, 0, nil, err 118 } 119 if len(decrypted) == 0 { 120 return 0, 0, 0, nil, &qerr.TransportError{ 121 ErrorCode: qerr.ProtocolViolation, 122 ErrorMessage: "empty packet", 123 } 124 } 125 return pn, pnLen, kp, decrypted, nil 126 } 127 128 func (u *packetUnpacker) unpackLongHeaderPacket(opener handshake.LongHeaderOpener, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, []byte, error) { 129 extHdr, parseErr := u.unpackLongHeader(opener, hdr, data, v) 130 // If the reserved bits are set incorrectly, we still need to continue unpacking. 131 // This avoids a timing side-channel, which otherwise might allow an attacker 132 // to gain information about the header encryption. 133 if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { 134 return nil, nil, parseErr 135 } 136 extHdrLen := extHdr.ParsedLen() 137 extHdr.PacketNumber = opener.DecodePacketNumber(extHdr.PacketNumber, extHdr.PacketNumberLen) 138 decrypted, err := opener.Open(data[extHdrLen:extHdrLen], data[extHdrLen:], extHdr.PacketNumber, data[:extHdrLen]) 139 if err != nil { 140 return nil, nil, err 141 } 142 if parseErr != nil { 143 return nil, nil, parseErr 144 } 145 return extHdr, decrypted, nil 146 } 147 148 func (u *packetUnpacker) unpackShortHeaderPacket(opener handshake.ShortHeaderOpener, rcvTime time.Time, data []byte) (protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, []byte, error) { 149 l, pn, pnLen, kp, parseErr := u.unpackShortHeader(opener, data) 150 // If the reserved bits are set incorrectly, we still need to continue unpacking. 151 // This avoids a timing side-channel, which otherwise might allow an attacker 152 // to gain information about the header encryption. 153 if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { 154 return 0, 0, 0, nil, &headerParseError{parseErr} 155 } 156 pn = opener.DecodePacketNumber(pn, pnLen) 157 decrypted, err := opener.Open(data[l:l], data[l:], rcvTime, pn, kp, data[:l]) 158 if err != nil { 159 return 0, 0, 0, nil, err 160 } 161 return pn, pnLen, kp, decrypted, parseErr 162 } 163 164 func (u *packetUnpacker) unpackShortHeader(hd headerDecryptor, data []byte) (int, protocol.PacketNumber, protocol.PacketNumberLen, protocol.KeyPhaseBit, error) { 165 hdrLen := 1 /* first header byte */ + u.shortHdrConnIDLen 166 if len(data) < hdrLen+4+16 { 167 return 0, 0, 0, 0, fmt.Errorf("packet too small, expected at least 20 bytes after the header, got %d", len(data)-hdrLen) 168 } 169 origPNBytes := make([]byte, 4) 170 copy(origPNBytes, data[hdrLen:hdrLen+4]) 171 // 2. decrypt the header, assuming a 4 byte packet number 172 hd.DecryptHeader( 173 data[hdrLen+4:hdrLen+4+16], 174 &data[0], 175 data[hdrLen:hdrLen+4], 176 ) 177 // 3. parse the header (and learn the actual length of the packet number) 178 l, pn, pnLen, kp, parseErr := wire.ParseShortHeader(data, u.shortHdrConnIDLen) 179 if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { 180 return l, pn, pnLen, kp, parseErr 181 } 182 // 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier 183 if pnLen != protocol.PacketNumberLen4 { 184 copy(data[hdrLen+int(pnLen):hdrLen+4], origPNBytes[int(pnLen):]) 185 } 186 return l, pn, pnLen, kp, parseErr 187 } 188 189 // The error is either nil, a wire.ErrInvalidReservedBits or of type headerParseError. 190 func (u *packetUnpacker) unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) { 191 extHdr, err := unpackLongHeader(hd, hdr, data, v) 192 if err != nil && err != wire.ErrInvalidReservedBits { 193 return nil, &headerParseError{err: err} 194 } 195 return extHdr, err 196 } 197 198 func unpackLongHeader(hd headerDecryptor, hdr *wire.Header, data []byte, v protocol.Version) (*wire.ExtendedHeader, error) { 199 r := bytes.NewReader(data) 200 201 hdrLen := hdr.ParsedLen() 202 if protocol.ByteCount(len(data)) < hdrLen+4+16 { 203 //nolint:stylecheck 204 return nil, fmt.Errorf("Packet too small. Expected at least 20 bytes after the header, got %d", protocol.ByteCount(len(data))-hdrLen) 205 } 206 // The packet number can be up to 4 bytes long, but we won't know the length until we decrypt it. 207 // 1. save a copy of the 4 bytes 208 origPNBytes := make([]byte, 4) 209 copy(origPNBytes, data[hdrLen:hdrLen+4]) 210 // 2. decrypt the header, assuming a 4 byte packet number 211 hd.DecryptHeader( 212 data[hdrLen+4:hdrLen+4+16], 213 &data[0], 214 data[hdrLen:hdrLen+4], 215 ) 216 // 3. parse the header (and learn the actual length of the packet number) 217 extHdr, parseErr := hdr.ParseExtended(r, v) 218 if parseErr != nil && parseErr != wire.ErrInvalidReservedBits { 219 return nil, parseErr 220 } 221 // 4. if the packet number is shorter than 4 bytes, replace the remaining bytes with the copy we saved earlier 222 if extHdr.PacketNumberLen != protocol.PacketNumberLen4 { 223 copy(data[extHdr.ParsedLen():hdrLen+4], origPNBytes[int(extHdr.PacketNumberLen):]) 224 } 225 return extHdr, parseErr 226 }