github.com/tetratelabs/wazero@v1.2.1/internal/engine/compiler/engine_cache_test.go (about)

     1  package compiler
     2  
     3  import (
     4  	"bytes"
     5  	"crypto/sha256"
     6  	"encoding/binary"
     7  	"errors"
     8  	"io"
     9  	"math"
    10  	"testing"
    11  	"testing/iotest"
    12  
    13  	"github.com/tetratelabs/wazero/internal/asm"
    14  	"github.com/tetratelabs/wazero/internal/filecache"
    15  	"github.com/tetratelabs/wazero/internal/testing/require"
    16  	"github.com/tetratelabs/wazero/internal/u32"
    17  	"github.com/tetratelabs/wazero/internal/u64"
    18  	"github.com/tetratelabs/wazero/internal/wasm"
    19  )
    20  
    21  var testVersion = ""
    22  
    23  func concat(ins ...[]byte) (ret []byte) {
    24  	for _, in := range ins {
    25  		ret = append(ret, in...)
    26  	}
    27  	return
    28  }
    29  
    30  func makeCodeSegment(bytes ...byte) asm.CodeSegment {
    31  	return *asm.NewCodeSegment(bytes)
    32  }
    33  
    34  func TestSerializeCompiledModule(t *testing.T) {
    35  	tests := []struct {
    36  		in  *compiledModule
    37  		exp []byte
    38  	}{
    39  		{
    40  			in: &compiledModule{
    41  				executable: makeCodeSegment(1, 2, 3, 4, 5),
    42  				functions: []compiledFunction{
    43  					{executableOffset: 0, stackPointerCeil: 12345},
    44  				},
    45  			},
    46  			exp: concat(
    47  				[]byte(wazeroMagic),
    48  				[]byte{byte(len(testVersion))},
    49  				[]byte(testVersion),
    50  				[]byte{0},             // ensure termination.
    51  				u32.LeBytes(1),        // number of functions.
    52  				u64.LeBytes(12345),    // stack pointer ceil.
    53  				u64.LeBytes(0),        // offset.
    54  				u64.LeBytes(5),        // length of code.
    55  				[]byte{1, 2, 3, 4, 5}, // code.
    56  			),
    57  		},
    58  		{
    59  			in: &compiledModule{
    60  				ensureTermination: true,
    61  				executable:        makeCodeSegment(1, 2, 3, 4, 5),
    62  				functions: []compiledFunction{
    63  					{executableOffset: 0, stackPointerCeil: 12345},
    64  				},
    65  			},
    66  			exp: concat(
    67  				[]byte(wazeroMagic),
    68  				[]byte{byte(len(testVersion))},
    69  				[]byte(testVersion),
    70  				[]byte{1},             // ensure termination.
    71  				u32.LeBytes(1),        // number of functions.
    72  				u64.LeBytes(12345),    // stack pointer ceil.
    73  				u64.LeBytes(0),        // offset.
    74  				u64.LeBytes(5),        // length of code.
    75  				[]byte{1, 2, 3, 4, 5}, // code.
    76  			),
    77  		},
    78  		{
    79  			in: &compiledModule{
    80  				ensureTermination: true,
    81  				executable:        makeCodeSegment(1, 2, 3, 4, 5, 1, 2, 3),
    82  				functions: []compiledFunction{
    83  					{executableOffset: 0, stackPointerCeil: 12345},
    84  					{executableOffset: 5, stackPointerCeil: 0xffffffff},
    85  				},
    86  			},
    87  			exp: concat(
    88  				[]byte(wazeroMagic),
    89  				[]byte{byte(len(testVersion))},
    90  				[]byte(testVersion),
    91  				[]byte{1},      // ensure termination.
    92  				u32.LeBytes(2), // number of functions.
    93  				// Function index = 0.
    94  				u64.LeBytes(12345), // stack pointer ceil.
    95  				u64.LeBytes(0),     // offset.
    96  				// Function index = 1.
    97  				u64.LeBytes(0xffffffff), // stack pointer ceil.
    98  				u64.LeBytes(5),          // offset.
    99  				// Executable.
   100  				u64.LeBytes(8),                 // length of code.
   101  				[]byte{1, 2, 3, 4, 5, 1, 2, 3}, // code.
   102  			),
   103  		},
   104  	}
   105  
   106  	for i, tc := range tests {
   107  		actual, err := io.ReadAll(serializeCompiledModule(testVersion, tc.in))
   108  		require.NoError(t, err, i)
   109  		require.Equal(t, tc.exp, actual, i)
   110  	}
   111  }
   112  
   113  func TestDeserializeCompiledModule(t *testing.T) {
   114  	tests := []struct {
   115  		name                  string
   116  		in                    []byte
   117  		importedFunctionCount uint32
   118  		expCompiledModule     *compiledModule
   119  		expStaleCache         bool
   120  		expErr                string
   121  	}{
   122  		{
   123  			name:   "invalid header",
   124  			in:     []byte{1},
   125  			expErr: "compilationcache: invalid header length: 1",
   126  		},
   127  		{
   128  			name: "version mismatch",
   129  			in: concat(
   130  				[]byte(wazeroMagic),
   131  				[]byte{byte(len("1233123.1.1"))},
   132  				[]byte("1233123.1.1"),
   133  				u32.LeBytes(1), // number of functions.
   134  			),
   135  			expStaleCache: true,
   136  		},
   137  		{
   138  			name: "version mismatch",
   139  			in: concat(
   140  				[]byte(wazeroMagic),
   141  				[]byte{byte(len("1"))},
   142  				[]byte("1"),
   143  				u32.LeBytes(1), // number of functions.
   144  			),
   145  			expStaleCache: true,
   146  		},
   147  		{
   148  			name: "one function",
   149  			in: concat(
   150  				[]byte(wazeroMagic),
   151  				[]byte{byte(len(testVersion))},
   152  				[]byte(testVersion),
   153  				[]byte{0},          // ensure termination.
   154  				u32.LeBytes(1),     // number of functions.
   155  				u64.LeBytes(12345), // stack pointer ceil.
   156  				u64.LeBytes(0),     // offset.
   157  				// Executable.
   158  				u64.LeBytes(5),        // size.
   159  				[]byte{1, 2, 3, 4, 5}, // machine code.
   160  			),
   161  			expCompiledModule: &compiledModule{
   162  				executable: makeCodeSegment(1, 2, 3, 4, 5),
   163  				functions: []compiledFunction{
   164  					{executableOffset: 0, stackPointerCeil: 12345, index: 0},
   165  				},
   166  			},
   167  			expStaleCache: false,
   168  			expErr:        "",
   169  		},
   170  		{
   171  			name: "one function with ensure termination",
   172  			in: concat(
   173  				[]byte(wazeroMagic),
   174  				[]byte{byte(len(testVersion))},
   175  				[]byte(testVersion),
   176  				[]byte{1},             // ensure termination.
   177  				u32.LeBytes(1),        // number of functions.
   178  				u64.LeBytes(12345),    // stack pointer ceil.
   179  				u64.LeBytes(0),        // offset.
   180  				u64.LeBytes(5),        // length of code.
   181  				[]byte{1, 2, 3, 4, 5}, // code.
   182  			),
   183  			expCompiledModule: &compiledModule{
   184  				ensureTermination: true,
   185  				executable:        makeCodeSegment(1, 2, 3, 4, 5),
   186  				functions:         []compiledFunction{{executableOffset: 0, stackPointerCeil: 12345, index: 0}},
   187  			},
   188  			expStaleCache: false,
   189  			expErr:        "",
   190  		},
   191  		{
   192  			name: "two functions",
   193  			in: concat(
   194  				[]byte(wazeroMagic),
   195  				[]byte{byte(len(testVersion))},
   196  				[]byte(testVersion),
   197  				[]byte{0},      // ensure termination.
   198  				u32.LeBytes(2), // number of functions.
   199  				// Function index = 0.
   200  				u64.LeBytes(12345), // stack pointer ceil.
   201  				u64.LeBytes(0),     // offset.
   202  				// Function index = 1.
   203  				u64.LeBytes(0xffffffff), // stack pointer ceil.
   204  				u64.LeBytes(7),          // offset.
   205  				// Executable.
   206  				u64.LeBytes(10),                       // size.
   207  				[]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, // machine code.
   208  			),
   209  			importedFunctionCount: 1,
   210  			expCompiledModule: &compiledModule{
   211  				executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
   212  				functions: []compiledFunction{
   213  					{executableOffset: 0, stackPointerCeil: 12345, index: 1},
   214  					{executableOffset: 7, stackPointerCeil: 0xffffffff, index: 2},
   215  				},
   216  			},
   217  			expStaleCache: false,
   218  			expErr:        "",
   219  		},
   220  		{
   221  			name: "reading stack pointer",
   222  			in: concat(
   223  				[]byte(wazeroMagic),
   224  				[]byte{byte(len(testVersion))},
   225  				[]byte(testVersion),
   226  				[]byte{0},      // ensure termination.
   227  				u32.LeBytes(2), // number of functions.
   228  				// Function index = 0.
   229  				u64.LeBytes(12345), // stack pointer ceil.
   230  				u64.LeBytes(5),     // offset.
   231  				// Function index = 1.
   232  			),
   233  			expErr: "compilationcache: error reading func[1] stack pointer ceil: EOF",
   234  		},
   235  		{
   236  			name: "reading executable offset",
   237  			in: concat(
   238  				[]byte(wazeroMagic),
   239  				[]byte{byte(len(testVersion))},
   240  				[]byte(testVersion),
   241  				[]byte{0},      // ensure termination.
   242  				u32.LeBytes(2), // number of functions.
   243  				// Function index = 0.
   244  				u64.LeBytes(12345), // stack pointer ceil.
   245  				u64.LeBytes(5),     // offset.
   246  				// Function index = 1.
   247  				u64.LeBytes(12345), // stack pointer ceil.
   248  			),
   249  			expErr: "compilationcache: error reading func[1] executable offset: EOF",
   250  		},
   251  		{
   252  			name: "mmapping",
   253  			in: concat(
   254  				[]byte(wazeroMagic),
   255  				[]byte{byte(len(testVersion))},
   256  				[]byte(testVersion),
   257  				[]byte{0},      // ensure termination.
   258  				u32.LeBytes(2), // number of functions.
   259  				// Function index = 0.
   260  				u64.LeBytes(12345), // stack pointer ceil.
   261  				u64.LeBytes(0),     // offset.
   262  				// Function index = 1.
   263  				u64.LeBytes(12345), // stack pointer ceil.
   264  				u64.LeBytes(5),     // offset.
   265  				// Executable.
   266  				u64.LeBytes(5), // size of the executable.
   267  				// Lack of machine code here.
   268  			),
   269  			expErr: "compilationcache: error reading executable (len=5): EOF",
   270  		},
   271  	}
   272  
   273  	for _, tc := range tests {
   274  		tc := tc
   275  		t.Run(tc.name, func(t *testing.T) {
   276  			cm, staleCache, err := deserializeCompiledModule(testVersion, io.NopCloser(bytes.NewReader(tc.in)),
   277  				&wasm.Module{ImportFunctionCount: tc.importedFunctionCount})
   278  
   279  			if tc.expCompiledModule != nil {
   280  				require.Equal(t, len(tc.expCompiledModule.functions), len(cm.functions))
   281  				for i := 0; i < len(cm.functions); i++ {
   282  					require.Equal(t, cm, cm.functions[i].parent)
   283  					tc.expCompiledModule.functions[i].parent = cm
   284  				}
   285  			}
   286  
   287  			if tc.expErr != "" {
   288  				require.EqualError(t, err, tc.expErr)
   289  			} else {
   290  				require.NoError(t, err)
   291  				require.Equal(t, tc.expCompiledModule, cm)
   292  			}
   293  
   294  			require.Equal(t, tc.expStaleCache, staleCache)
   295  		})
   296  	}
   297  }
   298  
   299  func TestEngine_getCompiledModuleFromCache(t *testing.T) {
   300  	valid := concat(
   301  		[]byte(wazeroMagic),
   302  		[]byte{byte(len(testVersion))},
   303  		[]byte(testVersion),
   304  		[]byte{0},      // ensure termination.
   305  		u32.LeBytes(2), // number of functions.
   306  		// Function index = 0.
   307  		u64.LeBytes(12345), // stack pointer ceil.
   308  		u64.LeBytes(0),     // offset.
   309  		// Function index = 1.
   310  		u64.LeBytes(0xffffffff), // stack pointer ceil.
   311  		u64.LeBytes(5),          // offset.
   312  		// executables.
   313  		u64.LeBytes(10),                       // length of code.
   314  		[]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, // code.
   315  	)
   316  
   317  	tests := []struct {
   318  		name              string
   319  		ext               map[wasm.ModuleID][]byte
   320  		key               wasm.ModuleID
   321  		isHostMod         bool
   322  		expCompiledModule *compiledModule
   323  		expHit            bool
   324  		expErr            string
   325  		expDeleted        bool
   326  	}{
   327  		{name: "extern cache not given"},
   328  		{
   329  			name: "not hit",
   330  			ext:  map[wasm.ModuleID][]byte{},
   331  		},
   332  		{
   333  			name:      "host module",
   334  			ext:       map[wasm.ModuleID][]byte{{}: valid},
   335  			isHostMod: true,
   336  		},
   337  		{
   338  			name:   "error in Cache.Get",
   339  			ext:    map[wasm.ModuleID][]byte{{}: {}},
   340  			expErr: "compilationcache: error reading header: EOF",
   341  		},
   342  		{
   343  			name:   "error in deserialization",
   344  			ext:    map[wasm.ModuleID][]byte{{}: {1, 2, 3}},
   345  			expErr: "compilationcache: invalid header length: 3",
   346  		},
   347  		{
   348  			name: "stale cache",
   349  			ext: map[wasm.ModuleID][]byte{{}: concat(
   350  				[]byte(wazeroMagic),
   351  				[]byte{byte(len("1233123.1.1"))},
   352  				[]byte("1233123.1.1"),
   353  				u32.LeBytes(1), // number of functions.
   354  			)},
   355  			expDeleted: true,
   356  		},
   357  		{
   358  			name: "hit",
   359  			ext: map[wasm.ModuleID][]byte{
   360  				{}: valid,
   361  			},
   362  			expHit: true,
   363  			expCompiledModule: &compiledModule{
   364  				executable: makeCodeSegment(1, 2, 3, 4, 5, 6, 7, 8, 9, 10),
   365  				functions: []compiledFunction{
   366  					{stackPointerCeil: 12345, executableOffset: 0, index: 0},
   367  					{stackPointerCeil: 0xffffffff, executableOffset: 5, index: 1},
   368  				},
   369  				source:            nil,
   370  				ensureTermination: false,
   371  			},
   372  		},
   373  	}
   374  
   375  	for _, tc := range tests {
   376  		tc := tc
   377  		t.Run(tc.name, func(t *testing.T) {
   378  			m := &wasm.Module{ID: tc.key, IsHostModule: tc.isHostMod}
   379  			if exp := tc.expCompiledModule; exp != nil {
   380  				exp.source = m
   381  				for i := range tc.expCompiledModule.functions {
   382  					tc.expCompiledModule.functions[i].parent = exp
   383  				}
   384  			}
   385  
   386  			e := engine{}
   387  			if tc.ext != nil {
   388  				tmp := t.TempDir()
   389  				e.fileCache = filecache.New(tmp)
   390  				for key, value := range tc.ext {
   391  					err := e.fileCache.Add(key, bytes.NewReader(value))
   392  					require.NoError(t, err)
   393  				}
   394  			}
   395  
   396  			codes, hit, err := e.getCompiledModuleFromCache(m)
   397  			if tc.expErr != "" {
   398  				require.EqualError(t, err, tc.expErr)
   399  			} else {
   400  				require.NoError(t, err)
   401  			}
   402  
   403  			require.Equal(t, tc.expHit, hit)
   404  			require.Equal(t, tc.expCompiledModule, codes)
   405  
   406  			if tc.ext != nil && tc.expDeleted {
   407  				_, hit, err := e.fileCache.Get(tc.key)
   408  				require.NoError(t, err)
   409  				require.False(t, hit)
   410  			}
   411  		})
   412  	}
   413  }
   414  
   415  func TestEngine_addCompiledModuleToCache(t *testing.T) {
   416  	t.Run("not defined", func(t *testing.T) {
   417  		e := engine{}
   418  		err := e.addCompiledModuleToCache(nil, nil)
   419  		require.NoError(t, err)
   420  	})
   421  	t.Run("host module", func(t *testing.T) {
   422  		tc := filecache.New(t.TempDir())
   423  		e := engine{fileCache: tc}
   424  		cm := &compiledModule{
   425  			executable: makeCodeSegment(1, 2, 3),
   426  			functions:  []compiledFunction{{stackPointerCeil: 123}},
   427  		}
   428  		m := &wasm.Module{ID: sha256.Sum256(nil), IsHostModule: true} // Host module!
   429  		err := e.addCompiledModuleToCache(m, cm)
   430  		require.NoError(t, err)
   431  		// Check the host module not cached.
   432  		_, hit, err := tc.Get(m.ID)
   433  		require.NoError(t, err)
   434  		require.False(t, hit)
   435  	})
   436  	t.Run("add", func(t *testing.T) {
   437  		tc := filecache.New(t.TempDir())
   438  		e := engine{fileCache: tc}
   439  		m := &wasm.Module{}
   440  		cm := &compiledModule{
   441  			executable: makeCodeSegment(1, 2, 3),
   442  			functions:  []compiledFunction{{stackPointerCeil: 123}},
   443  		}
   444  		err := e.addCompiledModuleToCache(m, cm)
   445  		require.NoError(t, err)
   446  
   447  		content, ok, err := tc.Get(m.ID)
   448  		require.NoError(t, err)
   449  		require.True(t, ok)
   450  		actual, err := io.ReadAll(content)
   451  		require.NoError(t, err)
   452  		require.Equal(t, concat(
   453  			[]byte(wazeroMagic),
   454  			[]byte{byte(len(testVersion))},
   455  			[]byte(testVersion),
   456  			[]byte{0},
   457  			u32.LeBytes(1),   // number of functions.
   458  			u64.LeBytes(123), // stack pointer ceil.
   459  			u64.LeBytes(0),   // offset.
   460  			u64.LeBytes(3),   // size of executable.
   461  			[]byte{1, 2, 3},
   462  		), actual)
   463  		require.NoError(t, content.Close())
   464  	})
   465  }
   466  
   467  func Test_readUint64(t *testing.T) {
   468  	tests := []struct {
   469  		name  string
   470  		input uint64
   471  	}{
   472  		{
   473  			name:  "zero",
   474  			input: 0,
   475  		},
   476  		{
   477  			name:  "half",
   478  			input: math.MaxUint32,
   479  		},
   480  		{
   481  			name:  "max",
   482  			input: math.MaxUint64,
   483  		},
   484  	}
   485  
   486  	for _, tt := range tests {
   487  		tc := tt
   488  
   489  		t.Run(tc.name, func(t *testing.T) {
   490  			input := make([]byte, 8)
   491  			binary.LittleEndian.PutUint64(input, tc.input)
   492  
   493  			var b [8]byte
   494  			n, err := readUint64(bytes.NewReader(input), &b)
   495  			require.NoError(t, err)
   496  			require.Equal(t, tc.input, n)
   497  
   498  			// ensure the buffer was cleared
   499  			var expectedB [8]byte
   500  			require.Equal(t, expectedB, b)
   501  		})
   502  	}
   503  }
   504  
   505  func Test_readUint64_errors(t *testing.T) {
   506  	tests := []struct {
   507  		name        string
   508  		input       io.Reader
   509  		expectedErr string
   510  	}{
   511  		{
   512  			name:        "zero",
   513  			input:       bytes.NewReader([]byte{}),
   514  			expectedErr: "EOF",
   515  		},
   516  		{
   517  			name:        "not enough",
   518  			input:       bytes.NewReader([]byte{1, 2}),
   519  			expectedErr: "EOF",
   520  		},
   521  		{
   522  			name:        "error reading",
   523  			input:       iotest.ErrReader(errors.New("ice cream")),
   524  			expectedErr: "ice cream",
   525  		},
   526  	}
   527  
   528  	for _, tt := range tests {
   529  		tc := tt
   530  
   531  		t.Run(tc.name, func(t *testing.T) {
   532  			var b [8]byte
   533  			_, err := readUint64(tc.input, &b)
   534  			require.EqualError(t, err, tc.expectedErr)
   535  		})
   536  	}
   537  }