github.com/bananabytelabs/wazero@v0.0.0-20240105073314-54b22a776da8/internal/engine/compiler/engine_cache_test.go (about)

     1  package compiler
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/sha256"
     6  	"encoding/binary"
     7  	"errors"
     8  	"io"
     9  	"math"
    10  	"testing"
    11  	"testing/iotest"
    12  
    13  	"github.com/bananabytelabs/wazero/internal/asm"
    14  	"github.com/bananabytelabs/wazero/internal/filecache"
    15  	"github.com/bananabytelabs/wazero/internal/testing/require"
    16  	"github.com/bananabytelabs/wazero/internal/u32"
    17  	"github.com/bananabytelabs/wazero/internal/u64"
    18  	"github.com/bananabytelabs/wazero/internal/wasm"
    19  )
    20  
    21  var testVersion = ""
    22  
    23  func concat(ins ...[]byte) (ret []byte) {
    24  	for _, in := range ins {
    25  		ret = append(ret, in...)
    26  	}
    27  	return
    28  }
    29  
    30  func makeCodeSegment(bytes ...byte) asm.CodeSegment {
    31  	return *asm.NewCodeSegment(bytes)
    32  }
    33  
    34  func TestSerializeCompiledModule(t *testing.T) {
    35  	tests := []struct {
    36  		in  *compiledModule
    37  		exp []byte
    38  	}{
    39  		{
    40  			in: &compiledModule{
    41  				compiledCode: &compiledCode{
    42  					executable: makeCodeSegment(1, 2, 3, 4, 5),
    43  				},
    44  				functions: []compiledFunction{
    45  					{executableOffset: 0, stackPointerCeil: 12345},
    46  				},
    47  			},
    48  			exp: concat(
    49  				[]byte(wazeroMagic),
    50  				[]byte{byte(len(testVersion))},
    51  				[]byte(testVersion),
    52  				[]byte{0},             // ensure termination.
    53  				u32.LeBytes(1),        // number of functions.
    54  				u64.LeBytes(12345),    // stack pointer ceil.
    55  				u64.LeBytes(0),        // offset.
    56  				u64.LeBytes(5),        // length of code.
    57  				[]byte{1, 2, 3, 4, 5}, // code.
    58  			),
    59  		},
    60  		{
    61  			in: &compiledModule{
    62  				compiledCode: &compiledCode{
    63  					executable: makeCodeSegment(1, 2, 3, 4, 5),
    64  				},
    65  				functions: []compiledFunction{
    66  					{executableOffset: 0, stackPointerCeil: 12345},
    67  				},
    68  				ensureTermination: true,
    69  			},
    70  			exp: concat(
    71  				[]byte(wazeroMagic),
    72  				[]byte{byte(len(testVersion))},
    73  				[]byte(testVersion),
    74  				[]byte{1},             // ensure termination.
    75  				u32.LeBytes(1),        // number of functions.
    76  				u64.LeBytes(12345),    // stack pointer ceil.
    77  				u64.LeBytes(0),        // offset.
    78  				u64.LeBytes(5),        // length of code.
    79  				[]byte{1, 2, 3, 4, 5}, // code.
    80  			),
    81  		},
    82  		{
    83  			in: &compiledModule{
    84  				compiledCode: &compiledCode{
    85  					executable: makeCodeSegment(1, 2, 3, 4, 5, 1, 2, 3),
    86  				},
    87  				functions: []compiledFunction{
    88  					{executableOffset: 0, stackPointerCeil: 12345},
    89  					{executableOffset: 5, stackPointerCeil: 0xffffffff},
    90  				},
    91  				ensureTermination: true,
    92  			},
    93  			exp: concat(
    94  				[]byte(wazeroMagic),
    95  				[]byte{byte(len(testVersion))},
    96  				[]byte(testVersion),
    97  				[]byte{1},      // ensure termination.
    98  				u32.LeBytes(2), // number of functions.
    99  				// Function index = 0.
   100  				u64.LeBytes(12345), // stack pointer ceil.
   101  				u64.LeBytes(0),     // offset.
   102  				// Function index = 1.
   103  				u64.LeBytes(0xffffffff), // stack pointer ceil.
   104  				u64.LeBytes(5),          // offset.
   105  				// Executable.
   106  				u64.LeBytes(8),                 // length of code.
   107  				[]byte{1, 2, 3, 4, 5, 1, 2, 3}, // code.
   108  			),
   109  		},
   110  	}
   111  
   112  	for i, tc := range tests {
   113  		actual, err := io.ReadAll(serializeCompiledModule(testVersion, tc.in))
   114  		require.NoError(t, err, i)
   115  		require.Equal(t, tc.exp, actual, i)
   116  	}
   117  }
   118  
   119  func TestDeserializeCompiledModule(t *testing.T) {
   120  	tests := []struct {
   121  		name                  string
   122  		in                    []byte
   123  		importedFunctionCount uint32
   124  		expCompiledModule     *compiledModule
   125  		expStaleCache         bool
   126  		expErr                string
   127  	}{
   128  		{
   129  			name:   "invalid header",
   130  			in:     []byte{1},
   131  			expErr: "compilationcache: invalid header length: 1",
   132  		},
   133  		{
   134  			name: "version mismatch",
   135  			in: concat(
   136  				[]byte(wazeroMagic),
   137  				[]byte{byte(len("1233123.1.1"))},
   138  				[]byte("1233123.1.1"),
   139  				u32.LeBytes(1), // number of functions.
   140  			),
   141  			expStaleCache: true,
   142  		},
   143  		{
   144  			name: "version mismatch",
   145  			in: concat(
   146  				[]byte(wazeroMagic),
   147  				[]byte{byte(len("1"))},
   148  				[]byte("1"),
   149  				u32.LeBytes(1), // number of functions.
   150  			),
   151  			expStaleCache: true,
   152  		},
   153  		{
   154  			name: "one function",
   155  			in: concat(
   156  				[]byte(wazeroMagic),
   157  				[]byte{byte(len(testVersion))},
   158  				[]byte(testVersion),
   159  				[]byte{0},          // ensure termination.
   160  				u32.LeBytes(1),     // number of functions.
   161  				u64.LeBytes(12345), // stack pointer ceil.
   162  				u64.LeBytes(0),     // offset.
   163  				// Executable.
   164  				u64.LeBytes(5),        // size.
   165  				[]byte{1, 2, 3, 4, 5}, // machine code.
   166  			),
   167  			expCompiledModule: &compiledModule{
   168  				compiledCode: &compiledCode{
   169  					executable: makeCodeSegment(1, 2, 3, 4, 5),
   170  				},
   171  				functions: []compiledFunction{
   172  					{executableOffset: 0, stackPointerCeil: 12345, index: 0},
   173  				},
   174  			},
   175  			expStaleCache: false,
   176  			expErr:        "",
   177  		},
   178  		{
   179  			name: "one function with ensure termination",
   180  			in: concat(
   181  				[]byte(wazeroMagic),
   182  				[]byte{byte(len(testVersion))},
   183  				[]byte(testVersion),
   184  				[]byte{1},             // ensure termination.
   185  				u32.LeBytes(1),        // number of functions.
   186  				u64.LeBytes(12345),    // stack pointer ceil.
   187  				u64.LeBytes(0),        // offset.
   188  				u64.LeBytes(5),        // length of code.
   189  				[]byte{1, 2, 3, 4, 5}, // code.
   190  			),
   191  			expCompiledModule: &compiledModule{
   192  				compiledCode: &compiledCode{
   193  					executable: makeCodeSegment(1, 2, 3, 4, 5),
   194  				},
   195  				functions:         []compiledFunction{{executableOffset: 0, stackPointerCeil: 12345, index: 0}},
   196  				ensureTermination: true,
   197  			},
   198  			expStaleCache: false,
   199  			expErr:        "",
   200  		},
   201  		{
   202  			name: "two functions",
   203  			in: concat(
   204  				[]byte(wazeroMagic),
   205  				[]byte{byte(len(testVersion))},
   206  				[]byte(testVersion),
   207  				[]byte{0},      // ensure termination.
   208  				u32.LeBytes(2), // number of functions.
   209  				// Function index = 0.
   210  				u64.LeBytes(12345), // stack pointer ceil.
   211  				u64.LeBytes(0),     // offset.
   212  				// Function index = 1.
   213  				u64.LeBytes(0xffffffff), // stack pointer ceil.
   214  				u64.LeBytes(7),          // offset.
   215  				// Executable.
   216  				u64.LeBytes(10),                       // size.
   217  				[]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, // machine code.
   218  			),
   219  			importedFunctionCount: 1,
   220  			expCompiledModule: &compiledModule{
   221  				compiledCode: &compiledCode{
   222  					executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
   223  				},
   224  				functions: []compiledFunction{
   225  					{executableOffset: 0, stackPointerCeil: 12345, index: 1},
   226  					{executableOffset: 7, stackPointerCeil: 0xffffffff, index: 2},
   227  				},
   228  			},
   229  			expStaleCache: false,
   230  			expErr:        "",
   231  		},
   232  		{
   233  			name: "reading stack pointer",
   234  			in: concat(
   235  				[]byte(wazeroMagic),
   236  				[]byte{byte(len(testVersion))},
   237  				[]byte(testVersion),
   238  				[]byte{0},      // ensure termination.
   239  				u32.LeBytes(2), // number of functions.
   240  				// Function index = 0.
   241  				u64.LeBytes(12345), // stack pointer ceil.
   242  				u64.LeBytes(5),     // offset.
   243  				// Function index = 1.
   244  			),
   245  			expErr: "compilationcache: error reading func[1] stack pointer ceil: EOF",
   246  		},
   247  		{
   248  			name: "reading executable offset",
   249  			in: concat(
   250  				[]byte(wazeroMagic),
   251  				[]byte{byte(len(testVersion))},
   252  				[]byte(testVersion),
   253  				[]byte{0},      // ensure termination.
   254  				u32.LeBytes(2), // number of functions.
   255  				// Function index = 0.
   256  				u64.LeBytes(12345), // stack pointer ceil.
   257  				u64.LeBytes(5),     // offset.
   258  				// Function index = 1.
   259  				u64.LeBytes(12345), // stack pointer ceil.
   260  			),
   261  			expErr: "compilationcache: error reading func[1] executable offset: EOF",
   262  		},
   263  		{
   264  			name: "mmapping",
   265  			in: concat(
   266  				[]byte(wazeroMagic),
   267  				[]byte{byte(len(testVersion))},
   268  				[]byte(testVersion),
   269  				[]byte{0},      // ensure termination.
   270  				u32.LeBytes(2), // number of functions.
   271  				// Function index = 0.
   272  				u64.LeBytes(12345), // stack pointer ceil.
   273  				u64.LeBytes(0),     // offset.
   274  				// Function index = 1.
   275  				u64.LeBytes(12345), // stack pointer ceil.
   276  				u64.LeBytes(5),     // offset.
   277  				// Executable.
   278  				u64.LeBytes(5), // size of the executable.
   279  				// Lack of machine code here.
   280  			),
   281  			expErr: "compilationcache: error reading executable (len=5): EOF",
   282  		},
   283  	}
   284  
   285  	for _, tc := range tests {
   286  		tc := tc
   287  		t.Run(tc.name, func(t *testing.T) {
   288  			cm, staleCache, err := deserializeCompiledModule(testVersion, io.NopCloser(bytes.NewReader(tc.in)),
   289  				&wasm.Module{ImportFunctionCount: tc.importedFunctionCount})
   290  
   291  			if tc.expCompiledModule != nil {
   292  				require.Equal(t, len(tc.expCompiledModule.functions), len(cm.functions))
   293  				for i := 0; i < len(cm.functions); i++ {
   294  					require.Equal(t, cm.compiledCode, cm.functions[i].parent)
   295  					tc.expCompiledModule.functions[i].parent = cm.compiledCode
   296  				}
   297  			}
   298  
   299  			if tc.expErr != "" {
   300  				require.EqualError(t, err, tc.expErr)
   301  			} else {
   302  				require.NoError(t, err)
   303  				require.Equal(t, tc.expCompiledModule, cm)
   304  			}
   305  
   306  			require.Equal(t, tc.expStaleCache, staleCache)
   307  		})
   308  	}
   309  }
   310  
   311  func TestEngine_getCompiledModuleFromCache(t *testing.T) {
   312  	valid := concat(
   313  		[]byte(wazeroMagic),
   314  		[]byte{byte(len(testVersion))},
   315  		[]byte(testVersion),
   316  		[]byte{0},      // ensure termination.
   317  		u32.LeBytes(2), // number of functions.
   318  		// Function index = 0.
   319  		u64.LeBytes(12345), // stack pointer ceil.
   320  		u64.LeBytes(0),     // offset.
   321  		// Function index = 1.
   322  		u64.LeBytes(0xffffffff), // stack pointer ceil.
   323  		u64.LeBytes(5),          // offset.
   324  		// executables.
   325  		u64.LeBytes(10),                       // length of code.
   326  		[]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, // code.
   327  	)
   328  
   329  	tests := []struct {
   330  		name              string
   331  		ext               map[wasm.ModuleID][]byte
   332  		key               wasm.ModuleID
   333  		isHostMod         bool
   334  		expCompiledModule *compiledModule
   335  		expHit            bool
   336  		expErr            string
   337  		expDeleted        bool
   338  	}{
   339  		{name: "extern cache not given"},
   340  		{
   341  			name: "not hit",
   342  			ext:  map[wasm.ModuleID][]byte{},
   343  		},
   344  		{
   345  			name:      "host module",
   346  			ext:       map[wasm.ModuleID][]byte{{}: valid},
   347  			isHostMod: true,
   348  		},
   349  		{
   350  			name:   "error in Cache.Get",
   351  			ext:    map[wasm.ModuleID][]byte{{}: {}},
   352  			expErr: "compilationcache: error reading header: EOF",
   353  		},
   354  		{
   355  			name:   "error in deserialization",
   356  			ext:    map[wasm.ModuleID][]byte{{}: {1, 2, 3}},
   357  			expErr: "compilationcache: invalid header length: 3",
   358  		},
   359  		{
   360  			name: "stale cache",
   361  			ext: map[wasm.ModuleID][]byte{{}: concat(
   362  				[]byte(wazeroMagic),
   363  				[]byte{byte(len("1233123.1.1"))},
   364  				[]byte("1233123.1.1"),
   365  				u32.LeBytes(1), // number of functions.
   366  			)},
   367  			expDeleted: true,
   368  		},
   369  		{
   370  			name: "hit",
   371  			ext: map[wasm.ModuleID][]byte{
   372  				{}: valid,
   373  			},
   374  			expHit: true,
   375  			expCompiledModule: &compiledModule{
   376  				compiledCode: &compiledCode{
   377  					executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
   378  				},
   379  				functions: []compiledFunction{
   380  					{stackPointerCeil: 12345, executableOffset: 0, index: 0},
   381  					{stackPointerCeil: 0xffffffff, executableOffset: 5, index: 1},
   382  				},
   383  			},
   384  		},
   385  	}
   386  
   387  	for _, tc := range tests {
   388  		tc := tc
   389  		t.Run(tc.name, func(t *testing.T) {
   390  			m := &wasm.Module{ID: tc.key, IsHostModule: tc.isHostMod}
   391  			if exp := tc.expCompiledModule; exp != nil {
   392  				exp.source = m
   393  				for i := range tc.expCompiledModule.functions {
   394  					tc.expCompiledModule.functions[i].parent = exp.compiledCode
   395  				}
   396  			}
   397  
   398  			e := engine{}
   399  			if tc.ext != nil {
   400  				tmp := t.TempDir()
   401  				e.fileCache = filecache.New(tmp)
   402  				for key, value := range tc.ext {
   403  					err := e.fileCache.Add(key, bytes.NewReader(value))
   404  					require.NoError(t, err)
   405  				}
   406  			}
   407  
   408  			codes, hit, err := e.getCompiledModuleFromCache(m)
   409  			if tc.expErr != "" {
   410  				require.EqualError(t, err, tc.expErr)
   411  			} else {
   412  				require.NoError(t, err)
   413  			}
   414  
   415  			require.Equal(t, tc.expHit, hit)
   416  			require.Equal(t, tc.expCompiledModule, codes)
   417  
   418  			if tc.ext != nil && tc.expDeleted {
   419  				_, hit, err := e.fileCache.Get(tc.key)
   420  				require.NoError(t, err)
   421  				require.False(t, hit)
   422  			}
   423  		})
   424  	}
   425  }
   426  
   427  func TestEngine_addCompiledModuleToCache(t *testing.T) {
   428  	t.Run("not defined", func(t *testing.T) {
   429  		e := engine{}
   430  		err := e.addCompiledModuleToCache(nil, nil)
   431  		require.NoError(t, err)
   432  	})
   433  	t.Run("host module", func(t *testing.T) {
   434  		tc := filecache.New(t.TempDir())
   435  		e := engine{fileCache: tc}
   436  		cm := &compiledModule{
   437  			compiledCode: &compiledCode{
   438  				executable: makeCodeSegment(1, 2, 3),
   439  			},
   440  			functions: []compiledFunction{{stackPointerCeil: 123}},
   441  		}
   442  		m := &wasm.Module{ID: sha256.Sum256(nil), IsHostModule: true} // Host module!
   443  		err := e.addCompiledModuleToCache(m, cm)
   444  		require.NoError(t, err)
   445  		// Check the host module not cached.
   446  		_, hit, err := tc.Get(m.ID)
   447  		require.NoError(t, err)
   448  		require.False(t, hit)
   449  	})
   450  	t.Run("add", func(t *testing.T) {
   451  		tc := filecache.New(t.TempDir())
   452  		e := engine{fileCache: tc}
   453  		m := &wasm.Module{}
   454  		cm := &compiledModule{
   455  			compiledCode: &compiledCode{
   456  				executable: makeCodeSegment(1, 2, 3),
   457  			},
   458  			functions: []compiledFunction{{stackPointerCeil: 123}},
   459  		}
   460  		err := e.addCompiledModuleToCache(m, cm)
   461  		require.NoError(t, err)
   462  
   463  		content, ok, err := tc.Get(m.ID)
   464  		require.NoError(t, err)
   465  		require.True(t, ok)
   466  		actual, err := io.ReadAll(content)
   467  		require.NoError(t, err)
   468  		require.Equal(t, concat(
   469  			[]byte(wazeroMagic),
   470  			[]byte{byte(len(testVersion))},
   471  			[]byte(testVersion),
   472  			[]byte{0},
   473  			u32.LeBytes(1),   // number of functions.
   474  			u64.LeBytes(123), // stack pointer ceil.
   475  			u64.LeBytes(0),   // offset.
   476  			u64.LeBytes(3),   // size of executable.
   477  			[]byte{1, 2, 3},
   478  		), actual)
   479  		require.NoError(t, content.Close())
   480  	})
   481  }
   482  
   483  func Test_readUint64(t *testing.T) {
   484  	tests := []struct {
   485  		name  string
   486  		input uint64
   487  	}{
   488  		{
   489  			name:  "zero",
   490  			input: 0,
   491  		},
   492  		{
   493  			name:  "half",
   494  			input: math.MaxUint32,
   495  		},
   496  		{
   497  			name:  "max",
   498  			input: math.MaxUint64,
   499  		},
   500  	}
   501  
   502  	for _, tt := range tests {
   503  		tc := tt
   504  
   505  		t.Run(tc.name, func(t *testing.T) {
   506  			input := make([]byte, 8)
   507  			binary.LittleEndian.PutUint64(input, tc.input)
   508  
   509  			var b [8]byte
   510  			n, err := readUint64(bytes.NewReader(input), &b)
   511  			require.NoError(t, err)
   512  			require.Equal(t, tc.input, n)
   513  
   514  			// ensure the buffer was cleared
   515  			var expectedB [8]byte
   516  			require.Equal(t, expectedB, b)
   517  		})
   518  	}
   519  }
   520  
   521  func Test_readUint64_errors(t *testing.T) {
   522  	tests := []struct {
   523  		name        string
   524  		input       io.Reader
   525  		expectedErr string
   526  	}{
   527  		{
   528  			name:        "zero",
   529  			input:       bytes.NewReader([]byte{}),
   530  			expectedErr: "EOF",
   531  		},
   532  		{
   533  			name:        "not enough",
   534  			input:       bytes.NewReader([]byte{1, 2}),
   535  			expectedErr: "EOF",
   536  		},
   537  		{
   538  			name:        "error reading",
   539  			input:       iotest.ErrReader(errors.New("ice cream")),
   540  			expectedErr: "ice cream",
   541  		},
   542  	}
   543  
   544  	for _, tt := range tests {
   545  		tc := tt
   546  
   547  		t.Run(tc.name, func(t *testing.T) {
   548  			var b [8]byte
   549  			_, err := readUint64(tc.input, &b)
   550  			require.EqualError(t, err, tc.expectedErr)
   551  		})
   552  	}
   553  }