wa-lang.org/wazero@v1.0.2/imports/proxywasm/_proxytest/root.go (about)

     1  // Copyright 2020-2021 Tetrate
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  // http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package proxytest
    16  
    17  import (
    18  	"fmt"
    19  	"log"
    20  	"strings"
    21  
    22  	"wa-lang.org/wazero/imports/proxywasm/internal"
    23  	"wa-lang.org/wazero/imports/proxywasm/types"
    24  )
    25  
    26  type (
    27  	rootHostEmulator struct {
    28  		activeCalloutID  uint32
    29  		logs             [internal.LogLevelMax][]string
    30  		tickPeriod       uint32
    31  		foreignFunctions map[string]func([]byte) []byte
    32  
    33  		queues        map[uint32][][]byte
    34  		queueNameID   map[string]uint32
    35  		sharedDataKVS map[string]*sharedData
    36  
    37  		httpContextIDToCalloutInfos map[uint32][]HttpCalloutAttribute // key: contextID
    38  		httpCalloutIDToContextID    map[uint32]uint32                 // key: calloutID
    39  		httpCalloutResponse         map[uint32]struct {               // key: calloutID
    40  			headers  [][2]string
    41  			trailers [][2]string
    42  			body     []byte
    43  		}
    44  
    45  		metricIDToType  map[uint32]internal.MetricType
    46  		metricNameToID  map[string]uint32
    47  		metricIDToValue map[uint32]uint64
    48  
    49  		pluginConfiguration, vmConfiguration []byte
    50  	}
    51  
    52  	HttpCalloutAttribute struct {
    53  		CalloutID uint32
    54  		Upstream  string
    55  		Headers   [][2]string
    56  		Trailers  [][2]string
    57  		Body      []byte
    58  	}
    59  
    60  	sharedData struct {
    61  		data []byte
    62  		cas  uint32
    63  	}
    64  )
    65  
    66  func newRootHostEmulator(pluginConfiguration, vmConfiguration []byte) *rootHostEmulator {
    67  	host := &rootHostEmulator{
    68  		foreignFunctions:            map[string]func([]byte) []byte{},
    69  		queues:                      map[uint32][][]byte{},
    70  		queueNameID:                 map[string]uint32{},
    71  		sharedDataKVS:               map[string]*sharedData{},
    72  		metricIDToValue:             map[uint32]uint64{},
    73  		metricIDToType:              map[uint32]internal.MetricType{},
    74  		metricNameToID:              map[string]uint32{},
    75  		httpContextIDToCalloutInfos: map[uint32][]HttpCalloutAttribute{},
    76  		httpCalloutIDToContextID:    map[uint32]uint32{},
    77  		httpCalloutResponse: map[uint32]struct {
    78  			headers  [][2]string
    79  			trailers [][2]string
    80  			body     []byte
    81  		}{},
    82  
    83  		pluginConfiguration: pluginConfiguration,
    84  		vmConfiguration:     vmConfiguration,
    85  	}
    86  	return host
    87  }
    88  
    89  // impl internal.ProxyWasmHost
    90  func (r *rootHostEmulator) ProxyLog(logLevel internal.LogLevel, messageData *byte, messageSize int) internal.Status {
    91  	str := internal.RawBytePtrToString(messageData, messageSize)
    92  
    93  	log.Printf("proxy_%s_log: %s", logLevel, str)
    94  	r.logs[logLevel] = append(r.logs[logLevel], str)
    95  	return internal.StatusOK
    96  }
    97  
    98  // impl internal.ProxyWasmHost
    99  func (r *rootHostEmulator) ProxySetTickPeriodMilliseconds(period uint32) internal.Status {
   100  	r.tickPeriod = period
   101  	return internal.StatusOK
   102  }
   103  
   104  // impl internal.ProxyWasmHost
   105  func (r *rootHostEmulator) ProxyRegisterSharedQueue(nameData *byte, nameSize int, returnID *uint32) internal.Status {
   106  	name := internal.RawBytePtrToString(nameData, nameSize)
   107  	if id, ok := r.queueNameID[name]; ok {
   108  		*returnID = id
   109  		return internal.StatusOK
   110  	}
   111  
   112  	id := uint32(len(r.queues))
   113  	r.queues[id] = [][]byte{}
   114  	r.queueNameID[name] = id
   115  	*returnID = id
   116  	return internal.StatusOK
   117  }
   118  
   119  // impl internal.ProxyWasmHost
   120  func (r *rootHostEmulator) ProxyDequeueSharedQueue(queueID uint32, returnValueData **byte, returnValueSize *int) internal.Status {
   121  	queue, ok := r.queues[queueID]
   122  	if !ok {
   123  		log.Printf("queue %d is not found", queueID)
   124  		return internal.StatusNotFound
   125  	} else if len(queue) == 0 {
   126  		log.Printf("queue %d is empty", queueID)
   127  		return internal.StatusEmpty
   128  	}
   129  
   130  	data := queue[0]
   131  	*returnValueData = &data[0]
   132  	*returnValueSize = len(data)
   133  	r.queues[queueID] = queue[1:]
   134  	return internal.StatusOK
   135  }
   136  
   137  // impl internal.ProxyWasmHost
   138  func (r *rootHostEmulator) ProxyEnqueueSharedQueue(queueID uint32, valueData *byte, valueSize int) internal.Status {
   139  	queue, ok := r.queues[queueID]
   140  	if !ok {
   141  		log.Printf("queue %d is not found", queueID)
   142  		return internal.StatusNotFound
   143  	}
   144  
   145  	r.queues[queueID] = append(queue, internal.RawBytePtrToByteSlice(valueData, valueSize))
   146  	internal.ProxyOnQueueReady(PluginContextID, queueID)
   147  	return internal.StatusOK
   148  }
   149  
   150  // impl internal.ProxyWasmHost
   151  func (r *rootHostEmulator) ProxyGetSharedData(keyData *byte, keySize int,
   152  	returnValueData **byte, returnValueSize *int, returnCas *uint32) internal.Status {
   153  	key := internal.RawBytePtrToString(keyData, keySize)
   154  
   155  	value, ok := r.sharedDataKVS[key]
   156  	if !ok {
   157  		return internal.StatusNotFound
   158  	}
   159  
   160  	*returnValueSize = len(value.data)
   161  	if len(value.data) > 0 {
   162  		*returnValueData = &value.data[0]
   163  	}
   164  	*returnCas = value.cas
   165  	return internal.StatusOK
   166  }
   167  
   168  // impl internal.ProxyWasmHost
   169  func (r *rootHostEmulator) ProxySetSharedData(keyData *byte, keySize int,
   170  	valueData *byte, valueSize int, cas uint32) internal.Status {
   171  	// Copy data provided by plugin to keep ownership within host. Otherwise, when
   172  	// plugin deallocates the memory could be modified.
   173  	key := strings.Clone(internal.RawBytePtrToString(keyData, keySize))
   174  	v := internal.RawBytePtrToByteSlice(valueData, valueSize)
   175  	value := make([]byte, len(v))
   176  	copy(value, v)
   177  
   178  	prev, ok := r.sharedDataKVS[key]
   179  	if !ok {
   180  		r.sharedDataKVS[key] = &sharedData{
   181  			data: value,
   182  			cas:  cas + 1,
   183  		}
   184  		return internal.StatusOK
   185  	}
   186  
   187  	if prev.cas != cas {
   188  		return internal.StatusCasMismatch
   189  	}
   190  
   191  	r.sharedDataKVS[key].cas = cas + 1
   192  	r.sharedDataKVS[key].data = value
   193  	return internal.StatusOK
   194  }
   195  
   196  // impl internal.ProxyWasmHost
   197  func (r *rootHostEmulator) ProxyDefineMetric(metricType internal.MetricType,
   198  	metricNameData *byte, metricNameSize int, returnMetricIDPtr *uint32) internal.Status {
   199  	name := internal.RawBytePtrToString(metricNameData, metricNameSize)
   200  	id, ok := r.metricNameToID[name]
   201  	if !ok {
   202  		id = uint32(len(r.metricNameToID))
   203  		r.metricNameToID[name] = id
   204  		r.metricIDToValue[id] = 0
   205  		r.metricIDToType[id] = metricType
   206  	}
   207  	*returnMetricIDPtr = id
   208  	return internal.StatusOK
   209  }
   210  
   211  // impl internal.ProxyWasmHost
   212  func (r *rootHostEmulator) ProxyIncrementMetric(metricID uint32, offset int64) internal.Status {
   213  	val, ok := r.metricIDToValue[metricID]
   214  	if !ok {
   215  		return internal.StatusBadArgument
   216  	}
   217  
   218  	r.metricIDToValue[metricID] = val + uint64(offset)
   219  	return internal.StatusOK
   220  }
   221  
   222  // impl internal.ProxyWasmHost
   223  func (r *rootHostEmulator) ProxyRecordMetric(metricID uint32, value uint64) internal.Status {
   224  	_, ok := r.metricIDToValue[metricID]
   225  	if !ok {
   226  		return internal.StatusBadArgument
   227  	}
   228  	r.metricIDToValue[metricID] = value
   229  	return internal.StatusOK
   230  }
   231  
   232  // impl internal.ProxyWasmHost
   233  func (r *rootHostEmulator) ProxyGetMetric(metricID uint32, returnMetricValue *uint64) internal.Status {
   234  	value, ok := r.metricIDToValue[metricID]
   235  	if !ok {
   236  		return internal.StatusBadArgument
   237  	}
   238  	*returnMetricValue = value
   239  	return internal.StatusOK
   240  }
   241  
   242  // impl internal.ProxyWasmHost
   243  func (r *rootHostEmulator) ProxyHttpCall(upstreamData *byte, upstreamSize int, headerData *byte, headerSize int, bodyData *byte,
   244  	bodySize int, trailersData *byte, trailersSize int, timeout uint32, calloutIDPtr *uint32) internal.Status {
   245  	upstream := internal.RawBytePtrToString(upstreamData, upstreamSize)
   246  	body := internal.RawBytePtrToString(bodyData, bodySize)
   247  	headers := deserializeRawBytePtrToMap(headerData, headerSize)
   248  	trailers := deserializeRawBytePtrToMap(trailersData, trailersSize)
   249  
   250  	log.Printf("[http callout to %s] timeout: %d", upstream, timeout)
   251  	log.Printf("[http callout to %s] headers: %v", upstream, headers)
   252  	log.Printf("[http callout to %s] body: %s", upstream, body)
   253  	log.Printf("[http callout to %s] trailers: %v", upstream, trailers)
   254  
   255  	calloutID := uint32(len(r.httpCalloutIDToContextID))
   256  	contextID := internal.VMStateGetActiveContextID()
   257  	r.httpCalloutIDToContextID[calloutID] = contextID
   258  	r.httpContextIDToCalloutInfos[contextID] = append(r.httpContextIDToCalloutInfos[contextID], HttpCalloutAttribute{
   259  		CalloutID: calloutID,
   260  		Upstream:  upstream,
   261  		Headers:   headers,
   262  		Trailers:  trailers,
   263  		Body:      []byte(body),
   264  	})
   265  
   266  	*calloutIDPtr = calloutID
   267  	return internal.StatusOK
   268  }
   269  
   270  // impl internal.ProxyWasmHost
   271  func (r *rootHostEmulator) RegisterForeignFunction(name string, f func([]byte) []byte) {
   272  	r.foreignFunctions[name] = f
   273  }
   274  
   275  // impl internal.ProxyWasmHost
   276  func (r *rootHostEmulator) ProxyCallForeignFunction(funcNamePtr *byte, funcNameSize int, paramPtr *byte, paramSize int, returnData **byte, returnSize *int) internal.Status {
   277  	funcName := internal.RawBytePtrToString(funcNamePtr, funcNameSize)
   278  	param := internal.RawBytePtrToByteSlice(paramPtr, paramSize)
   279  
   280  	log.Printf("[foreign call] funcname: %s", funcName)
   281  	log.Printf("[foreign call] param: %s", param)
   282  
   283  	f, ok := r.foreignFunctions[funcName]
   284  	if !ok {
   285  		log.Fatalf("%s not registered as a foreign function", funcName)
   286  	}
   287  	ret := f(param)
   288  	*returnData = &ret[0]
   289  	*returnSize = len(ret)
   290  
   291  	return internal.StatusOK
   292  }
   293  
   294  // // impl internal.ProxyWasmHost: delegated from hostEmulator
   295  func (r *rootHostEmulator) rootHostEmulatorProxyGetHeaderMapPairs(mapType internal.MapType, returnValueData **byte, returnValueSize *int) internal.Status {
   296  	res, ok := r.httpCalloutResponse[r.activeCalloutID]
   297  	if !ok {
   298  		log.Fatalf("callout response unregistered for %d", r.activeCalloutID)
   299  	}
   300  
   301  	var raw []byte
   302  	switch mapType {
   303  	case internal.MapTypeHttpCallResponseHeaders:
   304  		raw = internal.SerializeMap(res.headers)
   305  	case internal.MapTypeHttpCallResponseTrailers:
   306  		raw = internal.SerializeMap(res.trailers)
   307  	default:
   308  		panic("unreachable: maybe a bug in this host emulation or SDK")
   309  	}
   310  
   311  	*returnValueData = &raw[0]
   312  	*returnValueSize = len(raw)
   313  	return internal.StatusOK
   314  }
   315  
   316  // // impl internal.ProxyWasmHost: delegated from hostEmulator
   317  func (r *rootHostEmulator) rootHostEmulatorProxyGetMapValue(mapType internal.MapType, keyData *byte,
   318  	keySize int, returnValueData **byte, returnValueSize *int) internal.Status {
   319  	res, ok := r.httpCalloutResponse[r.activeCalloutID]
   320  	if !ok {
   321  		log.Fatalf("callout response unregistered for %d", r.activeCalloutID)
   322  	}
   323  
   324  	var hs [][2]string
   325  	switch mapType {
   326  	case internal.MapTypeHttpCallResponseHeaders:
   327  		hs = res.headers
   328  	case internal.MapTypeHttpCallResponseTrailers:
   329  		hs = res.trailers
   330  	default:
   331  		panic("unimplemented")
   332  	}
   333  
   334  	key := strings.ToLower(internal.RawBytePtrToString(keyData, keySize))
   335  
   336  	for _, h := range hs {
   337  		if h[0] == key {
   338  			v := []byte(h[1])
   339  			*returnValueData = &v[0]
   340  			*returnValueSize = len(v)
   341  			return internal.StatusOK
   342  		}
   343  	}
   344  
   345  	return internal.StatusNotFound
   346  }
   347  
   348  // // impl internal.ProxyWasmHost: delegated from hostEmulator
   349  func (r *rootHostEmulator) rootHostEmulatorProxyGetBufferBytes(bt internal.BufferType, start int, maxSize int,
   350  	returnBufferData **byte, returnBufferSize *int) internal.Status {
   351  	var buf []byte
   352  	switch bt {
   353  	case internal.BufferTypePluginConfiguration:
   354  		buf = r.pluginConfiguration
   355  	case internal.BufferTypeVMConfiguration:
   356  		buf = r.vmConfiguration
   357  	case internal.BufferTypeHttpCallResponseBody:
   358  		activeID := internal.VMStateGetActiveContextID()
   359  		res, ok := r.httpCalloutResponse[r.activeCalloutID]
   360  		if !ok {
   361  			log.Fatalf("callout response unregistered for %d", activeID)
   362  		}
   363  		buf = res.body
   364  	default:
   365  		panic("unreachable: maybe a bug in this host emulation or SDK")
   366  	}
   367  
   368  	if len(buf) == 0 {
   369  		return internal.StatusNotFound
   370  	} else if start >= len(buf) {
   371  		log.Printf("start index out of range: %d (start) >= %d ", start, len(buf))
   372  		return internal.StatusBadArgument
   373  	}
   374  
   375  	*returnBufferData = &buf[start]
   376  	if maxSize > len(buf)-start {
   377  		*returnBufferSize = len(buf) - start
   378  	} else {
   379  		*returnBufferSize = maxSize
   380  	}
   381  	return internal.StatusOK
   382  }
   383  
   384  // impl HostEmulator
   385  func (r *rootHostEmulator) GetTraceLogs() []string {
   386  	return r.getLogs(internal.LogLevelTrace)
   387  }
   388  
   389  // impl HostEmulator
   390  func (r *rootHostEmulator) GetDebugLogs() []string {
   391  	return r.getLogs(internal.LogLevelDebug)
   392  }
   393  
   394  // impl HostEmulator
   395  func (r *rootHostEmulator) GetInfoLogs() []string {
   396  	return r.getLogs(internal.LogLevelInfo)
   397  }
   398  
   399  // impl HostEmulator
   400  func (r *rootHostEmulator) GetWarnLogs() []string {
   401  	return r.getLogs(internal.LogLevelWarn)
   402  }
   403  
   404  // impl HostEmulator
   405  func (r *rootHostEmulator) GetErrorLogs() []string {
   406  	return r.getLogs(internal.LogLevelError)
   407  }
   408  
   409  // impl HostEmulator
   410  func (r *rootHostEmulator) GetCriticalLogs() []string {
   411  	return r.getLogs(internal.LogLevelCritical)
   412  }
   413  
   414  func (r *rootHostEmulator) getLogs(level internal.LogLevel) []string {
   415  	return r.logs[level]
   416  }
   417  
   418  // impl HostEmulator
   419  func (r *rootHostEmulator) GetTickPeriod() uint32 {
   420  	return r.tickPeriod
   421  }
   422  
   423  // impl HostEmulator
   424  func (r *rootHostEmulator) Tick() {
   425  	internal.ProxyOnTick(PluginContextID)
   426  }
   427  
   428  // impl HostEmulator
   429  func (r *rootHostEmulator) GetQueueSize(queueID uint32) int {
   430  	return len(r.queues[queueID])
   431  }
   432  
   433  // impl HostEmulator
   434  func (r *rootHostEmulator) GetCalloutAttributesFromContext(contextID uint32) []HttpCalloutAttribute {
   435  	infos := r.httpContextIDToCalloutInfos[contextID]
   436  	return infos
   437  }
   438  
   439  // impl HostEmulator
   440  func (r *rootHostEmulator) StartVM() types.OnVMStartStatus {
   441  	return internal.ProxyOnVMStart(PluginContextID, len(r.vmConfiguration))
   442  }
   443  
   444  // impl HostEmulator
   445  func (r *rootHostEmulator) StartPlugin() types.OnPluginStartStatus {
   446  	return internal.ProxyOnConfigure(PluginContextID, len(r.pluginConfiguration))
   447  }
   448  
   449  // impl HostEmulator
   450  func (r *rootHostEmulator) CallOnHttpCallResponse(calloutID uint32, headers, trailers [][2]string, body []byte) {
   451  	r.httpCalloutResponse[calloutID] = struct {
   452  		headers, trailers [][2]string
   453  		body              []byte
   454  	}{headers: cloneWithLowerCaseMapKeys(headers), trailers: cloneWithLowerCaseMapKeys(trailers), body: body}
   455  
   456  	// PluginContextID, calloutID uint32, numHeaders, bodySize, numTrailers in
   457  	r.activeCalloutID = calloutID
   458  	defer func() {
   459  		r.activeCalloutID = 0
   460  		delete(r.httpCalloutResponse, calloutID)
   461  		delete(r.httpCalloutIDToContextID, calloutID)
   462  	}()
   463  	internal.ProxyOnHttpCallResponse(PluginContextID, calloutID, len(headers), len(body), len(trailers))
   464  }
   465  
   466  // impl HostEmulator
   467  func (r *rootHostEmulator) FinishVM() bool {
   468  	return internal.ProxyOnDone(PluginContextID)
   469  }
   470  
   471  func (r *rootHostEmulator) GetCounterMetric(name string) (uint64, error) {
   472  	id, ok := r.metricNameToID[name]
   473  	if !ok {
   474  		return 0, fmt.Errorf("%s not found", name)
   475  	}
   476  
   477  	t, ok := r.metricIDToType[id]
   478  	if !ok {
   479  		return 0, fmt.Errorf("%s not found", name)
   480  	}
   481  
   482  	if t != internal.MetricTypeCounter {
   483  		return 0, fmt.Errorf(
   484  			"%s is not %v metric type but %v", name, internal.MetricTypeCounter, t)
   485  	}
   486  
   487  	v, ok := r.metricIDToValue[id]
   488  	if !ok {
   489  		return 0, fmt.Errorf("%s not found", name)
   490  	}
   491  	return v, nil
   492  }
   493  
   494  func (r *rootHostEmulator) GetGaugeMetric(name string) (uint64, error) {
   495  	id, ok := r.metricNameToID[name]
   496  	if !ok {
   497  		return 0, fmt.Errorf("%s not found", name)
   498  	}
   499  
   500  	t, ok := r.metricIDToType[id]
   501  	if !ok {
   502  		return 0, fmt.Errorf("%s not found", name)
   503  	}
   504  
   505  	if t != internal.MetricTypeGauge {
   506  		return 0, fmt.Errorf(
   507  			"%s is not %v metric type but %v", name, internal.MetricTypeGauge, t)
   508  	}
   509  
   510  	v, ok := r.metricIDToValue[id]
   511  	if !ok {
   512  		return 0, fmt.Errorf("%s not found", name)
   513  	}
   514  	return v, nil
   515  }
   516  
   517  func (r *rootHostEmulator) GetHistogramMetric(name string) (uint64, error) {
   518  	id, ok := r.metricNameToID[name]
   519  	if !ok {
   520  		return 0, fmt.Errorf("%s not found", name)
   521  	}
   522  
   523  	t, ok := r.metricIDToType[id]
   524  	if !ok {
   525  		return 0, fmt.Errorf("%s not found", name)
   526  	}
   527  
   528  	if t != internal.MetricTypeHistogram {
   529  		return 0, fmt.Errorf(
   530  			"%s is not %v metric type but %v", name, internal.MetricTypeHistogram, t)
   531  	}
   532  
   533  	v, ok := r.metricIDToValue[id]
   534  	if !ok {
   535  		return 0, fmt.Errorf("%s not found", name)
   536  	}
   537  	return v, nil
   538  }