github.com/machinefi/w3bstream@v1.6.5-rc9.0.20240426031326-b8c7c4876e72/pkg/modules/vm/wasmtime/instance.go (about)

     1  package wasmtime
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"fmt"
     7  	"strconv"
     8  	"strings"
     9  	"sync/atomic"
    10  	"time"
    11  
    12  	"github.com/google/uuid"
    13  	"github.com/pkg/errors"
    14  	"github.com/reactivex/rxgo/v2"
    15  	"github.com/tidwall/gjson"
    16  
    17  	conflog "github.com/machinefi/w3bstream/pkg/depends/conf/log"
    18  	"github.com/machinefi/w3bstream/pkg/depends/kit/logr"
    19  	"github.com/machinefi/w3bstream/pkg/depends/kit/mq"
    20  	"github.com/machinefi/w3bstream/pkg/depends/x/contextx"
    21  	"github.com/machinefi/w3bstream/pkg/depends/x/mapx"
    22  	"github.com/machinefi/w3bstream/pkg/depends/x/misc/must"
    23  	"github.com/machinefi/w3bstream/pkg/enums"
    24  	"github.com/machinefi/w3bstream/pkg/types"
    25  	"github.com/machinefi/w3bstream/pkg/types/wasm"
    26  )
    27  
    28  const (
    29  	maxUint = ^uint32(0)
    30  	maxInt  = int(maxUint >> 1)
    31  	// TODO: add into config
    32  	maxMsgPerInstance = 5000
    33  )
    34  
    35  type Instance struct {
    36  	ctx         context.Context
    37  	id          types.SFID
    38  	rt          *Runtime
    39  	lk          *ExportFuncs
    40  	state       *atomic.Uint32
    41  	res         *mapx.Map[uint32, []byte]
    42  	evs         *mapx.Map[uint32, []byte]
    43  	kvs         wasm.KVStore
    44  	ch          chan rxgo.Item
    45  	source      []string
    46  	operators   []wasm.Operator
    47  	simpleOpMap map[string]string
    48  	windOps     []wasm.Operator
    49  	windOpMap   map[string]string
    50  	sink        wasm.Sink
    51  }
    52  
    53  func NewInstanceByCode(ctx context.Context, id types.SFID, code []byte, st enums.InstanceState) (i *Instance, err error) {
    54  	ctx, l := logr.Start(ctx, "modules.vm.wasmtime.NewInstanceByCode")
    55  	defer l.End()
    56  
    57  	res := mapx.New[uint32, []byte]()
    58  	evs := mapx.New[uint32, []byte]()
    59  	rt := NewRuntime()
    60  	lk, err := NewExportFuncs(contextx.WithContextCompose(
    61  		wasm.WithRuntimeResourceContext(res),
    62  		wasm.WithRuntimeEventTypesContext(evs),
    63  	)(ctx), rt)
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  	if err := rt.Link(lk, code); err != nil {
    68  		return nil, err
    69  	}
    70  	state := &atomic.Uint32{}
    71  	state.Store(uint32(st))
    72  
    73  	ins := &Instance{
    74  		rt:    rt,
    75  		lk:    lk,
    76  		id:    id,
    77  		state: state,
    78  		res:   res,
    79  		evs:   evs,
    80  		kvs:   wasm.MustKVStoreFromContext(ctx),
    81  		ch:    make(chan rxgo.Item),
    82  	}
    83  
    84  	flow, ok := wasm.FlowFromContext(ctx)
    85  	if ok {
    86  		ins.source = flow.Source.Strategies
    87  		ins.operators = flow.Operators
    88  		ins.simpleOpMap = make(map[string]string)
    89  		ins.windOpMap = make(map[string]string)
    90  		ins.windOps = make([]wasm.Operator, 0)
    91  		ins.sink = flow.Sink
    92  		go func() {
    93  			observable := ins.streamCompute(ins.ch)
    94  			ins.initSink(ins.ctx, observable)
    95  		}()
    96  	}
    97  
    98  	return ins, nil
    99  }
   100  
   101  var _ wasm.Instance = (*Instance)(nil)
   102  
   103  func (i *Instance) ID() string { return i.id.String() }
   104  
   105  func (i *Instance) Start(ctx context.Context) error {
   106  	ctx, l := logr.Start(ctx, "modules.vm.Instance.Start", "instance_id", i.ID())
   107  	defer l.End()
   108  
   109  	i.state.Store(uint32(enums.INSTANCE_STATE__STARTED))
   110  	return nil
   111  }
   112  
   113  func (i *Instance) Stop(ctx context.Context) error {
   114  	ctx, l := logr.Start(ctx, "modules.vm.Instance.Stop", "instance_id", i.ID())
   115  	defer l.End()
   116  
   117  	i.state.Store(uint32(enums.INSTANCE_STATE__STOPPED))
   118  	return nil
   119  }
   120  
   121  func (i *Instance) State() wasm.InstanceState { return wasm.InstanceState(i.state.Load()) }
   122  
   123  func (i *Instance) HandleEvent(ctx context.Context, fn, eventType string, data []byte) *wasm.EventHandleResult {
   124  	ctx, l := logr.Start(ctx, "modules.vm.wasmtime.Instance.HandleEvent")
   125  	defer l.End()
   126  
   127  	if i.State() != enums.INSTANCE_STATE__STARTED {
   128  		return &wasm.EventHandleResult{
   129  			InstanceID: i.id.String(),
   130  			Code:       wasm.ResultStatusCode_Failed,
   131  			ErrMsg:     "instance not running",
   132  		}
   133  	}
   134  
   135  	task := &Task{
   136  		EventID:   types.MustEventIDFromContext(ctx),
   137  		EventType: eventType,
   138  		Handler:   fn,
   139  		Payload:   data,
   140  		TaskState: mq.TASK_STATE__PENDING,
   141  		vm:        i,
   142  		retrieve:  make(chan *wasm.EventHandleResult),
   143  	}
   144  
   145  	return i.handle(ctx, task)
   146  }
   147  
   148  func (i *Instance) streamCompute(ch chan rxgo.Item) rxgo.Observable {
   149  	obs := rxgo.FromChannel(ch)
   150  	for index, op := range i.operators {
   151  		switch {
   152  		case op.OpType == enums.FLOW_OPERATOR__FILTER:
   153  			filterNum := index
   154  			i.simpleOpMap[fmt.Sprintf("%s_%d", enums.FLOW_OPERATOR__FILTER, filterNum)] = op.WasmFunc
   155  
   156  			obs = obs.Filter(func(inter interface{}) bool {
   157  				start := time.Now()
   158  				res := false
   159  				task := inter.(*Task)
   160  				task.Handler = i.simpleOpMap[fmt.Sprintf("%s_%d", enums.FLOW_OPERATOR__FILTER, filterNum)]
   161  
   162  				rb, ok := i.runOp(task)
   163  				if !ok {
   164  					conflog.Std().Error(errors.New(fmt.Sprintf("%s result not found", op.WasmFunc)))
   165  					return res
   166  				}
   167  
   168  				result := strings.ToLower(string(rb))
   169  				if result == "true" {
   170  					res = true
   171  				} else if result == "false" {
   172  					res = false
   173  				} else {
   174  					conflog.Std().Warn(errors.New("the value does not support"))
   175  				}
   176  				duration := time.Since(start)
   177  				conflog.Std().Info(fmt.Sprintf("%s template cost %s", task.Handler, duration.String()))
   178  				return res
   179  			})
   180  		case op.OpType == enums.FLOW_OPERATOR__MAP:
   181  			mapNum := index
   182  			i.simpleOpMap[fmt.Sprintf("%s_%d", enums.FLOW_OPERATOR__MAP, mapNum)] = op.WasmFunc
   183  
   184  			obs = obs.Map(func(ctx context.Context, inter interface{}) (interface{}, error) {
   185  				start := time.Now()
   186  				task := inter.(*Task)
   187  				task.Handler = i.simpleOpMap[fmt.Sprintf("%s_%d", enums.FLOW_OPERATOR__MAP, mapNum)]
   188  
   189  				rb, ok := i.runOp(task)
   190  				if !ok {
   191  					conflog.Std().Error(errors.New(fmt.Sprintf("%s result not found", op.WasmFunc)))
   192  					return nil, errors.New(fmt.Sprintf("%s result not found", op.WasmFunc))
   193  				}
   194  
   195  				task.Payload = rb
   196  				duration := time.Since(start)
   197  				conflog.Std().Info(fmt.Sprintf("%s template cost %s", task.Handler, duration.String()))
   198  				return task, nil
   199  			})
   200  		case op.OpType == enums.FLOW_OPERATOR__WINDOW:
   201  			obs = obs.WindowWithTime(rxgo.WithDuration(60 * time.Second))
   202  		case op.OpType > enums.FLOW_OPERATOR__WINDOW:
   203  			i.windOps = append(i.windOps, op)
   204  		}
   205  	}
   206  
   207  	return obs
   208  }
   209  
   210  func (i *Instance) initSink(ctx context.Context, observable rxgo.Observable) {
   211  	c := observable.Observe()
   212  	for item := range c {
   213  
   214  		switch item.V.(type) {
   215  		case rxgo.GroupedObservable: // group operator
   216  			go func() {
   217  				obs := item.V.(rxgo.GroupedObservable)
   218  				// add other op like reduce
   219  				for it := range obs.Observe() {
   220  					i.sinkData(ctx, it)
   221  				}
   222  			}()
   223  		case *rxgo.ObservableImpl: // window operator
   224  			var (
   225  				obs   = item.V
   226  				index = 0
   227  				op    = wasm.Operator{}
   228  			)
   229  
   230  			for index, op = range i.windOps {
   231  				switch op.OpType {
   232  				// last op
   233  				case enums.FLOW_OPERATOR__REDUCE:
   234  					reduceNum := index
   235  					i.windOpMap[fmt.Sprintf("%s_%d", enums.FLOW_OPERATOR__REDUCE, reduceNum)] = op.WasmFunc
   236  
   237  					obs = obs.(*rxgo.ObservableImpl).Reduce(func(ctx context.Context, inter1 interface{}, inter2 interface{}) (interface{}, error) {
   238  						start := time.Now()
   239  						var task1, task2 *Task
   240  						task2 = inter2.(*Task)
   241  						task2.Handler = i.windOpMap[fmt.Sprintf("%s_%d", enums.FLOW_OPERATOR__REDUCE, reduceNum)]
   242  
   243  						tasks := make([]*Task, 0)
   244  						if inter1 != nil {
   245  							task1 = inter1.(*Task)
   246  						}
   247  						tasks = append(tasks, task1)
   248  						tasks = append(tasks, task2)
   249  
   250  						rb, ok := i.runOp(tasks...)
   251  						if !ok {
   252  							conflog.Std().Error(errors.New(fmt.Sprintf("%s result not found", op.WasmFunc)))
   253  							return nil, errors.New(fmt.Sprintf("%s result not found", op.WasmFunc))
   254  						}
   255  
   256  						task2.Payload = rb
   257  						duration := time.Since(start)
   258  						conflog.Std().Info(fmt.Sprintf("%s template cost %s", task2.Handler, duration.String()))
   259  						return task2, nil
   260  					})
   261  				case enums.FLOW_OPERATOR__GROUP:
   262  					groupNum := index
   263  					i.windOpMap[fmt.Sprintf("%s_%d", enums.FLOW_OPERATOR__GROUP, groupNum)] = op.WasmFunc
   264  
   265  					obs = obs.(*rxgo.ObservableImpl).GroupByDynamic(func(item rxgo.Item) string {
   266  						start := time.Now()
   267  						task := item.V.(*Task)
   268  						task.Handler = i.windOpMap[fmt.Sprintf("%s_%d", enums.FLOW_OPERATOR__GROUP, groupNum)]
   269  
   270  						rb, ok := i.runOp(task)
   271  						if !ok {
   272  							conflog.Std().Error(errors.New(fmt.Sprintf("%s result not found", op.WasmFunc)))
   273  							return "error"
   274  						}
   275  
   276  						groupKey := string(rb)
   277  						duration := time.Since(start)
   278  						conflog.Std().Info(fmt.Sprintf("%s template cost %s", task.Handler, duration.String()))
   279  						return groupKey
   280  					}, rxgo.WithBufferedChannel(2), rxgo.WithErrorStrategy(rxgo.ContinueOnError))
   281  					goto skip
   282  				}
   283  			}
   284  
   285  		skip:
   286  			switch obs.(type) {
   287  			case rxgo.OptionalSingle:
   288  				for it := range obs.(rxgo.OptionalSingle).Observe() {
   289  					i.sinkData(ctx, it)
   290  				}
   291  			case *rxgo.ObservableImpl:
   292  				for it := range obs.(*rxgo.ObservableImpl).Observe() {
   293  					// check group or common
   294  					switch it.V.(type) {
   295  					case rxgo.GroupedObservable:
   296  						go func() {
   297  							grpObs := it.V
   298  							op := wasm.Operator{}
   299  							// add other op like reduce
   300  							// there are other ops after group op, should add here
   301  							if index < len(i.windOps)-1 {
   302  								for j := index; j < len(i.windOps); j++ {
   303  									op = i.windOps[j]
   304  									switch op.OpType {
   305  									case enums.FLOW_OPERATOR__REDUCE:
   306  										reduceNum := j
   307  										i.windOpMap[fmt.Sprintf("%s_%d", enums.FLOW_OPERATOR__REDUCE, reduceNum)] = op.WasmFunc
   308  
   309  										grpObs = grpObs.(rxgo.GroupedObservable).Reduce(func(ctx context.Context, inter1 interface{}, inter2 interface{}) (interface{}, error) {
   310  											start := time.Now()
   311  											var task1, task2 *Task
   312  											task2 = inter2.(*Task)
   313  											task2.Handler = i.windOpMap[fmt.Sprintf("%s_%d", enums.FLOW_OPERATOR__REDUCE, reduceNum)]
   314  
   315  											tasks := make([]*Task, 0)
   316  											if inter1 != nil {
   317  												task1 = inter1.(*Task)
   318  											}
   319  											tasks = append(tasks, task1)
   320  											tasks = append(tasks, task2)
   321  
   322  											rb, ok := i.runOp(tasks...)
   323  											if !ok {
   324  												conflog.Std().Error(errors.New(fmt.Sprintf("%s result not found", op.WasmFunc)))
   325  												return nil, errors.New(fmt.Sprintf("%s result not found", op.WasmFunc))
   326  											}
   327  
   328  											task2.Payload = rb
   329  											duration := time.Since(start)
   330  											conflog.Std().Info(fmt.Sprintf("%s template cost %s", task2.Handler, duration.String()))
   331  											return task2, nil
   332  										})
   333  									}
   334  								}
   335  							}
   336  							switch grpObs.(type) {
   337  							case rxgo.OptionalSingle:
   338  								for it := range grpObs.(rxgo.OptionalSingle).Observe() {
   339  									i.sinkData(ctx, it)
   340  								}
   341  							case *rxgo.ObservableImpl:
   342  								for it := range grpObs.(*rxgo.ObservableImpl).Observe() {
   343  									i.sinkData(ctx, it)
   344  								}
   345  							default:
   346  								i.sinkData(ctx, it)
   347  							}
   348  						}()
   349  					default:
   350  						i.sinkData(ctx, it)
   351  					}
   352  				}
   353  			}
   354  		default:
   355  			i.sinkData(ctx, item)
   356  		}
   357  	}
   358  }
   359  
   360  func (i *Instance) sinkData(ctx context.Context, item rxgo.Item) {
   361  	rowByte := item.V.(*Task).Payload
   362  
   363  	switch i.sink.SinkType {
   364  	case enums.FLOW_SINK__RMDB:
   365  		db, err := sql.Open(i.sink.SinkInfo.DBInfo.DBType, i.sink.SinkInfo.DBInfo.Endpoint)
   366  		if err != nil {
   367  			conflog.Std().Error(err)
   368  		}
   369  		err = db.Ping()
   370  		if err != nil {
   371  			conflog.Std().Error(err)
   372  		}
   373  
   374  		sqlStringPrefix := fmt.Sprintf("INSERT INTO %s (", i.sink.SinkInfo.DBInfo.Table)
   375  		sqlStringSuffix := fmt.Sprintf(") VALUES (")
   376  		params := make([]interface{}, 0)
   377  		for index, c := range i.sink.SinkInfo.DBInfo.Columns {
   378  			params = append(params, gjson.GetBytes(rowByte, c).String())
   379  			sqlStringPrefix = sqlStringPrefix + c + ","
   380  			sqlStringSuffix = sqlStringSuffix + "$" + strconv.Itoa(index+1) + ","
   381  		}
   382  		sqlString := fmt.Sprintf("%s%s);", sqlStringPrefix[:len(sqlStringPrefix)-1], sqlStringSuffix[:len(sqlStringSuffix)-1])
   383  
   384  		_, err = db.ExecContext(context.Background(), sqlString, params...)
   385  		if err != nil {
   386  			conflog.Std().Error(err)
   387  		}
   388  	case enums.FLOW_SINK__BLOCKCHAIN:
   389  
   390  	default:
   391  
   392  	}
   393  }
   394  
   395  func (i *Instance) runOp(task ...*Task) ([]byte, bool) {
   396  	var (
   397  		ctx     context.Context
   398  		handler string
   399  
   400  		rids = make([]interface{}, 0)
   401  	)
   402  
   403  	for _, t := range task {
   404  		// if task is nil,  set rid is 0
   405  		var rid uint32 = 0
   406  		if t != nil {
   407  			rid = i.AddResource([]byte(t.EventType), t.Payload)
   408  			// ctx = t.ctx
   409  			handler = t.Handler
   410  		}
   411  
   412  		rids = append(rids, int32(rid))
   413  	}
   414  	defer func() {
   415  		for _, rid := range rids {
   416  			i.RmvResource(uint32(rid.(int32)))
   417  		}
   418  	}()
   419  
   420  	start := time.Now()
   421  	code := i.handleByRid(ctx, handler, rids...).Code
   422  	duration := time.Since(start)
   423  	conflog.Std().Info(fmt.Sprintf("%s wasm cost %s", handler, duration.String()))
   424  
   425  	conflog.Std().Info(fmt.Sprintf("%s wasm code %d", handler, code))
   426  
   427  	if code < 0 {
   428  		conflog.Std().Error(errors.New(fmt.Sprintf("%s wasm code run error", handler)))
   429  		return nil, false
   430  	}
   431  
   432  	return i.GetResource(uint32(code))
   433  }
   434  
   435  func (i *Instance) handleByRid(ctx context.Context, handlerName string, rids ...interface{}) *wasm.EventHandleResult {
   436  	l := types.MustLoggerFromContext(ctx)
   437  
   438  	_, l = l.Start(ctx, "instance.handleByRid")
   439  	defer l.End()
   440  
   441  	if err := i.rt.Instantiate(ctx); err != nil {
   442  		return &wasm.EventHandleResult{
   443  			InstanceID: i.id.String(),
   444  			ErrMsg:     err.Error(),
   445  			Code:       wasm.ResultStatusCode_Failed,
   446  		}
   447  	}
   448  	defer i.rt.Deinstantiate(ctx)
   449  
   450  	result, err := i.rt.Call(ctx, handlerName, rids...)
   451  	if err != nil {
   452  		l.Error(err)
   453  		return &wasm.EventHandleResult{
   454  			InstanceID: i.id.String(),
   455  			ErrMsg:     err.Error(),
   456  			Code:       wasm.ResultStatusCode_Failed,
   457  		}
   458  	}
   459  
   460  	return &wasm.EventHandleResult{
   461  		InstanceID: i.id.String(),
   462  		Code:       wasm.ResultStatusCode(result.(int32)),
   463  	}
   464  }
   465  
   466  func (i *Instance) handle(ctx context.Context, task *Task) *wasm.EventHandleResult {
   467  	ctx, l := logr.Start(ctx, "modules.vm.wasmtime.Instance.handle",
   468  		"event_id", task.EventID,
   469  		"instance_id", i.id,
   470  	)
   471  	defer l.End()
   472  
   473  	l.Info("start processing task")
   474  	rid, ok := i.lk.EntryContext(ctx, task.EventID, []byte(task.EventType), task.Payload)
   475  	if !ok {
   476  		return &wasm.EventHandleResult{
   477  			InstanceID: i.id.String(),
   478  			ErrMsg:     "InstanceBusy",
   479  			Code:       wasm.ResultStatusCode_Failed,
   480  		}
   481  	}
   482  	defer func() {
   483  		must.BeTrue(i.lk.LeaveContext(ctx, task.EventID, rid))
   484  	}()
   485  
   486  	if err := i.rt.Instantiate(ctx); err != nil {
   487  		return &wasm.EventHandleResult{
   488  			InstanceID: i.id.String(),
   489  			ErrMsg:     err.Error(),
   490  			Code:       wasm.ResultStatusCode_Failed,
   491  		}
   492  	}
   493  	defer i.rt.Deinstantiate(ctx)
   494  
   495  	before, err := i.rt.store.GetFuel()
   496  	if err == nil {
   497  		defer func() {
   498  			after, err := i.rt.store.GetFuel()
   499  			if err == nil {
   500  				i.lk.HostLog(conflog.InfoLevel, fmt.Sprintf("consumed fuel: %d", before-after))
   501  			}
   502  		}()
   503  	}
   504  
   505  	// TODO support wasm return data(not only code) for HTTP responding
   506  	result, err := i.rt.Call(ctx, task.Handler, int32(rid))
   507  	l.WithValues("result", result, "error", err).Debug("call wasm runtime completed.")
   508  	if err != nil {
   509  		i.lk.HostLog(conflog.ErrorLevel, errors.Wrapf(err, "wasm call completed with error. code: %v", result))
   510  		return &wasm.EventHandleResult{
   511  			InstanceID: i.id.String(),
   512  			ErrMsg:     err.Error(),
   513  			Code:       wasm.ResultStatusCode_Failed,
   514  		}
   515  	}
   516  	i.lk.HostLog(conflog.InfoLevel, fmt.Sprintf("wasm call completed. code: %v", result))
   517  	return &wasm.EventHandleResult{
   518  		InstanceID: i.id.String(),
   519  		Code:       wasm.ResultStatusCode(result.(int32)),
   520  	}
   521  }
   522  
   523  func (i *Instance) AddResource(eventType, data []byte) uint32 {
   524  	var id = int32(uuid.New().ID() % uint32(maxInt))
   525  	i.res.Store(uint32(id), data)
   526  	i.evs.Store(uint32(id), eventType)
   527  	return uint32(id)
   528  }
   529  
   530  func (i *Instance) GetResource(id uint32) ([]byte, bool) {
   531  	return i.res.Load(id)
   532  }
   533  
   534  func (i *Instance) RmvResource(id uint32) {
   535  	i.res.Remove(id)
   536  	i.evs.Remove(id)
   537  }