wa-lang.org/wazero@v1.0.2/internal/engine/compiler/engine_cache_test.go (about)

     1  package compiler
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"math"
    10  	"testing"
    11  	"testing/iotest"
    12  
    13  	"wa-lang.org/wazero/internal/testing/require"
    14  	"wa-lang.org/wazero/internal/u32"
    15  	"wa-lang.org/wazero/internal/u64"
    16  	"wa-lang.org/wazero/internal/wasm"
    17  )
    18  
    19  var testVersion string
    20  
    21  func concat(ins ...[]byte) (ret []byte) {
    22  	for _, in := range ins {
    23  		ret = append(ret, in...)
    24  	}
    25  	return
    26  }
    27  
    28  func TestSerializeCodes(t *testing.T) {
    29  	tests := []struct {
    30  		in  []*code
    31  		exp []byte
    32  	}{
    33  		{
    34  			in: []*code{{stackPointerCeil: 12345, codeSegment: []byte{1, 2, 3, 4, 5}}},
    35  			exp: concat(
    36  				[]byte(wazeroMagic),
    37  				[]byte{byte(len(testVersion))},
    38  				[]byte(testVersion),
    39  				u32.LeBytes(1),        // number of functions.
    40  				u64.LeBytes(12345),    // stack pointer ceil.
    41  				u64.LeBytes(5),        // length of code.
    42  				[]byte{1, 2, 3, 4, 5}, // code.
    43  			),
    44  		},
    45  		{
    46  			in: []*code{
    47  				{stackPointerCeil: 12345, codeSegment: []byte{1, 2, 3, 4, 5}},
    48  				{stackPointerCeil: 0xffffffff, codeSegment: []byte{1, 2, 3}},
    49  			},
    50  			exp: concat(
    51  				[]byte(wazeroMagic),
    52  				[]byte{byte(len(testVersion))},
    53  				[]byte(testVersion),
    54  				u32.LeBytes(2), // number of functions.
    55  				// Function index = 0.
    56  				u64.LeBytes(12345),    // stack pointer ceil.
    57  				u64.LeBytes(5),        // length of code.
    58  				[]byte{1, 2, 3, 4, 5}, // code.
    59  				// Function index = 1.
    60  				u64.LeBytes(0xffffffff), // stack pointer ceil.
    61  				u64.LeBytes(3),          // length of code.
    62  				[]byte{1, 2, 3},         // code.
    63  			),
    64  		},
    65  	}
    66  
    67  	for i, tc := range tests {
    68  		actual, err := io.ReadAll(serializeCodes(testVersion, tc.in))
    69  		require.NoError(t, err, i)
    70  		require.Equal(t, tc.exp, actual, i)
    71  	}
    72  }
    73  
    74  func TestDeserializeCodes(t *testing.T) {
    75  	tests := []struct {
    76  		name          string
    77  		in            []byte
    78  		expCodes      []*code
    79  		expStaleCache bool
    80  		expErr        string
    81  	}{
    82  		{
    83  			name:   "invalid header",
    84  			in:     []byte{1},
    85  			expErr: "compilationcache: invalid header length: 1",
    86  		},
    87  		{
    88  			name: "version mismatch",
    89  			in: concat(
    90  				[]byte(wazeroMagic),
    91  				[]byte{byte(len("1233123.1.1"))},
    92  				[]byte("1233123.1.1"),
    93  				u32.LeBytes(1), // number of functions.
    94  			),
    95  			expStaleCache: true,
    96  		},
    97  		{
    98  			name: "version mismatch",
    99  			in: concat(
   100  				[]byte(wazeroMagic),
   101  				[]byte{byte(len("1"))},
   102  				[]byte("1"),
   103  				u32.LeBytes(1), // number of functions.
   104  			),
   105  			expStaleCache: true,
   106  		},
   107  		{
   108  			name: "one function",
   109  			in: concat(
   110  				[]byte(wazeroMagic),
   111  				[]byte{byte(len(testVersion))},
   112  				[]byte(testVersion),
   113  				u32.LeBytes(1),        // number of functions.
   114  				u64.LeBytes(12345),    // stack pointer ceil.
   115  				u64.LeBytes(5),        // length of code.
   116  				[]byte{1, 2, 3, 4, 5}, // code.
   117  			),
   118  			expCodes: []*code{
   119  				{stackPointerCeil: 12345, codeSegment: []byte{1, 2, 3, 4, 5}},
   120  			},
   121  			expStaleCache: false,
   122  			expErr:        "",
   123  		},
   124  		{
   125  			name: "two functions",
   126  			in: concat(
   127  				[]byte(wazeroMagic),
   128  				[]byte{byte(len(testVersion))},
   129  				[]byte(testVersion),
   130  				u32.LeBytes(2), // number of functions.
   131  				// Function index = 0.
   132  				u64.LeBytes(12345),    // stack pointer ceil.
   133  				u64.LeBytes(5),        // length of code.
   134  				[]byte{1, 2, 3, 4, 5}, // code.
   135  				// Function index = 1.
   136  				u64.LeBytes(0xffffffff), // stack pointer ceil.
   137  				u64.LeBytes(3),          // length of code.
   138  				[]byte{1, 2, 3},         // code.
   139  			),
   140  			expCodes: []*code{
   141  				{stackPointerCeil: 12345, codeSegment: []byte{1, 2, 3, 4, 5}},
   142  				{stackPointerCeil: 0xffffffff, codeSegment: []byte{1, 2, 3}},
   143  			},
   144  			expStaleCache: false,
   145  			expErr:        "",
   146  		},
   147  		{
   148  			name: "reading stack pointer",
   149  			in: concat(
   150  				[]byte(wazeroMagic),
   151  				[]byte{byte(len(testVersion))},
   152  				[]byte(testVersion),
   153  				u32.LeBytes(2), // number of functions.
   154  				// Function index = 0.
   155  				u64.LeBytes(12345),    // stack pointer ceil.
   156  				u64.LeBytes(5),        // length of code.
   157  				[]byte{1, 2, 3, 4, 5}, // code.
   158  				// Function index = 1.
   159  			),
   160  			expErr: "compilationcache: error reading func[1] stack pointer ceil: EOF",
   161  		},
   162  		{
   163  			name: "reading native code size",
   164  			in: concat(
   165  				[]byte(wazeroMagic),
   166  				[]byte{byte(len(testVersion))},
   167  				[]byte(testVersion),
   168  				u32.LeBytes(2), // number of functions.
   169  				// Function index = 0.
   170  				u64.LeBytes(12345),    // stack pointer ceil.
   171  				u64.LeBytes(5),        // length of code.
   172  				[]byte{1, 2, 3, 4, 5}, // code.
   173  				// Function index = 1.
   174  				u64.LeBytes(12345), // stack pointer ceil.
   175  			),
   176  			expErr: "compilationcache: error reading func[1] reading native code size: EOF",
   177  		},
   178  		{
   179  			name: "mmapping",
   180  			in: concat(
   181  				[]byte(wazeroMagic),
   182  				[]byte{byte(len(testVersion))},
   183  				[]byte(testVersion),
   184  				u32.LeBytes(2), // number of functions.
   185  				// Function index = 0.
   186  				u64.LeBytes(12345),    // stack pointer ceil.
   187  				u64.LeBytes(5),        // length of code.
   188  				[]byte{1, 2, 3, 4, 5}, // code.
   189  				// Function index = 1.
   190  				u64.LeBytes(12345), // stack pointer ceil.
   191  				u64.LeBytes(5),     // length of code.
   192  				// Lack of code here.
   193  			),
   194  			expErr: "compilationcache: error mmapping func[1] code (len=5): EOF",
   195  		},
   196  	}
   197  
   198  	for _, tc := range tests {
   199  		tc := tc
   200  		t.Run(tc.name, func(t *testing.T) {
   201  			codes, staleCache, err := deserializeCodes(testVersion, bytes.NewReader(tc.in))
   202  			if tc.expErr != "" {
   203  				require.EqualError(t, err, tc.expErr)
   204  			} else {
   205  				require.NoError(t, err)
   206  			}
   207  
   208  			require.Equal(t, tc.expCodes, codes)
   209  			require.Equal(t, tc.expStaleCache, staleCache)
   210  		})
   211  	}
   212  }
   213  
   214  func TestEngine_getCodesFromCache(t *testing.T) {
   215  	tests := []struct {
   216  		name       string
   217  		ext        *testCache
   218  		key        wasm.ModuleID
   219  		expCodes   []*code
   220  		expHit     bool
   221  		expErr     string
   222  		expDeleted bool
   223  	}{
   224  		{name: "extern cache not given"},
   225  		{
   226  			name: "not hit",
   227  			ext:  &testCache{caches: map[wasm.ModuleID][]byte{}},
   228  		},
   229  		{
   230  			name:   "error in Cache.Get",
   231  			ext:    &testCache{caches: map[wasm.ModuleID][]byte{{}: {}}},
   232  			expErr: "some error from extern cache",
   233  		},
   234  		{
   235  			name:   "error in deserialization",
   236  			ext:    &testCache{caches: map[wasm.ModuleID][]byte{{}: {1, 2, 3}}},
   237  			expErr: "compilationcache: invalid header length: 3",
   238  		},
   239  		{
   240  			name: "stale cache",
   241  			ext: &testCache{caches: map[wasm.ModuleID][]byte{{}: concat(
   242  				[]byte(wazeroMagic),
   243  				[]byte{byte(len("1233123.1.1"))},
   244  				[]byte("1233123.1.1"),
   245  				u32.LeBytes(1), // number of functions.
   246  			)}},
   247  			expDeleted: true,
   248  		},
   249  		{
   250  			name: "hit",
   251  			ext: &testCache{caches: map[wasm.ModuleID][]byte{
   252  				{}: concat(
   253  					[]byte(wazeroMagic),
   254  					[]byte{byte(len(testVersion))},
   255  					[]byte(testVersion),
   256  					u32.LeBytes(2), // number of functions.
   257  					// Function index = 0.
   258  					u64.LeBytes(12345),    // stack pointer ceil.
   259  					u64.LeBytes(5),        // length of code.
   260  					[]byte{1, 2, 3, 4, 5}, // code.
   261  					// Function index = 1.
   262  					u64.LeBytes(0xffffffff), // stack pointer ceil.
   263  					u64.LeBytes(3),          // length of code.
   264  					[]byte{1, 2, 3},         // code.
   265  				),
   266  			}},
   267  			expHit: true,
   268  			expCodes: []*code{
   269  				{stackPointerCeil: 12345, codeSegment: []byte{1, 2, 3, 4, 5}, indexInModule: 0},
   270  				{stackPointerCeil: 0xffffffff, codeSegment: []byte{1, 2, 3}, indexInModule: 1},
   271  			},
   272  		},
   273  	}
   274  
   275  	for _, tc := range tests {
   276  		tc := tc
   277  		t.Run(tc.name, func(t *testing.T) {
   278  			m := &wasm.Module{ID: tc.key}
   279  			for _, expC := range tc.expCodes {
   280  				expC.sourceModule = m
   281  			}
   282  
   283  			e := engine{}
   284  			if tc.ext != nil {
   285  				e.Cache = tc.ext
   286  			}
   287  
   288  			codes, hit, err := e.getCodesFromCache(m)
   289  			if tc.expErr != "" {
   290  				require.EqualError(t, err, tc.expErr)
   291  			} else {
   292  				require.NoError(t, err)
   293  			}
   294  
   295  			require.Equal(t, tc.expHit, hit)
   296  			require.Equal(t, tc.expCodes, codes)
   297  
   298  			if tc.expDeleted {
   299  				require.Equal(t, tc.ext.deleted, tc.key)
   300  			}
   301  		})
   302  	}
   303  }
   304  
   305  func TestEngine_addCodesToCache(t *testing.T) {
   306  	t.Run("not defined", func(t *testing.T) {
   307  		e := engine{}
   308  		err := e.addCodesToCache(nil, nil)
   309  		require.NoError(t, err)
   310  	})
   311  	t.Run("add", func(t *testing.T) {
   312  		ext := &testCache{caches: map[wasm.ModuleID][]byte{}}
   313  		e := engine{Cache: ext}
   314  		m := &wasm.Module{}
   315  		codes := []*code{{stackPointerCeil: 123, codeSegment: []byte{1, 2, 3}}}
   316  		err := e.addCodesToCache(m, codes)
   317  		require.NoError(t, err)
   318  
   319  		content, ok := ext.caches[m.ID]
   320  		require.True(t, ok)
   321  		require.Equal(t, concat(
   322  			[]byte(wazeroMagic),
   323  			[]byte{byte(len(testVersion))},
   324  			[]byte(testVersion),
   325  			u32.LeBytes(1),   // number of functions.
   326  			u64.LeBytes(123), // stack pointer ceil.
   327  			u64.LeBytes(3),   // length of code.
   328  			[]byte{1, 2, 3},  // code.
   329  		), content)
   330  	})
   331  }
   332  
   333  func Test_readUint64(t *testing.T) {
   334  	tests := []struct {
   335  		name  string
   336  		input uint64
   337  	}{
   338  		{
   339  			name:  "zero",
   340  			input: 0,
   341  		},
   342  		{
   343  			name:  "half",
   344  			input: math.MaxUint32,
   345  		},
   346  		{
   347  			name:  "max",
   348  			input: math.MaxUint64,
   349  		},
   350  	}
   351  
   352  	for _, tt := range tests {
   353  		tc := tt
   354  
   355  		t.Run(tc.name, func(t *testing.T) {
   356  			input := make([]byte, 8)
   357  			binary.LittleEndian.PutUint64(input, tc.input)
   358  
   359  			var b [8]byte
   360  			n, err := readUint64(bytes.NewReader(input), &b)
   361  			require.NoError(t, err)
   362  			require.Equal(t, tc.input, n)
   363  
   364  			// ensure the buffer was cleared
   365  			var expectedB [8]byte
   366  			require.Equal(t, expectedB, b)
   367  		})
   368  	}
   369  }
   370  
   371  func Test_readUint64_errors(t *testing.T) {
   372  	tests := []struct {
   373  		name        string
   374  		input       io.Reader
   375  		expectedErr string
   376  	}{
   377  		{
   378  			name:        "zero",
   379  			input:       bytes.NewReader([]byte{}),
   380  			expectedErr: "EOF",
   381  		},
   382  		{
   383  			name:        "not enough",
   384  			input:       bytes.NewReader([]byte{1, 2}),
   385  			expectedErr: "EOF",
   386  		},
   387  		{
   388  			name:        "error reading",
   389  			input:       iotest.ErrReader(errors.New("ice cream")),
   390  			expectedErr: "ice cream",
   391  		},
   392  	}
   393  
   394  	for _, tt := range tests {
   395  		tc := tt
   396  
   397  		t.Run(tc.name, func(t *testing.T) {
   398  			var b [8]byte
   399  			_, err := readUint64(tc.input, &b)
   400  			require.EqualError(t, err, tc.expectedErr)
   401  		})
   402  	}
   403  }
   404  
   405  // testCache implements compilationcache.Cache
   406  type testCache struct {
   407  	caches  map[wasm.ModuleID][]byte
   408  	deleted wasm.ModuleID
   409  }
   410  
   411  // Get implements compilationcache.Cache Get
   412  func (tc *testCache) Get(key wasm.ModuleID) (content io.ReadCloser, ok bool, err error) {
   413  	var raw []byte
   414  	raw, ok = tc.caches[key]
   415  	if !ok {
   416  		return
   417  	}
   418  
   419  	if len(raw) == 0 {
   420  		ok = false
   421  		err = fmt.Errorf("some error from extern cache")
   422  		return
   423  	}
   424  
   425  	content = io.NopCloser(bytes.NewReader(raw))
   426  	return
   427  }
   428  
   429  // Add implements compilationcache.Cache Add
   430  func (tc *testCache) Add(key wasm.ModuleID, content io.Reader) (err error) {
   431  	raw, err := io.ReadAll(content)
   432  	if err != nil {
   433  		return err
   434  	}
   435  	tc.caches[key] = raw
   436  	return
   437  }
   438  
   439  // Delete implements compilationcache.Cache Delete
   440  func (tc *testCache) Delete(key wasm.ModuleID) (err error) {
   441  	tc.deleted = key
   442  	return
   443  }