github.com/wasilibs/wazerox@v0.0.0-20240124024944-4923be63ab5f/internal/wasm/module_instance.go (about)

     1  package wasm
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  
     8  	"github.com/wasilibs/wazerox/api"
     9  	"github.com/wasilibs/wazerox/sys"
    10  )
    11  
    12  // FailIfClosed returns a sys.ExitError if CloseWithExitCode was called.
    13  func (m *ModuleInstance) FailIfClosed() (err error) {
    14  	if closed := m.Closed.Load(); closed != 0 {
    15  		switch closed & exitCodeFlagMask {
    16  		case exitCodeFlagResourceClosed:
    17  		case exitCodeFlagResourceNotClosed:
    18  			// This happens when this module is closed asynchronously in CloseModuleOnCanceledOrTimeout,
    19  			// and the closure of resources have been deferred here.
    20  			_ = m.ensureResourcesClosed(context.Background())
    21  		}
    22  		return sys.NewExitError(uint32(closed >> 32)) // Unpack the high order bits as the exit code.
    23  	}
    24  	return nil
    25  }
    26  
    27  // CloseModuleOnCanceledOrTimeout take a context `ctx`, which might be a Cancel or Timeout context,
    28  // and spawns the Goroutine to check the context is canceled ot deadline exceeded. If it reaches
    29  // one of the conditions, it sets the appropriate exit code.
    30  //
    31  // Callers of this function must invoke the returned context.CancelFunc to release the spawned Goroutine.
    32  func (m *ModuleInstance) CloseModuleOnCanceledOrTimeout(ctx context.Context) context.CancelFunc {
    33  	// Creating an empty channel in this case is a bit more efficient than
    34  	// creating a context.Context and canceling it with the same effect. We
    35  	// really just need to be notified when to stop listening to the users
    36  	// context. Closing the channel will unblock the select in the goroutine
    37  	// causing it to return an stop listening to ctx.Done().
    38  	cancelChan := make(chan struct{})
    39  	go m.closeModuleOnCanceledOrTimeout(ctx, cancelChan)
    40  	return func() { close(cancelChan) }
    41  }
    42  
    43  // closeModuleOnCanceledOrTimeout is extracted from CloseModuleOnCanceledOrTimeout for testing.
    44  func (m *ModuleInstance) closeModuleOnCanceledOrTimeout(ctx context.Context, cancelChan <-chan struct{}) {
    45  	select {
    46  	case <-ctx.Done():
    47  		select {
    48  		case <-cancelChan:
    49  			// In some cases by the time this goroutine is scheduled, the caller
    50  			// has already closed both the context and the cancelChan. In this
    51  			// case go will randomize which branch of the outer select to enter
    52  			// and we don't want to close the module.
    53  		default:
    54  			// This is the same logic as CloseWithCtxErr except this calls closeWithExitCodeWithoutClosingResource
    55  			// so that we can defer the resource closure in FailIfClosed.
    56  			switch {
    57  			case errors.Is(ctx.Err(), context.Canceled):
    58  				// TODO: figure out how to report error here.
    59  				_ = m.closeWithExitCodeWithoutClosingResource(sys.ExitCodeContextCanceled)
    60  			case errors.Is(ctx.Err(), context.DeadlineExceeded):
    61  				// TODO: figure out how to report error here.
    62  				_ = m.closeWithExitCodeWithoutClosingResource(sys.ExitCodeDeadlineExceeded)
    63  			}
    64  		}
    65  	case <-cancelChan:
    66  	}
    67  }
    68  
    69  // CloseWithCtxErr closes the module with an exit code based on the type of
    70  // error reported by the context.
    71  //
    72  // If the context's error is unknown or nil, the module does not close.
    73  func (m *ModuleInstance) CloseWithCtxErr(ctx context.Context) {
    74  	switch {
    75  	case errors.Is(ctx.Err(), context.Canceled):
    76  		// TODO: figure out how to report error here.
    77  		_ = m.CloseWithExitCode(ctx, sys.ExitCodeContextCanceled)
    78  	case errors.Is(ctx.Err(), context.DeadlineExceeded):
    79  		// TODO: figure out how to report error here.
    80  		_ = m.CloseWithExitCode(ctx, sys.ExitCodeDeadlineExceeded)
    81  	}
    82  }
    83  
    84  // Name implements the same method as documented on api.Module
    85  func (m *ModuleInstance) Name() string {
    86  	return m.ModuleName
    87  }
    88  
    89  // String implements the same method as documented on api.Module
    90  func (m *ModuleInstance) String() string {
    91  	return fmt.Sprintf("Module[%s]", m.Name())
    92  }
    93  
    94  // Close implements the same method as documented on api.Module.
    95  func (m *ModuleInstance) Close(ctx context.Context) (err error) {
    96  	return m.CloseWithExitCode(ctx, 0)
    97  }
    98  
    99  // CloseWithExitCode implements the same method as documented on api.Module.
   100  func (m *ModuleInstance) CloseWithExitCode(ctx context.Context, exitCode uint32) (err error) {
   101  	if !m.setExitCode(exitCode, exitCodeFlagResourceClosed) {
   102  		return nil // not an error to have already closed
   103  	}
   104  	_ = m.s.deleteModule(m)
   105  	return m.ensureResourcesClosed(ctx)
   106  }
   107  
   108  // IsClosed implements the same method as documented on api.Module.
   109  func (m *ModuleInstance) IsClosed() bool {
   110  	return m.Closed.Load() != 0
   111  }
   112  
   113  func (m *ModuleInstance) closeWithExitCodeWithoutClosingResource(exitCode uint32) (err error) {
   114  	if !m.setExitCode(exitCode, exitCodeFlagResourceNotClosed) {
   115  		return nil // not an error to have already closed
   116  	}
   117  	_ = m.s.deleteModule(m)
   118  	return nil
   119  }
   120  
   121  // closeWithExitCode is the same as CloseWithExitCode besides this doesn't delete it from Store.moduleList.
   122  func (m *ModuleInstance) closeWithExitCode(ctx context.Context, exitCode uint32) (err error) {
   123  	if !m.setExitCode(exitCode, exitCodeFlagResourceClosed) {
   124  		return nil // not an error to have already closed
   125  	}
   126  	// TODO(anuraaga): A shared memory probably needs to be closed somewhere other than the module level
   127  	if m.MemoryInstance != nil {
   128  		if err := m.MemoryInstance.Close(); err != nil {
   129  			return err
   130  		}
   131  	}
   132  	return m.ensureResourcesClosed(ctx)
   133  }
   134  
   135  type exitCodeFlag = uint64
   136  
   137  const exitCodeFlagMask = 0xff
   138  
   139  const (
   140  	// exitCodeFlagResourceClosed indicates that the module was closed and resources were already closed.
   141  	exitCodeFlagResourceClosed = 1 << iota
   142  	// exitCodeFlagResourceNotClosed indicates that the module was closed while resources are not closed yet.
   143  	exitCodeFlagResourceNotClosed
   144  )
   145  
   146  func (m *ModuleInstance) setExitCode(exitCode uint32, flag exitCodeFlag) bool {
   147  	closed := flag | uint64(exitCode)<<32 // Store exitCode as high-order bits.
   148  	return m.Closed.CompareAndSwap(0, closed)
   149  }
   150  
   151  // ensureResourcesClosed ensures that resources assigned to ModuleInstance is released.
   152  // Only one call will happen per module, due to external atomic guards on Closed.
   153  func (m *ModuleInstance) ensureResourcesClosed(ctx context.Context) (err error) {
   154  	if closeNotifier := m.CloseNotifier; closeNotifier != nil { // experimental
   155  		closeNotifier.CloseNotify(ctx, uint32(m.Closed.Load()>>32))
   156  		m.CloseNotifier = nil
   157  	}
   158  
   159  	if sysCtx := m.Sys; sysCtx != nil { // nil if from HostModuleBuilder
   160  		if err = sysCtx.FS().Close(); err != nil {
   161  			return err
   162  		}
   163  		m.Sys = nil
   164  	}
   165  
   166  	if m.CodeCloser == nil {
   167  		return
   168  	}
   169  	if e := m.CodeCloser.Close(ctx); e != nil && err == nil {
   170  		err = e
   171  	}
   172  	m.CodeCloser = nil
   173  	return
   174  }
   175  
   176  // Memory implements the same method as documented on api.Module.
   177  func (m *ModuleInstance) Memory() api.Memory {
   178  	return m.MemoryInstance
   179  }
   180  
   181  // ExportedMemory implements the same method as documented on api.Module.
   182  func (m *ModuleInstance) ExportedMemory(name string) api.Memory {
   183  	_, err := m.getExport(name, ExternTypeMemory)
   184  	if err != nil {
   185  		return nil
   186  	}
   187  	// We Assume that we have at most one memory.
   188  	return m.MemoryInstance
   189  }
   190  
   191  // ExportedMemoryDefinitions implements the same method as documented on
   192  // api.Module.
   193  func (m *ModuleInstance) ExportedMemoryDefinitions() map[string]api.MemoryDefinition {
   194  	// Special case as we currently only support one memory.
   195  	if mem := m.MemoryInstance; mem != nil {
   196  		// Now, find out if it is exported
   197  		for name, exp := range m.Exports {
   198  			if exp.Type == ExternTypeMemory {
   199  				return map[string]api.MemoryDefinition{name: mem.definition}
   200  			}
   201  		}
   202  	}
   203  	return map[string]api.MemoryDefinition{}
   204  }
   205  
   206  // ExportedFunction implements the same method as documented on api.Module.
   207  func (m *ModuleInstance) ExportedFunction(name string) api.Function {
   208  	exp, err := m.getExport(name, ExternTypeFunc)
   209  	if err != nil {
   210  		return nil
   211  	}
   212  	return m.Engine.NewFunction(exp.Index)
   213  }
   214  
   215  // ExportedFunctionDefinitions implements the same method as documented on
   216  // api.Module.
   217  func (m *ModuleInstance) ExportedFunctionDefinitions() map[string]api.FunctionDefinition {
   218  	result := map[string]api.FunctionDefinition{}
   219  	for name, exp := range m.Exports {
   220  		if exp.Type == ExternTypeFunc {
   221  			result[name] = m.Source.FunctionDefinition(exp.Index)
   222  		}
   223  	}
   224  	return result
   225  }
   226  
   227  // GlobalVal is an internal hack to get the lower 64 bits of a global.
   228  func (m *ModuleInstance) GlobalVal(idx Index) uint64 {
   229  	return m.Globals[idx].Val
   230  }
   231  
   232  // ExportedGlobal implements the same method as documented on api.Module.
   233  func (m *ModuleInstance) ExportedGlobal(name string) api.Global {
   234  	exp, err := m.getExport(name, ExternTypeGlobal)
   235  	if err != nil {
   236  		return nil
   237  	}
   238  	g := m.Globals[exp.Index]
   239  	if g.Type.Mutable {
   240  		return mutableGlobal{g: g}
   241  	}
   242  	return constantGlobal{g: g}
   243  }
   244  
   245  // NumGlobal implements experimental.InternalModule.
   246  func (m *ModuleInstance) NumGlobal() int {
   247  	return len(m.Globals)
   248  }
   249  
   250  // Global implements experimental.InternalModule.
   251  func (m *ModuleInstance) Global(idx int) api.Global {
   252  	return constantGlobal{g: m.Globals[idx]}
   253  }