github.com/moontrade/wavm-go@v0.3.2-0.20220316110326-d229dd66ad65/worker/loader.go (about)

     1  package worker
     2  
     3  import "C"
     4  import (
     5  	"errors"
     6  	"fmt"
     7  	"github.com/moontrade/wavm-go"
     8  	"sync"
     9  	"unsafe"
    10  )
    11  
    12  func DefaultEngine() *wavm.Engine {
    13  	return wavm.NewEngineWithConfig(wavm.NewConfigAll().SetMultiMemory(false))
    14  }
    15  
    16  type Loader struct {
    17  	engineFactory       func() *wavm.Engine
    18  	counter             int64
    19  	fd_write_type       *wavm.FuncType
    20  	clock_time_get_type *wavm.FuncType
    21  	args_sizes_get_type *wavm.FuncType
    22  	args_get_type       *wavm.FuncType
    23  	setTimeout_type     *wavm.FuncType
    24  	modImports          []wavm.Import
    25  	modExports          []wavm.Export
    26  	imports             []*wavm.Extern
    27  	exports             []*wavm.Extern
    28  	compartment         *wavm.Compartment
    29  	mu                  sync.Mutex
    30  }
    31  
    32  func NewLoader(engineFactory func() *wavm.Engine) *Loader {
    33  	if engineFactory == nil {
    34  		engineFactory = DefaultEngine
    35  	}
    36  	return &Loader{
    37  		engineFactory: engineFactory,
    38  		fd_write_type: wavm.FuncType_4_1(
    39  			wavm.ValTypeI32(), wavm.ValTypeI32(), wavm.ValTypeI32(), wavm.ValTypeI32(),
    40  			wavm.ValTypeI32(),
    41  		),
    42  		clock_time_get_type: wavm.FuncType_3_1(
    43  			wavm.ValTypeI32(), wavm.ValTypeI64(), wavm.ValTypeI32(),
    44  			wavm.ValTypeI32(),
    45  		),
    46  		args_sizes_get_type: wavm.FuncType_2_1(
    47  			wavm.ValTypeI32(), wavm.ValTypeI32(),
    48  			wavm.ValTypeI32(),
    49  		),
    50  		args_get_type: wavm.FuncType_2_1(
    51  			wavm.ValTypeI32(), wavm.ValTypeI32(),
    52  			wavm.ValTypeI32(),
    53  		),
    54  		setTimeout_type: wavm.FuncType_1_0(wavm.ValTypeI64()),
    55  		modImports:      make([]wavm.Import, 0, 32),
    56  		modExports:      make([]wavm.Export, 0, 32),
    57  		imports:         make([]*wavm.Extern, 0, 32),
    58  	}
    59  }
    60  
    61  func (wl *Loader) Close() error {
    62  	wl.fd_write_type.Delete()
    63  	wl.clock_time_get_type.Delete()
    64  	wl.args_sizes_get_type.Delete()
    65  	wl.args_get_type.Delete()
    66  	wl.setTimeout_type.Delete()
    67  	return nil
    68  }
    69  
    70  func (wl *Loader) findModuleExport(name string) *wavm.Export {
    71  	for i := 0; i < len(wl.modExports); i++ {
    72  		export := &wl.modExports[i]
    73  		if export.NameUnsafe() == name {
    74  			return export
    75  		}
    76  	}
    77  	return nil
    78  }
    79  
    80  func (wl *Loader) Load(precompiled, trace bool, binary []byte, maxTableElems, maxMemoryPages int32) (*Worker, error) {
    81  	wl.mu.Lock()
    82  	defer wl.mu.Unlock()
    83  
    84  	// Init Engine
    85  	engine := wl.engineFactory()
    86  
    87  	//compartment := wl.compartment
    88  	//if compartment == nil {
    89  	//	wl.compartment = engine.NewCompartment("")
    90  	//	compartment = wl.compartment
    91  	//}
    92  	compartment := engine.NewCompartment("")
    93  
    94  	// Init store
    95  	store := compartment.NewStore("")
    96  	// Load module
    97  	var module *wavm.Module
    98  	if precompiled {
    99  		module = engine.NewPrecompiledModule(binary)
   100  	} else {
   101  		module = engine.NewModule(binary)
   102  	}
   103  
   104  	wl.counter++
   105  	// Init worker
   106  	worker := &Worker{
   107  		id:          wl.counter,
   108  		engine:      engine,
   109  		compartment: compartment,
   110  		store:       store,
   111  		funcCall:    (*C.void)(wavm.WASMFuncCall),
   112  	}
   113  
   114  	// Module imports
   115  	wl.modImports = wl.modImports[:0]
   116  	wl.modImports = module.Imports(wl.modImports)
   117  	for _, imp := range wl.modImports {
   118  		if trace {
   119  			println("import", imp.ModuleUnsafe(), "name", imp.NameUnsafe())
   120  		}
   121  	}
   122  
   123  	// Module exports
   124  	wl.modExports = wl.modExports[:0]
   125  	wl.modExports = module.Exports(wl.modExports)
   126  	for _, export := range wl.modExports {
   127  		if trace {
   128  			println("export", export.NameUnsafe())
   129  		}
   130  	}
   131  
   132  	var (
   133  		// Create import funcs
   134  		fd_write = compartment.NewFunc(
   135  			wl.fd_write_type,
   136  			moontrade_fd_write(),
   137  			//(wavm.FuncCallback)(C.moontrade_fd_write),
   138  			//"fd_write",
   139  			"",
   140  		)
   141  		clock_time_get = compartment.NewFunc(
   142  			wl.clock_time_get_type,
   143  			moontrade_clock_time_get(),
   144  			//(wavm.FuncCallback)(C.moontrade_clock_time_get),
   145  			//"clock_time_get",
   146  			"",
   147  		)
   148  		args_sizes_get = compartment.NewFunc(
   149  			wl.args_sizes_get_type,
   150  			moontrade_args_sizes_get(),
   151  			//(wavm.FuncCallback)(C.moontrade_args_sizes_get),
   152  			//"args_sizes_get",
   153  			"",
   154  		)
   155  		args_get = compartment.NewFunc(
   156  			wl.args_get_type,
   157  			moontrade_args_get(),
   158  			//(wavm.FuncCallback)(C.moontrade_args_get),
   159  			//"args_get",
   160  			"",
   161  		)
   162  		setTimeout = compartment.NewFunc(
   163  			wl.setTimeout_type,
   164  			moontrade_set_timeout(),
   165  			//(wavm.FuncCallback)(C.moontrade_set_timeout),
   166  			//"setTimeout",
   167  			"",
   168  		)
   169  	)
   170  
   171  	// Instance imports
   172  	wl.imports = wl.imports[:0]
   173  	wl.imports = append(wl.imports, fd_write.AsExtern())
   174  	wl.imports = append(wl.imports, clock_time_get.AsExtern())
   175  	wl.imports = append(wl.imports, args_sizes_get.AsExtern())
   176  	wl.imports = append(wl.imports, args_get.AsExtern())
   177  	wl.imports = append(wl.imports, setTimeout.AsExtern())
   178  
   179  	// New instance
   180  	var trap *wavm.Trap
   181  	instance := store.NewInstanceWithQuota(module, wl.imports, &trap, maxTableElems, maxMemoryPages, false, "")
   182  	// Error?
   183  	if trap != nil {
   184  		// Clean up.
   185  		for _, imp := range wl.imports {
   186  			imp.AsFunc().Delete()
   187  		}
   188  		err := fmt.Errorf(trap.String())
   189  		trap.Delete()
   190  		return nil, err
   191  	}
   192  
   193  	fd_write.Delete()
   194  	clock_time_get.Delete()
   195  	args_sizes_get.Delete()
   196  	args_get.Delete()
   197  	setTimeout.Delete()
   198  	// Cleanup wasm_func_t
   199  	//for _, imp := range wl.imports {
   200  	//	fn := imp.AsFunc()
   201  	//	fn.Delete()
   202  	//	//imp.AsFunc().Delete()
   203  	//}
   204  
   205  	// Instance exports
   206  
   207  	wl.exports = wl.exports[:0]
   208  	wl.exports = instance.Exports(wl.exports)
   209  
   210  	// Assert
   211  	if len(wl.exports) != len(wl.modExports) {
   212  		instance.Delete()
   213  		module.Delete()
   214  		store.Delete()
   215  		compartment.Delete()
   216  		engine.Delete()
   217  		return nil, errors.New("instance exports and module exports don't match")
   218  	}
   219  
   220  	// Map exports to worker
   221  	for i, export := range wl.exports {
   222  		modExport := &wl.modExports[i]
   223  
   224  		switch export.AsKind() {
   225  		case wavm.ExternFunc:
   226  			fn := export.AsFunc()
   227  			if trace {
   228  				println("func export", modExport.NameUnsafe())
   229  			}
   230  			switch modExport.NameUnsafe() {
   231  			case "_start":
   232  				worker.start = fn
   233  			case "resume":
   234  				worker.resume = fn
   235  			case "alloc":
   236  				worker.alloc = fn
   237  			case "realloc":
   238  				worker.realloc = fn
   239  			case "free":
   240  				worker.free = fn
   241  			case "stub":
   242  				worker.stub = fn
   243  			}
   244  		case wavm.ExternTable:
   245  		case wavm.ExternMemory:
   246  			worker.memory = export.AsMemory()
   247  		case wavm.ExternGlobal:
   248  		}
   249  	}
   250  
   251  	// Clean up module
   252  	module.Delete()
   253  	// Clean up instance
   254  	instance.Delete()
   255  
   256  	// Init memory
   257  	data := worker.memory.Data()
   258  	pages := worker.memory.Pages()
   259  	size := worker.memory.Size()
   260  	//C.moontrade_memory_set(memory)
   261  	if trace {
   262  		fmt.Println("data", uintptr(unsafe.Pointer(data)), "pages", uint(pages), "size", size)
   263  	}
   264  
   265  	worker.init()
   266  
   267  	return worker, nil
   268  }