golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/quic/packet_test.go (about)

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build go1.21
     6  
     7  package quic
     8  
     9  import (
    10  	"bytes"
    11  	"encoding/binary"
    12  	"encoding/hex"
    13  	"reflect"
    14  	"strings"
    15  	"testing"
    16  )
    17  
    18  func TestPacketHeader(t *testing.T) {
    19  	for _, test := range []struct {
    20  		name         string
    21  		packet       []byte
    22  		isLongHeader bool
    23  		packetType   packetType
    24  		dstConnID    []byte
    25  	}{{
    26  		// Initial packet from https://www.rfc-editor.org/rfc/rfc9001#section-a.1
    27  		// (truncated)
    28  		name: "rfc9001_a1",
    29  		packet: unhex(`
    30  			c000000001088394c8f03e5157080000 449e7b9aec34d1b1c98dd7689fb8ec11
    31  		`),
    32  		isLongHeader: true,
    33  		packetType:   packetTypeInitial,
    34  		dstConnID:    unhex(`8394c8f03e515708`),
    35  	}, {
    36  		// Initial packet from https://www.rfc-editor.org/rfc/rfc9001#section-a.3
    37  		// (truncated)
    38  		name: "rfc9001_a3",
    39  		packet: unhex(`
    40  			cf000000010008f067a5502a4262b500 4075c0d95a482cd0991cd25b0aac406a
    41  		`),
    42  		isLongHeader: true,
    43  		packetType:   packetTypeInitial,
    44  		dstConnID:    []byte{},
    45  	}, {
    46  		// Retry packet from https://www.rfc-editor.org/rfc/rfc9001#section-a.4
    47  		name: "rfc9001_a4",
    48  		packet: unhex(`
    49  			ff000000010008f067a5502a4262b574 6f6b656e04a265ba2eff4d829058fb3f
    50  			0f2496ba
    51  		`),
    52  		isLongHeader: true,
    53  		packetType:   packetTypeRetry,
    54  		dstConnID:    []byte{},
    55  	}, {
    56  		// Short header packet from https://www.rfc-editor.org/rfc/rfc9001#section-a.5
    57  		name: "rfc9001_a5",
    58  		packet: unhex(`
    59  			4cfe4189655e5cd55c41f69080575d7999c25a5bfb
    60  		`),
    61  		isLongHeader: false,
    62  		packetType:   packetType1RTT,
    63  		dstConnID:    unhex(`fe4189655e5cd55c`),
    64  	}, {
    65  		// Version Negotiation packet.
    66  		name: "version_negotiation",
    67  		packet: unhex(`
    68  			80 00000000 01ff0001020304
    69  		`),
    70  		isLongHeader: true,
    71  		packetType:   packetTypeVersionNegotiation,
    72  		dstConnID:    []byte{0xff},
    73  	}, {
    74  		// Too-short packet.
    75  		name: "truncated_after_connid_length",
    76  		packet: unhex(`
    77  			cf0000000105
    78  		`),
    79  		isLongHeader: true,
    80  		packetType:   packetTypeInitial,
    81  		dstConnID:    nil,
    82  	}, {
    83  		// Too-short packet.
    84  		name: "truncated_after_version",
    85  		packet: unhex(`
    86  			cf00000001
    87  		`),
    88  		isLongHeader: true,
    89  		packetType:   packetTypeInitial,
    90  		dstConnID:    nil,
    91  	}, {
    92  		// Much too short packet.
    93  		name: "truncated_in_version",
    94  		packet: unhex(`
    95  			cf000000
    96  		`),
    97  		isLongHeader: true,
    98  		packetType:   packetTypeInvalid,
    99  		dstConnID:    nil,
   100  	}} {
   101  		t.Run(test.name, func(t *testing.T) {
   102  			if got, want := isLongHeader(test.packet[0]), test.isLongHeader; got != want {
   103  				t.Errorf("packet %x:\nisLongHeader(packet) = %v, want %v", test.packet, got, want)
   104  			}
   105  			if got, want := getPacketType(test.packet), test.packetType; got != want {
   106  				t.Errorf("packet %x:\ngetPacketType(packet) = %v, want %v", test.packet, got, want)
   107  			}
   108  			gotConnID, gotOK := dstConnIDForDatagram(test.packet)
   109  			wantConnID, wantOK := test.dstConnID, test.dstConnID != nil
   110  			if !bytes.Equal(gotConnID, wantConnID) || gotOK != wantOK {
   111  				t.Errorf("packet %x:\ndstConnIDForDatagram(packet) = {%x}, %v; want {%x}, %v", test.packet, gotConnID, gotOK, wantConnID, wantOK)
   112  			}
   113  		})
   114  	}
   115  }
   116  
   117  func TestEncodeDecodeVersionNegotiation(t *testing.T) {
   118  	dstConnID := []byte("this is a very long destination connection id")
   119  	srcConnID := []byte("this is a very long source connection id")
   120  	versions := []uint32{1, 0xffffffff}
   121  	got := appendVersionNegotiation([]byte{}, dstConnID, srcConnID, versions...)
   122  	want := bytes.Join([][]byte{{
   123  		0b1100_0000, // header byte
   124  		0, 0, 0, 0,  // Version
   125  		byte(len(dstConnID)),
   126  	}, dstConnID, {
   127  		byte(len(srcConnID)),
   128  	}, srcConnID, {
   129  		0x00, 0x00, 0x00, 0x01,
   130  		0xff, 0xff, 0xff, 0xff,
   131  	}}, nil)
   132  	if !bytes.Equal(got, want) {
   133  		t.Fatalf("appendVersionNegotiation(nil, %x, %x, %v):\ngot  %x\nwant %x",
   134  			dstConnID, srcConnID, versions, got, want)
   135  	}
   136  	gotDst, gotSrc, gotVersionBytes := parseVersionNegotiation(got)
   137  	if got, want := gotDst, dstConnID; !bytes.Equal(got, want) {
   138  		t.Errorf("parseVersionNegotiation: got dstConnID = %x, want %x", got, want)
   139  	}
   140  	if got, want := gotSrc, srcConnID; !bytes.Equal(got, want) {
   141  		t.Errorf("parseVersionNegotiation: got srcConnID = %x, want %x", got, want)
   142  	}
   143  	var gotVersions []uint32
   144  	for len(gotVersionBytes) >= 4 {
   145  		gotVersions = append(gotVersions, binary.BigEndian.Uint32(gotVersionBytes))
   146  		gotVersionBytes = gotVersionBytes[4:]
   147  	}
   148  	if got, want := gotVersions, versions; !reflect.DeepEqual(got, want) {
   149  		t.Errorf("parseVersionNegotiation: got versions = %v, want %v", got, want)
   150  	}
   151  }
   152  
   153  func TestParseGenericLongHeaderPacket(t *testing.T) {
   154  	for _, test := range []struct {
   155  		name      string
   156  		packet    []byte
   157  		version   uint32
   158  		dstConnID []byte
   159  		srcConnID []byte
   160  		data      []byte
   161  	}{{
   162  		name: "long header packet",
   163  		packet: unhex(`
   164  			80 01020304 04a1a2a3a4 05b1b2b3b4b5 c1
   165  		`),
   166  		version:   0x01020304,
   167  		dstConnID: unhex(`a1a2a3a4`),
   168  		srcConnID: unhex(`b1b2b3b4b5`),
   169  		data:      unhex(`c1`),
   170  	}, {
   171  		name: "zero everything",
   172  		packet: unhex(`
   173  			80 00000000 00 00
   174  		`),
   175  		version:   0,
   176  		dstConnID: []byte{},
   177  		srcConnID: []byte{},
   178  		data:      []byte{},
   179  	}} {
   180  		t.Run(test.name, func(t *testing.T) {
   181  			p, ok := parseGenericLongHeaderPacket(test.packet)
   182  			if !ok {
   183  				t.Fatalf("parseGenericLongHeaderPacket() = _, false; want true")
   184  			}
   185  			if got, want := p.version, test.version; got != want {
   186  				t.Errorf("version = %v, want %v", got, want)
   187  			}
   188  			if got, want := p.dstConnID, test.dstConnID; !bytes.Equal(got, want) {
   189  				t.Errorf("Destination Connection ID = {%x}, want {%x}", got, want)
   190  			}
   191  			if got, want := p.srcConnID, test.srcConnID; !bytes.Equal(got, want) {
   192  				t.Errorf("Source Connection ID = {%x}, want {%x}", got, want)
   193  			}
   194  			if got, want := p.data, test.data; !bytes.Equal(got, want) {
   195  				t.Errorf("Data = {%x}, want {%x}", got, want)
   196  			}
   197  		})
   198  	}
   199  }
   200  
   201  func TestParseGenericLongHeaderPacketErrors(t *testing.T) {
   202  	for _, test := range []struct {
   203  		name   string
   204  		packet []byte
   205  	}{{
   206  		name: "short header packet",
   207  		packet: unhex(`
   208  			00 01020304 04a1a2a3a4 05b1b2b3b4b5 c1
   209  		`),
   210  	}, {
   211  		name: "packet too short",
   212  		packet: unhex(`
   213  			80 000000
   214  		`),
   215  	}, {
   216  		name: "destination id too long",
   217  		packet: unhex(`
   218  			80 00000000 02 00
   219  		`),
   220  	}, {
   221  		name: "source id too long",
   222  		packet: unhex(`
   223  			80 00000000 00 01
   224  		`),
   225  	}} {
   226  		t.Run(test.name, func(t *testing.T) {
   227  			_, ok := parseGenericLongHeaderPacket(test.packet)
   228  			if ok {
   229  				t.Fatalf("parseGenericLongHeaderPacket() = _, true; want false")
   230  			}
   231  		})
   232  	}
   233  }
   234  
   235  func unhex(s string) []byte {
   236  	b, err := hex.DecodeString(strings.Map(func(c rune) rune {
   237  		switch c {
   238  		case ' ', '\t', '\n':
   239  			return -1
   240  		}
   241  		return c
   242  	}, s))
   243  	if err != nil {
   244  		panic(err)
   245  	}
   246  	return b
   247  }