github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/tcpip/testutil/testutil.go (about) 1 // Copyright 2021 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package testutil provides helper functions for netstack unit tests. 16 package testutil 17 18 import ( 19 "fmt" 20 "net" 21 "reflect" 22 "strings" 23 24 "github.com/SagerNet/gvisor/pkg/tcpip" 25 ) 26 27 // MustParse4 parses an IPv4 string (e.g. "192.168.1.1") into a tcpip.Address. 28 // Passing an IPv4-mapped IPv6 address will yield only the 4 IPv4 bytes. 29 func MustParse4(addr string) tcpip.Address { 30 ip := net.ParseIP(addr).To4() 31 if ip == nil { 32 panic(fmt.Sprintf("Parse4 expects IPv4 addresses, but was passed %q", addr)) 33 } 34 return tcpip.Address(ip) 35 } 36 37 // MustParse6 parses an IPv6 string (e.g. "fe80::1") into a tcpip.Address. Passing 38 // an IPv4 address will yield an IPv4-mapped IPv6 address. 39 func MustParse6(addr string) tcpip.Address { 40 ip := net.ParseIP(addr).To16() 41 if ip == nil { 42 panic(fmt.Sprintf("Parse6 was passed malformed address %q", addr)) 43 } 44 return tcpip.Address(ip) 45 } 46 47 func checkFieldCounts(ref, multi reflect.Value) error { 48 refTypeName := ref.Type().Name() 49 multiTypeName := multi.Type().Name() 50 refNumField := ref.NumField() 51 multiNumField := multi.NumField() 52 53 if refNumField != multiNumField { 54 return fmt.Errorf("type %s has an incorrect number of fields: got = %d, want = %d (same as type %s)", multiTypeName, multiNumField, refNumField, refTypeName) 55 } 56 57 return nil 58 } 59 60 func validateField(ref reflect.Value, refName string, m tcpip.MultiCounterStat, multiName string) error { 61 s, ok := ref.Addr().Interface().(**tcpip.StatCounter) 62 if !ok { 63 return fmt.Errorf("expected ref type's to be *StatCounter, but its type is %s", ref.Type().Elem().Name()) 64 } 65 66 // The field names are expected to match (case insensitive). 67 if !strings.EqualFold(refName, multiName) { 68 return fmt.Errorf("wrong field name: got = %s, want = %s", multiName, refName) 69 } 70 71 base := (*s).Value() 72 m.Increment() 73 if (*s).Value() != base+1 { 74 return fmt.Errorf("updates to the '%s MultiCounterStat' counters are not reflected in the '%s CounterStat'", multiName, refName) 75 } 76 77 return nil 78 } 79 80 // ValidateMultiCounterStats verifies that every counter stored in multi is 81 // correctly tracking its counterpart in the given counters. 82 func ValidateMultiCounterStats(multi reflect.Value, counters []reflect.Value) error { 83 for _, c := range counters { 84 if err := checkFieldCounts(c, multi); err != nil { 85 return err 86 } 87 } 88 89 for i := 0; i < multi.NumField(); i++ { 90 multiName := multi.Type().Field(i).Name 91 multiUnsafe := unsafeExposeUnexportedFields(multi.Field(i)) 92 93 if m, ok := multiUnsafe.Addr().Interface().(*tcpip.MultiCounterStat); ok { 94 for _, c := range counters { 95 if err := validateField(unsafeExposeUnexportedFields(c.Field(i)), c.Type().Field(i).Name, *m, multiName); err != nil { 96 return err 97 } 98 } 99 } else { 100 var countersNextField []reflect.Value 101 for _, c := range counters { 102 countersNextField = append(countersNextField, c.Field(i)) 103 } 104 if err := ValidateMultiCounterStats(multi.Field(i), countersNextField); err != nil { 105 return err 106 } 107 } 108 } 109 110 return nil 111 } 112 113 // MustParseLink parses a Link string into a tcpip.LinkAddress, panicking on 114 // error. 115 // 116 // The string must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff. 117 func MustParseLink(addr string) tcpip.LinkAddress { 118 parsed, err := tcpip.ParseMACAddress(addr) 119 if err != nil { 120 panic(fmt.Sprintf("tcpip.ParseMACAddress(%s): %s", addr, err)) 121 } 122 return parsed 123 }