gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/tcpip/testutil/testutil.go (about) 1 // Copyright 2021 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package testutil provides helper functions for netstack unit tests. 16 package testutil 17 18 import ( 19 "fmt" 20 "net" 21 "reflect" 22 "strconv" 23 "strings" 24 25 "gvisor.dev/gvisor/pkg/tcpip" 26 ) 27 28 // MustParse4 parses an IPv4 string (e.g. "192.168.1.1") into a tcpip.Address. 29 // Passing an IPv4-mapped IPv6 address will yield only the 4 IPv4 bytes. 30 func MustParse4(addr string) tcpip.Address { 31 ip := net.ParseIP(addr).To4() 32 if ip == nil { 33 panic(fmt.Sprintf("Parse4 expects IPv4 addresses, but was passed %q", addr)) 34 } 35 return tcpip.AddrFrom4Slice(ip) 36 } 37 38 // MustParse6 parses an IPv6 string (e.g. "fe80::1") into a tcpip.Address. Passing 39 // an IPv4 address will yield an IPv4-mapped IPv6 address. 40 func MustParse6(addr string) tcpip.Address { 41 ip := net.ParseIP(addr).To16() 42 if ip == nil { 43 panic(fmt.Sprintf("Parse6 was passed malformed address %q", addr)) 44 } 45 return tcpip.AddrFrom16Slice(ip) 46 } 47 48 // MustParseSubnet4 parses an IPv4 subnet string (e.g. "192.168.1.0/24") into a 49 // tcpip.Subnet. 50 func MustParseSubnet4(subnet string) tcpip.Subnet { 51 parts := strings.Split(subnet, "/") 52 if len(parts) != 2 { 53 panic(fmt.Sprintf("MustParseSubnet4 expected CIDR notation (<addr>/<prefixLen>), but got %q", subnet)) 54 } 55 addr := MustParse4(parts[0]) 56 prefixLen, err := strconv.Atoi(parts[1]) 57 if err != nil { 58 panic(fmt.Sprintf("Failed to parse prefix length %q: %v", parts[1], err)) 59 } 60 if prefixLen < 0 || prefixLen > 32 { 61 panic(fmt.Sprintf("Prefix length %d is invalid. It must be between 0 and 32", prefixLen)) 62 } 63 prefixed := tcpip.AddressWithPrefix{Address: addr, PrefixLen: prefixLen} 64 return prefixed.Subnet() 65 } 66 67 func checkFieldCounts(ref, multi reflect.Value) error { 68 refTypeName := ref.Type().Name() 69 multiTypeName := multi.Type().Name() 70 refNumField := ref.NumField() 71 multiNumField := multi.NumField() 72 73 if refNumField != multiNumField { 74 return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName) 75 } 76 77 return nil 78 } 79 80 func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error { 81 s, ok := ref.Addr().Interface().(**tcpip.StatCounter) 82 if !ok { 83 return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name()) 84 } 85 86 // The field names are expected to match (case insensitive). 87 if !strings.EqualFold(refName, multiName) { 88 return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName) 89 } 90 91 base := (*s).Value() 92 m.Increment() 93 if (*s).Value() != base+1 { 94 return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName) 95 } 96 97 return nil 98 } 99 100 func validateIntegralMapField(ref reflect.Value, refName string, m tcpip.MultiIntegralStatCounterMap, multiName string) error { 101 // The field names are expected to match (case insensitive). 102 if !strings.EqualFold(refName, multiName) { 103 return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName) 104 } 105 s, ok := ref.Addr().Interface().(**tcpip.IntegralStatCounterMap) 106 if !ok { 107 return fmt.Errorf("field is not an IntegralStatCounterMap") 108 } 109 110 const key = 42 111 112 getValue := func() uint64 { 113 counter, ok := (*s).Get(key) 114 if !ok { 115 return 0 116 } 117 return counter.Value() 118 } 119 120 before := getValue() 121 122 m.Increment(key) 123 124 after := getValue() 125 126 if after != before+1 { 127 return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName) 128 } 129 130 return nil 131 } 132 133 func validateMultiCounterStats(multi reflect.Value, counters []reflect.Value) (foundMultiCounterStat, foundMultiIntegralStatCounterMap bool, err error) { 134 for _, c := range counters { 135 if err := checkFieldCounts(c, multi); err != nil { 136 return false, false, err 137 } 138 } 139 140 for i := 0; i < multi.NumField(); i++ { 141 multiName := multi.Type().Field(i).Name 142 multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i)) 143 144 switch m := multiUnsafe.Addr().Interface().(type) { 145 case *tcpip.MultiCounterStat: 146 foundMultiCounterStat = true 147 for _, c := range counters { 148 if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil { 149 return false, false, err 150 } 151 } 152 case *tcpip.MultiIntegralStatCounterMap: 153 foundMultiIntegralStatCounterMap = true 154 for _, c := range counters { 155 if err := validateIntegralMapField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil { 156 return false, false, err 157 } 158 } 159 default: 160 var countersNextField []reflect.Value 161 for _, c := range counters { 162 countersNextField = append(countersNextField, c.Field(i)) 163 } 164 innerFoundMultiCounterStat, innerFoundMultiIntegralStatCounterMap, err := validateMultiCounterStats(multi.Field(i), countersNextField) 165 if err != nil { 166 return false, false, err 167 } 168 foundMultiCounterStat = foundMultiCounterStat || innerFoundMultiCounterStat 169 foundMultiIntegralStatCounterMap = foundMultiIntegralStatCounterMap || innerFoundMultiIntegralStatCounterMap 170 } 171 } 172 173 return foundMultiCounterStat, foundMultiIntegralStatCounterMap, nil 174 } 175 176 // ValidateMultiCounterStatsOptions holds options used when validating multi 177 // counter stat structs. 178 type ValidateMultiCounterStatsOptions struct { 179 ExpectMultiCounterStat bool 180 ExpectMultiIntegralStatCounterMap bool 181 } 182 183 // ValidateMultiCounterStats verifies that every counter stored in multi is 184 // correctly tracking its counterpart in the given counters. 185 func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value, options ValidateMultiCounterStatsOptions) error { 186 foundMultiCounterStat, foundMultiIntegralStatCounterMap, err := validateMultiCounterStats(multi, counters) 187 if err != nil { 188 return err 189 } 190 if foundMultiCounterStat != options.ExpectMultiCounterStat { 191 return fmt.Errorf("got %T presence: %t, want: %t", (*tcpip.MultiCounterStat)(nil), foundMultiCounterStat, options.ExpectMultiCounterStat) 192 } 193 if foundMultiIntegralStatCounterMap != options.ExpectMultiIntegralStatCounterMap { 194 return fmt.Errorf("got %T presence: %t, want: %t", (*tcpip.MultiIntegralStatCounterMap)(nil), foundMultiIntegralStatCounterMap, options.ExpectMultiIntegralStatCounterMap) 195 } 196 197 return nil 198 } 199 200 // MustParseLink parses a Link string into a tcpip.LinkAddress, panicking on 201 // error. 202 // 203 // The string must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff. 204 func MustParseLink(addr string) tcpip.LinkAddress { 205 parsed, err := tcpip.ParseMACAddress(addr) 206 if err != nil { 207 panic(fmt.Sprintf("tcpip.ParseMACAddress(%s): %s", addr, err)) 208 } 209 return parsed 210 }