gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/state/tests/tests.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package tests tests the state packages.
    16  package tests
    17  
    18  import (
    19  	"bytes"
    20  	"context"
    21  	"fmt"
    22  	"math"
    23  	"reflect"
    24  	"testing"
    25  
    26  	"gvisor.dev/gvisor/pkg/state"
    27  	"gvisor.dev/gvisor/pkg/state/pretty"
    28  )
    29  
    30  // discard is an implementation of wire.Writer.
    31  type discard struct{}
    32  
    33  // Write implements wire.Writer.Write.
    34  func (discard) Write(p []byte) (int, error) { return len(p), nil }
    35  
    36  // WriteByte implements wire.Writer.WriteByte.
    37  func (discard) WriteByte(byte) error { return nil }
    38  
    39  // checkEqual checks if two objects are equal.
    40  //
    41  // N.B. This only handles one level of dereferences for NaN. Otherwise we
    42  // would need to fork the entire implementation of reflect.DeepEqual.
    43  func checkEqual(root, loadedValue any) bool {
    44  	if reflect.DeepEqual(root, loadedValue) {
    45  		return true
    46  	}
    47  
    48  	// NaN is not equal to itself. We handle the case of raw floating point
    49  	// primitives here, but don't handle this case nested.
    50  	rf32, ok1 := root.(float32)
    51  	lf32, ok2 := loadedValue.(float32)
    52  	if ok1 && ok2 && math.IsNaN(float64(rf32)) && math.IsNaN(float64(lf32)) {
    53  		return true
    54  	}
    55  	rf64, ok1 := root.(float64)
    56  	lf64, ok2 := loadedValue.(float64)
    57  	if ok1 && ok2 && math.IsNaN(rf64) && math.IsNaN(lf64) {
    58  		return true
    59  	}
    60  
    61  	// Same real for complex numbers.
    62  	rc64, ok1 := root.(complex64)
    63  	lc64, ok2 := root.(complex64)
    64  	if ok1 && ok2 {
    65  		return checkEqual(real(rc64), real(lc64)) && checkEqual(imag(rc64), imag(lc64))
    66  	}
    67  	rc128, ok1 := root.(complex128)
    68  	lc128, ok2 := root.(complex128)
    69  	if ok1 && ok2 {
    70  		return checkEqual(real(rc128), real(lc128)) && checkEqual(imag(rc128), imag(lc128))
    71  	}
    72  
    73  	return false
    74  }
    75  
    76  // runTestCases runs a test for each object in objects.
    77  func runTestCases(t *testing.T, shouldFail bool, prefix string, objects []any) {
    78  	t.Helper()
    79  	for i, root := range objects {
    80  		t.Run(fmt.Sprintf("%s%d", prefix, i), func(t *testing.T) {
    81  			t.Logf("Original object:\n%#v", root)
    82  
    83  			// Save the passed object.
    84  			saveBuffer := &bytes.Buffer{}
    85  			saveObjectPtr := reflect.New(reflect.TypeOf(root))
    86  			saveObjectPtr.Elem().Set(reflect.ValueOf(root))
    87  			saveStats, err := state.Save(context.Background(), saveBuffer, saveObjectPtr.Interface())
    88  			if err != nil {
    89  				if shouldFail {
    90  					return
    91  				}
    92  				t.Fatalf("Save failed unexpectedly: %v", err)
    93  			}
    94  
    95  			// Dump the serialized proto to aid with debugging.
    96  			var ppBuf bytes.Buffer
    97  			t.Logf("Raw state:\n%v", saveBuffer.Bytes())
    98  			if err := pretty.PrintText(&ppBuf, bytes.NewReader(saveBuffer.Bytes())); err != nil {
    99  				// We don't count this as a test failure if we
   100  				// have shouldFail set, but we will count as a
   101  				// failure if we were not expecting to fail.
   102  				if !shouldFail {
   103  					t.Errorf("PrettyPrint(html=false) failed unexpected: %v", err)
   104  				}
   105  			}
   106  			if err := pretty.PrintHTML(discard{}, bytes.NewReader(saveBuffer.Bytes())); err != nil {
   107  				// See above.
   108  				if !shouldFail {
   109  					t.Errorf("PrettyPrint(html=true) failed unexpected: %v", err)
   110  				}
   111  			}
   112  			t.Logf("Encoded state:\n%s", ppBuf.String())
   113  			t.Logf("Save stats:\n%s", saveStats.String())
   114  
   115  			// Load a new copy of the object.
   116  			loadObjectPtr := reflect.New(reflect.TypeOf(root))
   117  			loadStats, err := state.Load(context.Background(), bytes.NewReader(saveBuffer.Bytes()), loadObjectPtr.Interface())
   118  			if err != nil {
   119  				if shouldFail {
   120  					return
   121  				}
   122  				t.Fatalf("Load failed unexpectedly: %v", err)
   123  			}
   124  
   125  			// Compare the values.
   126  			loadedValue := loadObjectPtr.Elem().Interface()
   127  			if !checkEqual(root, loadedValue) {
   128  				if shouldFail {
   129  					return
   130  				}
   131  				t.Fatalf("Objects differ:\n\toriginal: %#v\n\tloaded:   %#v\n", root, loadedValue)
   132  			}
   133  
   134  			// Everything went okay. Is that good?
   135  			if shouldFail {
   136  				t.Fatalf("This test was expected to fail, but didn't.")
   137  			}
   138  			t.Logf("Load stats:\n%s", loadStats.String())
   139  
   140  			// Truncate half the bytes in the byte stream,
   141  			// and ensure that we can't restore. Then
   142  			// truncate only the final byte and ensure that
   143  			// we can't restore.
   144  			l := saveBuffer.Len()
   145  			halfReader := bytes.NewReader(saveBuffer.Bytes()[:l/2])
   146  			if _, err := state.Load(context.Background(), halfReader, loadObjectPtr.Interface()); err == nil {
   147  				t.Errorf("Load with half bytes succeeded unexpectedly.")
   148  			}
   149  			missingByteReader := bytes.NewReader(saveBuffer.Bytes()[:l-1])
   150  			if _, err := state.Load(context.Background(), missingByteReader, loadObjectPtr.Interface()); err == nil {
   151  				t.Errorf("Load with missing byte succeeded unexpectedly.")
   152  			}
   153  		})
   154  	}
   155  }
   156  
   157  // convert converts the slice to an []any.
   158  func convert(v any) (r []any) {
   159  	s := reflect.ValueOf(v) // Must be slice.
   160  	for i := 0; i < s.Len(); i++ {
   161  		r = append(r, s.Index(i).Interface())
   162  	}
   163  	return r
   164  }
   165  
   166  // flatten flattens multiple slices.
   167  func flatten(vs ...any) (r []any) {
   168  	for _, v := range vs {
   169  		r = append(r, convert(v)...)
   170  	}
   171  	return r
   172  }
   173  
   174  // filter maps from one slice to another.
   175  func filter(vs any, fn func(any) (any, bool)) (r []any) {
   176  	s := reflect.ValueOf(vs)
   177  	for i := 0; i < s.Len(); i++ {
   178  		v, ok := fn(s.Index(i).Interface())
   179  		if ok {
   180  			r = append(r, v)
   181  		}
   182  	}
   183  	return r
   184  }
   185  
   186  // combine combines objects in two slices as specified.
   187  func combine(v1, v2 any, fn func(_, _ any) any) (r []any) {
   188  	s1 := reflect.ValueOf(v1)
   189  	s2 := reflect.ValueOf(v2)
   190  	for i := 0; i < s1.Len(); i++ {
   191  		for j := 0; j < s2.Len(); j++ {
   192  			// Combine using the given function.
   193  			r = append(r, fn(s1.Index(i).Interface(), s2.Index(j).Interface()))
   194  		}
   195  	}
   196  	return r
   197  }
   198  
   199  // pointersTo is a filter function that returns pointers.
   200  func pointersTo(vs any) []any {
   201  	return filter(vs, func(o any) (any, bool) {
   202  		v := reflect.New(reflect.TypeOf(o))
   203  		v.Elem().Set(reflect.ValueOf(o))
   204  		return v.Interface(), true
   205  	})
   206  }
   207  
   208  // interfacesTo is a filter function that returns interface objects.
   209  func interfacesTo(vs any) []any {
   210  	return filter(vs, func(o any) (any, bool) {
   211  		var v [1]any
   212  		v[0] = o
   213  		return v, true
   214  	})
   215  }