wa-lang.org/wazero@v1.0.2/internal/wasm/host.go (about)

     1  package wasm
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"sort"
     7  
     8  	"wa-lang.org/wazero/api"
     9  	"wa-lang.org/wazero/internal/wasmdebug"
    10  )
    11  
    12  type ProxyFuncExporter interface {
    13  	ExportProxyFunc(*ProxyFunc)
    14  }
    15  
    16  // ProxyFunc is a function defined both in wasm and go. This is used to
    17  // optimize the Go signature or obviate calls based on what can be done
    18  // mechanically in wasm.
    19  type ProxyFunc struct {
    20  	// Proxy must be a wasm func
    21  	Proxy *HostFunc
    22  	// Proxied should be a go func.
    23  	Proxied *HostFunc
    24  
    25  	// CallBodyPos is the position in Code.Body of the caller to replace the
    26  	// real funcIdx of the proxied.
    27  	CallBodyPos int
    28  }
    29  
    30  func (p *ProxyFunc) Name() string {
    31  	return p.Proxied.Name
    32  }
    33  
    34  type HostFuncExporter interface {
    35  	ExportHostFunc(*HostFunc)
    36  }
    37  
    38  // HostFunc is a function with an inlined type, used for NewHostModule.
    39  // Any corresponding FunctionType will be reused or added to the Module.
    40  type HostFunc struct {
    41  	// ExportNames is equivalent to  the same method on api.FunctionDefinition.
    42  	ExportNames []string
    43  
    44  	// Name is equivalent to  the same method on api.FunctionDefinition.
    45  	Name string
    46  
    47  	// ParamTypes is equivalent to  the same method on api.FunctionDefinition.
    48  	ParamTypes []ValueType
    49  
    50  	// ParamNames is equivalent to  the same method on api.FunctionDefinition.
    51  	ParamNames []string
    52  
    53  	// ResultTypes is equivalent to  the same method on api.FunctionDefinition.
    54  	ResultTypes []ValueType
    55  
    56  	// Code is the equivalent function in the SectionIDCode.
    57  	Code *Code
    58  }
    59  
    60  // MustGoReflectFunc calls WithGoReflectFunc or panics on error.
    61  func (f *HostFunc) MustGoReflectFunc(fn interface{}) *HostFunc {
    62  	if ret, err := f.WithGoReflectFunc(fn); err != nil {
    63  		panic(err)
    64  	} else {
    65  		return ret
    66  	}
    67  }
    68  
    69  // WithGoFunc returns a copy of the function, replacing its Code.GoFunc.
    70  func (f *HostFunc) WithGoFunc(fn api.GoFunc) *HostFunc {
    71  	ret := *f
    72  	ret.Code = &Code{IsHostFunction: true, GoFunc: fn}
    73  	return &ret
    74  }
    75  
    76  // WithGoModuleFunc returns a copy of the function, replacing its Code.GoFunc.
    77  func (f *HostFunc) WithGoModuleFunc(fn api.GoModuleFunc) *HostFunc {
    78  	ret := *f
    79  	ret.Code = &Code{IsHostFunction: true, GoFunc: fn}
    80  	return &ret
    81  }
    82  
    83  // WithGoReflectFunc returns a copy of the function, replacing its Code.GoFunc.
    84  func (f *HostFunc) WithGoReflectFunc(fn interface{}) (*HostFunc, error) {
    85  	ret := *f
    86  	var err error
    87  	ret.ParamTypes, ret.ResultTypes, ret.Code, err = parseGoReflectFunc(fn)
    88  	return &ret, err
    89  }
    90  
    91  // WithWasm returns a copy of the function, replacing its Code.Body.
    92  func (f *HostFunc) WithWasm(body []byte) *HostFunc {
    93  	ret := *f
    94  	ret.Code = &Code{IsHostFunction: true, Body: body}
    95  	if f.Code != nil {
    96  		ret.Code.LocalTypes = f.Code.LocalTypes
    97  	}
    98  	return &ret
    99  }
   100  
   101  // NewHostModule is defined internally for use in WASI tests and to keep the code size in the root directory small.
   102  func NewHostModule(
   103  	moduleName string,
   104  	nameToGoFunc map[string]interface{},
   105  	funcToNames map[string][]string,
   106  	enabledFeatures api.CoreFeatures,
   107  ) (m *Module, err error) {
   108  	if moduleName != "" {
   109  		m = &Module{NameSection: &NameSection{ModuleName: moduleName}}
   110  	} else {
   111  		m = &Module{}
   112  	}
   113  
   114  	if exportCount := uint32(len(nameToGoFunc)); exportCount > 0 {
   115  		m.ExportSection = make([]*Export, 0, exportCount)
   116  		if err = addFuncs(m, nameToGoFunc, funcToNames, enabledFeatures); err != nil {
   117  			return
   118  		}
   119  	}
   120  
   121  	// Assigns the ModuleID by calculating sha256 on inputs as host modules do not have `wasm` to hash.
   122  	m.AssignModuleID([]byte(fmt.Sprintf("%s:%v:%v", moduleName, nameToGoFunc, enabledFeatures)))
   123  	m.BuildFunctionDefinitions()
   124  	return
   125  }
   126  
   127  // maxProxiedFuncIdx is the maximum index where leb128 encoding matches the bit
   128  // of an unsigned literal byte. Using this simplifies host function index
   129  // substitution.
   130  //
   131  // Note: this is 127, not 255 because when the MSB is set, leb128 encoding
   132  // doesn't match the literal byte.
   133  const maxProxiedFuncIdx = 127
   134  
   135  func addFuncs(
   136  	m *Module,
   137  	nameToGoFunc map[string]interface{},
   138  	funcToNames map[string][]string,
   139  	enabledFeatures api.CoreFeatures,
   140  ) (err error) {
   141  	if m.NameSection == nil {
   142  		m.NameSection = &NameSection{}
   143  	}
   144  	moduleName := m.NameSection.ModuleName
   145  	nameToFunc := make(map[string]*HostFunc, len(nameToGoFunc))
   146  	sortedExportNames := make([]string, len(nameToFunc))
   147  	for k := range nameToGoFunc {
   148  		sortedExportNames = append(sortedExportNames, k)
   149  	}
   150  
   151  	// Sort names for consistent iteration
   152  	sort.Strings(sortedExportNames)
   153  
   154  	funcNames := make([]string, len(nameToFunc))
   155  	for _, k := range sortedExportNames {
   156  		v := nameToGoFunc[k]
   157  		if hf, ok := v.(*HostFunc); ok {
   158  			nameToFunc[hf.Name] = hf
   159  			funcNames = append(funcNames, hf.Name)
   160  		} else if pf, ok := v.(*ProxyFunc); ok {
   161  			// First, add the proxied function which also gives us the real
   162  			// position in the function index namespace, We will need this
   163  			// later. We've kept code simpler by limiting the max index to
   164  			// what is encodable in a single byte. This is ok as we don't have
   165  			// any current use cases for hundreds of proxy functions.
   166  			proxiedIdx := len(funcNames)
   167  			if proxiedIdx > maxProxiedFuncIdx {
   168  				return errors.New("TODO: proxied funcidx larger than one byte")
   169  			}
   170  			nameToFunc[pf.Proxied.Name] = pf.Proxied
   171  			funcNames = append(funcNames, pf.Proxied.Name)
   172  
   173  			// Now that we have the real index of the proxied function,
   174  			// substitute that for the zero placeholder in the proxy's code
   175  			// body. This placeholder is at index CallBodyPos in the slice.
   176  			proxyBody := make([]byte, len(pf.Proxy.Code.Body))
   177  			copy(proxyBody, pf.Proxy.Code.Body)
   178  			proxyBody[pf.CallBodyPos] = byte(proxiedIdx)
   179  			proxy := pf.Proxy.WithWasm(proxyBody)
   180  
   181  			nameToFunc[proxy.Name] = proxy
   182  			funcNames = append(funcNames, proxy.Name)
   183  		} else { // reflection
   184  			params, results, code, ftErr := parseGoReflectFunc(v)
   185  			if ftErr != nil {
   186  				return fmt.Errorf("func[%s.%s] %w", moduleName, k, ftErr)
   187  			}
   188  			hf = &HostFunc{
   189  				ExportNames: []string{k},
   190  				Name:        k,
   191  				ParamTypes:  params,
   192  				ResultTypes: results,
   193  				Code:        code,
   194  			}
   195  			if names := funcToNames[k]; names != nil {
   196  				namesLen := len(names)
   197  				if namesLen > 1 && namesLen-1 != len(params) {
   198  					return fmt.Errorf("func[%s.%s] has %d params, but %d param names", moduleName, k, namesLen-1, len(params))
   199  				}
   200  				hf.Name = names[0]
   201  				hf.ParamNames = names[1:]
   202  			}
   203  			nameToFunc[k] = hf
   204  			funcNames = append(funcNames, k)
   205  		}
   206  	}
   207  
   208  	funcCount := uint32(len(nameToFunc))
   209  	m.NameSection.FunctionNames = make([]*NameAssoc, 0, funcCount)
   210  	m.FunctionSection = make([]Index, 0, funcCount)
   211  	m.CodeSection = make([]*Code, 0, funcCount)
   212  	m.FunctionDefinitionSection = make([]*FunctionDefinition, 0, funcCount)
   213  
   214  	idx := Index(0)
   215  	for _, name := range funcNames {
   216  		hf := nameToFunc[name]
   217  		debugName := wasmdebug.FuncName(moduleName, name, idx)
   218  		typeIdx, typeErr := m.maybeAddType(hf.ParamTypes, hf.ResultTypes, enabledFeatures)
   219  		if typeErr != nil {
   220  			return fmt.Errorf("func[%s] %v", debugName, typeErr)
   221  		}
   222  		m.FunctionSection = append(m.FunctionSection, typeIdx)
   223  		m.CodeSection = append(m.CodeSection, hf.Code)
   224  		for _, export := range hf.ExportNames {
   225  			m.ExportSection = append(m.ExportSection, &Export{Type: ExternTypeFunc, Name: export, Index: idx})
   226  		}
   227  		m.NameSection.FunctionNames = append(m.NameSection.FunctionNames, &NameAssoc{Index: idx, Name: hf.Name})
   228  		if len(hf.ParamNames) > 0 {
   229  			localNames := &NameMapAssoc{Index: idx}
   230  			for i, n := range hf.ParamNames {
   231  				localNames.NameMap = append(localNames.NameMap, &NameAssoc{Index: Index(i), Name: n})
   232  			}
   233  			m.NameSection.LocalNames = append(m.NameSection.LocalNames, localNames)
   234  		}
   235  		idx++
   236  	}
   237  	return nil
   238  }
   239  
   240  func (m *Module) maybeAddType(params, results []ValueType, enabledFeatures api.CoreFeatures) (Index, error) {
   241  	if len(results) > 1 {
   242  		// Guard >1.0 feature multi-value
   243  		if err := enabledFeatures.RequireEnabled(api.CoreFeatureMultiValue); err != nil {
   244  			return 0, fmt.Errorf("multiple result types invalid as %v", err)
   245  		}
   246  	}
   247  	for i, t := range m.TypeSection {
   248  		if t.EqualsSignature(params, results) {
   249  			return Index(i), nil
   250  		}
   251  	}
   252  
   253  	result := m.SectionElementCount(SectionIDType)
   254  	toAdd := &FunctionType{Params: params, Results: results}
   255  	m.TypeSection = append(m.TypeSection, toAdd)
   256  	return result, nil
   257  }