golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/quic/packet_number_test.go (about)

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  //go:build go1.21
     6  
     7  package quic
     8  
     9  import (
    10  	"bytes"
    11  	"encoding/binary"
    12  	"testing"
    13  )
    14  
    15  func TestDecodePacketNumber(t *testing.T) {
    16  	for _, test := range []struct {
    17  		largest   packetNumber
    18  		truncated packetNumber
    19  		want      packetNumber
    20  		size      int
    21  	}{{
    22  		largest:   0,
    23  		truncated: 1,
    24  		size:      4,
    25  		want:      1,
    26  	}, {
    27  		largest:   0,
    28  		truncated: 0,
    29  		size:      1,
    30  		want:      0,
    31  	}, {
    32  		largest:   0x00,
    33  		truncated: 0x01,
    34  		size:      1,
    35  		want:      0x01,
    36  	}, {
    37  		largest:   0x00,
    38  		truncated: 0xff,
    39  		size:      1,
    40  		want:      0xff,
    41  	}, {
    42  		largest:   0xff,
    43  		truncated: 0x01,
    44  		size:      1,
    45  		want:      0x101,
    46  	}, {
    47  		largest:   0x1000,
    48  		truncated: 0xff,
    49  		size:      1,
    50  		want:      0xfff,
    51  	}, {
    52  		largest:   0xa82f30ea,
    53  		truncated: 0x9b32,
    54  		size:      2,
    55  		want:      0xa82f9b32,
    56  	}} {
    57  		got := decodePacketNumber(test.largest, test.truncated, test.size)
    58  		if got != test.want {
    59  			t.Errorf("decodePacketNumber(largest=0x%x, truncated=0x%x, size=%v) = 0x%x, want 0x%x", test.largest, test.truncated, test.size, got, test.want)
    60  		}
    61  	}
    62  }
    63  
    64  func TestEncodePacketNumber(t *testing.T) {
    65  	for _, test := range []struct {
    66  		largestAcked packetNumber
    67  		pnum         packetNumber
    68  		wantSize     int
    69  	}{{
    70  		largestAcked: -1,
    71  		pnum:         0,
    72  		wantSize:     1,
    73  	}, {
    74  		largestAcked: 1000,
    75  		pnum:         1000 + 0x7f,
    76  		wantSize:     1,
    77  	}, {
    78  		largestAcked: 1000,
    79  		pnum:         1000 + 0x80, // 0x468
    80  		wantSize:     2,
    81  	}, {
    82  		largestAcked: 0x12345678,
    83  		pnum:         0x12345678 + 0x7fff, // 0x305452663
    84  		wantSize:     2,
    85  	}, {
    86  		largestAcked: 0x12345678,
    87  		pnum:         0x12345678 + 0x8000,
    88  		wantSize:     3,
    89  	}, {
    90  		largestAcked: 0,
    91  		pnum:         0x7fffff,
    92  		wantSize:     3,
    93  	}, {
    94  		largestAcked: 0,
    95  		pnum:         0x800000,
    96  		wantSize:     4,
    97  	}, {
    98  		largestAcked: 0xabe8bc,
    99  		pnum:         0xac5c02,
   100  		wantSize:     2,
   101  	}, {
   102  		largestAcked: 0xabe8bc,
   103  		pnum:         0xace8fe,
   104  		wantSize:     3,
   105  	}} {
   106  		size := packetNumberLength(test.pnum, test.largestAcked)
   107  		if got, want := size, test.wantSize; got != want {
   108  			t.Errorf("packetNumberLength(num=%x, maxAck=%x) = %v, want %v", test.pnum, test.largestAcked, got, want)
   109  		}
   110  		var enc packetNumber
   111  		switch size {
   112  		case 1:
   113  			enc = test.pnum & 0xff
   114  		case 2:
   115  			enc = test.pnum & 0xffff
   116  		case 3:
   117  			enc = test.pnum & 0xffffff
   118  		case 4:
   119  			enc = test.pnum & 0xffffffff
   120  		}
   121  		wantBytes := binary.BigEndian.AppendUint32(nil, uint32(enc))[4-size:]
   122  		gotBytes := appendPacketNumber(nil, test.pnum, test.largestAcked)
   123  		if !bytes.Equal(gotBytes, wantBytes) {
   124  			t.Errorf("appendPacketNumber(num=%v, maxAck=%x) = {%x}, want {%x}", test.pnum, test.largestAcked, gotBytes, wantBytes)
   125  		}
   126  		gotNum := decodePacketNumber(test.largestAcked, enc, size)
   127  		if got, want := gotNum, test.pnum; got != want {
   128  			t.Errorf("packetNumberLength(num=%x, maxAck=%x) = %v, but decoded number=%x", test.pnum, test.largestAcked, size, got)
   129  		}
   130  	}
   131  }
   132  
   133  func FuzzPacketNumber(f *testing.F) {
   134  	truncatedNumber := func(in []byte) packetNumber {
   135  		var truncated packetNumber
   136  		for _, b := range in {
   137  			truncated = (truncated << 8) | packetNumber(b)
   138  		}
   139  		return truncated
   140  	}
   141  	f.Fuzz(func(t *testing.T, in []byte, largestAckedInt64 int64) {
   142  		largestAcked := packetNumber(largestAckedInt64)
   143  		if len(in) < 1 || len(in) > 4 || largestAcked < 0 || largestAcked > maxPacketNumber {
   144  			return
   145  		}
   146  		truncatedIn := truncatedNumber(in)
   147  		decoded := decodePacketNumber(largestAcked, truncatedIn, len(in))
   148  
   149  		// Check that the decoded packet number's least significant bits match the input.
   150  		var mask packetNumber
   151  		for i := 0; i < len(in); i++ {
   152  			mask = (mask << 8) | 0xff
   153  		}
   154  		if truncatedIn != decoded&mask {
   155  			t.Fatalf("decoding mismatch: input=%x largestAcked=%v decoded=0x%x", in, largestAcked, decoded)
   156  		}
   157  
   158  		// We don't support encoding packet numbers less than largestAcked (since packet numbers
   159  		// never decrease), so skip the encoder tests if this would make us go backwards.
   160  		if decoded < largestAcked {
   161  			return
   162  		}
   163  
   164  		// We might encode this number using a different length than we received,
   165  		// but the common portions should match.
   166  		encoded := appendPacketNumber(nil, decoded, largestAcked)
   167  		a, b := in, encoded
   168  		if len(b) < len(a) {
   169  			a, b = b, a
   170  		}
   171  		for len(a) < len(b) {
   172  			b = b[1:]
   173  		}
   174  		if len(a) == 0 || !bytes.Equal(a, b) {
   175  			t.Fatalf("encoding mismatch: input=%x largestAcked=%v decoded=%v reencoded=%x", in, largestAcked, decoded, encoded)
   176  		}
   177  
   178  		if g := decodePacketNumber(largestAcked, truncatedNumber(encoded), len(encoded)); g != decoded {
   179  			t.Fatalf("packet encode/decode mismatch: pnum=%v largestAcked=%v encoded=%x got=%v", decoded, largestAcked, encoded, g)
   180  		}
   181  		if l := packetNumberLength(decoded, largestAcked); l != len(encoded) {
   182  			t.Fatalf("packet number length mismatch: pnum=%v largestAcked=%v encoded=%x len=%v", decoded, largestAcked, encoded, l)
   183  		}
   184  	})
   185  }