github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/state/tests/tests.go (about) 1 // Copyright 2018 The gVisor Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Package tests tests the state packages. 16 package tests 17 18 import ( 19 "bytes" 20 "context" 21 "fmt" 22 "math" 23 "reflect" 24 "testing" 25 26 "github.com/SagerNet/gvisor/pkg/state" 27 "github.com/SagerNet/gvisor/pkg/state/pretty" 28 ) 29 30 // discard is an implementation of wire.Writer. 31 type discard struct{} 32 33 // Write implements wire.Writer.Write. 34 func (discard) Write(p []byte) (int, error) { return len(p), nil } 35 36 // WriteByte implements wire.Writer.WriteByte. 37 func (discard) WriteByte(byte) error { return nil } 38 39 // checkEqual checks if two objects are equal. 40 // 41 // N.B. This only handles one level of dereferences for NaN. Otherwise we 42 // would need to fork the entire implementation of reflect.DeepEqual. 43 func checkEqual(root, loadedValue interface{}) bool { 44 if reflect.DeepEqual(root, loadedValue) { 45 return true 46 } 47 48 // NaN is not equal to itself. We handle the case of raw floating point 49 // primitives here, but don't handle this case nested. 50 rf32, ok1 := root.(float32) 51 lf32, ok2 := loadedValue.(float32) 52 if ok1 && ok2 && math.IsNaN(float64(rf32)) && math.IsNaN(float64(lf32)) { 53 return true 54 } 55 rf64, ok1 := root.(float64) 56 lf64, ok2 := loadedValue.(float64) 57 if ok1 && ok2 && math.IsNaN(rf64) && math.IsNaN(lf64) { 58 return true 59 } 60 61 // Same real for complex numbers. 62 rc64, ok1 := root.(complex64) 63 lc64, ok2 := root.(complex64) 64 if ok1 && ok2 { 65 return checkEqual(real(rc64), real(lc64)) && checkEqual(imag(rc64), imag(lc64)) 66 } 67 rc128, ok1 := root.(complex128) 68 lc128, ok2 := root.(complex128) 69 if ok1 && ok2 { 70 return checkEqual(real(rc128), real(lc128)) && checkEqual(imag(rc128), imag(lc128)) 71 } 72 73 return false 74 } 75 76 // runTestCases runs a test for each object in objects. 77 func runTestCases(t *testing.T, shouldFail bool, prefix string, objects []interface{}) { 78 t.Helper() 79 for i, root := range objects { 80 t.Run(fmt.Sprintf("%s%d", prefix, i), func(t *testing.T) { 81 t.Logf("Original object:\n%#v", root) 82 83 // Save the passed object. 84 saveBuffer := &bytes.Buffer{} 85 saveObjectPtr := reflect.New(reflect.TypeOf(root)) 86 saveObjectPtr.Elem().Set(reflect.ValueOf(root)) 87 saveStats, err := state.Save(context.Background(), saveBuffer, saveObjectPtr.Interface()) 88 if err != nil { 89 if shouldFail { 90 return 91 } 92 t.Fatalf("Save failed unexpectedly: %v", err) 93 } 94 95 // Dump the serialized proto to aid with debugging. 96 var ppBuf bytes.Buffer 97 t.Logf("Raw state:\n%v", saveBuffer.Bytes()) 98 if err := pretty.PrintText(&ppBuf, bytes.NewReader(saveBuffer.Bytes())); err != nil { 99 // We don't count this as a test failure if we 100 // have shouldFail set, but we will count as a 101 // failure if we were not expecting to fail. 102 if !shouldFail { 103 t.Errorf("PrettyPrint(html=false) failed unexpected: %v", err) 104 } 105 } 106 if err := pretty.PrintHTML(discard{}, bytes.NewReader(saveBuffer.Bytes())); err != nil { 107 // See above. 108 if !shouldFail { 109 t.Errorf("PrettyPrint(html=true) failed unexpected: %v", err) 110 } 111 } 112 t.Logf("Encoded state:\n%s", ppBuf.String()) 113 t.Logf("Save stats:\n%s", saveStats.String()) 114 115 // Load a new copy of the object. 116 loadObjectPtr := reflect.New(reflect.TypeOf(root)) 117 loadStats, err := state.Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface()) 118 if err != nil { 119 if shouldFail { 120 return 121 } 122 t.Fatalf("Load failed unexpectedly: %v", err) 123 } 124 125 // Compare the values. 126 loadedValue := loadObjectPtr.Elem().Interface() 127 if !checkEqual(root, loadedValue) { 128 if shouldFail { 129 return 130 } 131 t.Fatalf("Objects differ:\n\toriginal: %#v\n\tloaded: %#v\n", root, loadedValue) 132 } 133 134 // Everything went okay. Is that good? 135 if shouldFail { 136 t.Fatalf("This test was expected to fail, but didn't.") 137 } 138 t.Logf("Load stats:\n%s", loadStats.String()) 139 140 // Truncate half the bytes in the byte stream, 141 // and ensure that we can't restore. Then 142 // truncate only the final byte and ensure that 143 // we can't restore. 144 l := saveBuffer.Len() 145 halfReader := bytes.NewReader(saveBuffer.Bytes()[:l/2]) 146 if _, err := state.Load(context.Background(), halfReader, loadObjectPtr.Interface()); err == nil { 147 t.Errorf("Load with half bytes succeeded unexpectedly.") 148 } 149 missingByteReader := bytes.NewReader(saveBuffer.Bytes()[:l-1]) 150 if _, err := state.Load(context.Background(), missingByteReader, loadObjectPtr.Interface()); err == nil { 151 t.Errorf("Load with missing byte succeeded unexpectedly.") 152 } 153 }) 154 } 155 } 156 157 // convert converts the slice to an []interface{}. 158 func convert(v interface{}) (r []interface{}) { 159 s := reflect.ValueOf(v) // Must be slice. 160 for i := 0; i < s.Len(); i++ { 161 r = append(r, s.Index(i).Interface()) 162 } 163 return r 164 } 165 166 // flatten flattens multiple slices. 167 func flatten(vs ...interface{}) (r []interface{}) { 168 for _, v := range vs { 169 r = append(r, convert(v)...) 170 } 171 return r 172 } 173 174 // filter maps from one slice to another. 175 func filter(vs interface{}, fn func(interface{}) (interface{}, bool)) (r []interface{}) { 176 s := reflect.ValueOf(vs) 177 for i := 0; i < s.Len(); i++ { 178 v, ok := fn(s.Index(i).Interface()) 179 if ok { 180 r = append(r, v) 181 } 182 } 183 return r 184 } 185 186 // combine combines objects in two slices as specified. 187 func combine(v1, v2 interface{}, fn func(_, _ interface{}) interface{}) (r []interface{}) { 188 s1 := reflect.ValueOf(v1) 189 s2 := reflect.ValueOf(v2) 190 for i := 0; i < s1.Len(); i++ { 191 for j := 0; j < s2.Len(); j++ { 192 // Combine using the given function. 193 r = append(r, fn(s1.Index(i).Interface(), s2.Index(j).Interface())) 194 } 195 } 196 return r 197 } 198 199 // pointersTo is a filter function that returns pointers. 200 func pointersTo(vs interface{}) []interface{} { 201 return filter(vs, func(o interface{}) (interface{}, bool) { 202 v := reflect.New(reflect.TypeOf(o)) 203 v.Elem().Set(reflect.ValueOf(o)) 204 return v.Interface(), true 205 }) 206 } 207 208 // interfacesTo is a filter function that returns interface objects. 209 func interfacesTo(vs interface{}) []interface{} { 210 return filter(vs, func(o interface{}) (interface{}, bool) { 211 var v [1]interface{} 212 v[0] = o 213 return v, true 214 }) 215 }