github.com/SagerNet/gvisor@v0.0.0-20210707092255-7731c139d75c/pkg/state/statefile/statefile_test.go (about)

     1  // Copyright 2018 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package statefile
    16  
    17  import (
    18  	"bytes"
    19  	crand "crypto/rand"
    20  	"encoding/base64"
    21  	"io"
    22  	"math/rand"
    23  	"runtime"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/SagerNet/gvisor/pkg/compressio"
    28  )
    29  
    30  func randomKey() ([]byte, error) {
    31  	r := make([]byte, base64.RawStdEncoding.DecodedLen(keySize))
    32  	if _, err := io.ReadFull(crand.Reader, r); err != nil {
    33  		return nil, err
    34  	}
    35  	key := make([]byte, keySize)
    36  	base64.RawStdEncoding.Encode(key, r)
    37  	return key, nil
    38  }
    39  
    40  type testCase struct {
    41  	name     string
    42  	data     []byte
    43  	metadata map[string]string
    44  }
    45  
    46  func TestStatefile(t *testing.T) {
    47  	rand.Seed(time.Now().Unix())
    48  
    49  	cases := []testCase{
    50  		// Various data sizes.
    51  		{"nil", nil, nil},
    52  		{"empty", []byte(""), nil},
    53  		{"some", []byte("_"), nil},
    54  		{"one", []byte("0"), nil},
    55  		{"two", []byte("01"), nil},
    56  		{"three", []byte("012"), nil},
    57  		{"four", []byte("0123"), nil},
    58  		{"five", []byte("01234"), nil},
    59  		{"six", []byte("012356"), nil},
    60  		{"seven", []byte("0123567"), nil},
    61  		{"eight", []byte("01235678"), nil},
    62  
    63  		// Make sure we have one longer than the hash length.
    64  		{"longer than hash", []byte("012356asdjflkasjlk3jlk23j4lkjaso0d789f0aujw3lkjlkxsdf78asdful2kj3ljka78"), nil},
    65  
    66  		// Make sure we have one longer than the chunk size.
    67  		{"chunks", make([]byte, 3*compressionChunkSize), nil},
    68  		{"large", make([]byte, 30*compressionChunkSize), nil},
    69  
    70  		// Different metadata.
    71  		{"one metadata", []byte("data"), map[string]string{"foo": "bar"}},
    72  		{"two metadata", []byte("data"), map[string]string{"foo": "bar", "one": "two"}},
    73  	}
    74  
    75  	for _, c := range cases {
    76  		// Generate a key.
    77  		integrityKey, err := randomKey()
    78  		if err != nil {
    79  			t.Errorf("can't generate key: got %v, excepted nil", err)
    80  			continue
    81  		}
    82  
    83  		t.Run(c.name, func(t *testing.T) {
    84  			for _, key := range [][]byte{nil, integrityKey} {
    85  				t.Run("key="+string(key), func(t *testing.T) {
    86  					// Encoding happens via a buffer.
    87  					var bufEncoded bytes.Buffer
    88  					var bufDecoded bytes.Buffer
    89  
    90  					// Do all the writing.
    91  					w, err := NewWriter(&bufEncoded, key, c.metadata)
    92  					if err != nil {
    93  						t.Fatalf("error creating writer: got %v, expected nil", err)
    94  					}
    95  					if _, err := io.Copy(w, bytes.NewBuffer(c.data)); err != nil {
    96  						t.Fatalf("error during write: got %v, expected nil", err)
    97  					}
    98  
    99  					// Finish the sum.
   100  					if err := w.Close(); err != nil {
   101  						t.Fatalf("error during close: got %v, expected nil", err)
   102  					}
   103  
   104  					t.Logf("original data: %d bytes, encoded: %d bytes.",
   105  						len(c.data), len(bufEncoded.Bytes()))
   106  
   107  					// Do all the reading.
   108  					r, metadata, err := NewReader(bytes.NewReader(bufEncoded.Bytes()), key)
   109  					if err != nil {
   110  						t.Fatalf("error creating reader: got %v, expected nil", err)
   111  					}
   112  					if _, err := io.Copy(&bufDecoded, r); err != nil {
   113  						t.Fatalf("error during read: got %v, expected nil", err)
   114  					}
   115  
   116  					// Check that the data matches.
   117  					if !bytes.Equal(c.data, bufDecoded.Bytes()) {
   118  						t.Fatalf("data didn't match (%d vs %d bytes)", len(bufDecoded.Bytes()), len(c.data))
   119  					}
   120  
   121  					// Check that the metadata matches.
   122  					for k, v := range c.metadata {
   123  						nv, ok := metadata[k]
   124  						if !ok {
   125  							t.Fatalf("missing metadata: %s", k)
   126  						}
   127  						if v != nv {
   128  							t.Fatalf("mismatched metdata for %s: got %s, expected %s", k, nv, v)
   129  						}
   130  					}
   131  
   132  					// Change the data and verify that it fails.
   133  					if key != nil {
   134  						b := append([]byte(nil), bufEncoded.Bytes()...)
   135  						b[rand.Intn(len(b))]++
   136  						bufDecoded.Reset()
   137  						r, _, err = NewReader(bytes.NewReader(b), key)
   138  						if err == nil {
   139  							_, err = io.Copy(&bufDecoded, r)
   140  						}
   141  						if err == nil {
   142  							t.Error("got no error: expected error on data corruption")
   143  						}
   144  					}
   145  
   146  					// Change the key and verify that it fails.
   147  					newKey := integrityKey
   148  					if len(key) > 0 {
   149  						newKey = append([]byte{}, key...)
   150  						newKey[rand.Intn(len(newKey))]++
   151  					}
   152  					bufDecoded.Reset()
   153  					r, _, err = NewReader(bytes.NewReader(bufEncoded.Bytes()), newKey)
   154  					if err == nil {
   155  						_, err = io.Copy(&bufDecoded, r)
   156  					}
   157  					if err != compressio.ErrHashMismatch {
   158  						t.Errorf("got error: %v, expected ErrHashMismatch on key mismatch", err)
   159  					}
   160  				})
   161  			}
   162  		})
   163  	}
   164  }
   165  
   166  const benchmarkDataSize = 100 * 1024 * 1024
   167  
   168  func benchmark(b *testing.B, size int, write bool, compressible bool) {
   169  	b.StopTimer()
   170  	b.SetBytes(benchmarkDataSize)
   171  
   172  	// Generate source data.
   173  	var source []byte
   174  	if compressible {
   175  		// For compressible data, we use essentially all zeros.
   176  		source = make([]byte, benchmarkDataSize)
   177  	} else {
   178  		// For non-compressible data, we use random base64 data (to
   179  		// make it marginally compressible, a ratio of 75%).
   180  		var sourceBuf bytes.Buffer
   181  		bufW := base64.NewEncoder(base64.RawStdEncoding, &sourceBuf)
   182  		bufR := rand.New(rand.NewSource(0))
   183  		if _, err := io.CopyN(bufW, bufR, benchmarkDataSize); err != nil {
   184  			b.Fatalf("unable to seed random data: %v", err)
   185  		}
   186  		source = sourceBuf.Bytes()
   187  	}
   188  
   189  	// Generate a random key for integrity check.
   190  	key, err := randomKey()
   191  	if err != nil {
   192  		b.Fatalf("error generating key: %v", err)
   193  	}
   194  
   195  	// Define our benchmark functions. Prior to running the readState
   196  	// function here, you must execute the writeState function at least
   197  	// once (done below).
   198  	var stateBuf bytes.Buffer
   199  	writeState := func() {
   200  		stateBuf.Reset()
   201  		w, err := NewWriter(&stateBuf, key, nil)
   202  		if err != nil {
   203  			b.Fatalf("error creating writer: %v", err)
   204  		}
   205  		for done := 0; done < len(source); {
   206  			chunk := size // limit size.
   207  			if done+chunk > len(source) {
   208  				chunk = len(source) - done
   209  			}
   210  			n, err := w.Write(source[done : done+chunk])
   211  			done += n
   212  			if n == 0 && err != nil {
   213  				b.Fatalf("error during write: %v", err)
   214  			}
   215  		}
   216  		if err := w.Close(); err != nil {
   217  			b.Fatalf("error closing writer: %v", err)
   218  		}
   219  	}
   220  	readState := func() {
   221  		tmpBuf := bytes.NewBuffer(stateBuf.Bytes())
   222  		r, _, err := NewReader(tmpBuf, key)
   223  		if err != nil {
   224  			b.Fatalf("error creating reader: %v", err)
   225  		}
   226  		for done := 0; done < len(source); {
   227  			chunk := size // limit size.
   228  			if done+chunk > len(source) {
   229  				chunk = len(source) - done
   230  			}
   231  			n, err := r.Read(source[done : done+chunk])
   232  			done += n
   233  			if n == 0 && err != nil {
   234  				b.Fatalf("error during read: %v", err)
   235  			}
   236  		}
   237  	}
   238  	// Generate the state once without timing to ensure that buffers have
   239  	// been appropriately allocated.
   240  	writeState()
   241  	if write {
   242  		b.StartTimer()
   243  		for i := 0; i < b.N; i++ {
   244  			writeState()
   245  		}
   246  		b.StopTimer()
   247  	} else {
   248  		b.StartTimer()
   249  		for i := 0; i < b.N; i++ {
   250  			readState()
   251  		}
   252  		b.StopTimer()
   253  	}
   254  }
   255  
   256  func BenchmarkWrite4KCompressible(b *testing.B) {
   257  	benchmark(b, 4096, true, true)
   258  }
   259  
   260  func BenchmarkWrite4KNoncompressible(b *testing.B) {
   261  	benchmark(b, 4096, true, false)
   262  }
   263  
   264  func BenchmarkWrite1MCompressible(b *testing.B) {
   265  	benchmark(b, 1024*1024, true, true)
   266  }
   267  
   268  func BenchmarkWrite1MNoncompressible(b *testing.B) {
   269  	benchmark(b, 1024*1024, true, false)
   270  }
   271  
   272  func BenchmarkRead4KCompressible(b *testing.B) {
   273  	benchmark(b, 4096, false, true)
   274  }
   275  
   276  func BenchmarkRead4KNoncompressible(b *testing.B) {
   277  	benchmark(b, 4096, false, false)
   278  }
   279  
   280  func BenchmarkRead1MCompressible(b *testing.B) {
   281  	benchmark(b, 1024*1024, false, true)
   282  }
   283  
   284  func BenchmarkRead1MNoncompressible(b *testing.B) {
   285  	benchmark(b, 1024*1024, false, false)
   286  }
   287  
   288  func init() {
   289  	runtime.GOMAXPROCS(runtime.NumCPU())
   290  }