github.com/aergoio/aergo@v1.3.1/contract/callback.go (about)

     1  // Copyright (C) 2014 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
     2  //
     3  // Use of this source code is governed by an MIT-style
     4  // license that can be found in the LICENSE file.
     5  
     6  package contract
     7  
     8  // You can't export a Go function to C and have definitions in the C
     9  // preamble in the same file, so we have to have callbackTrampoline in
    10  // its own file. Because we need a separate file anyway, the support
    11  // code for SQLite custom functions is in here.
    12  
    13  /*
    14  #include <stdlib.h>
    15  #include <sqlite3-binding.h>
    16  
    17  void _sqlite3_result_text(sqlite3_context* ctx, const char* s);
    18  void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l);
    19  */
    20  import "C"
    21  
    22  import (
    23  	"errors"
    24  	"fmt"
    25  	"math"
    26  	"reflect"
    27  	"sync"
    28  	"unsafe"
    29  )
    30  
    31  //export callbackTrampoline
    32  func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
    33  	args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
    34  	fi := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*functionInfo)
    35  	fi.Call(ctx, args)
    36  }
    37  
    38  //export stepTrampoline
    39  func stepTrampoline(ctx *C.sqlite3_context, argc C.int, argv **C.sqlite3_value) {
    40  	args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:int(argc):int(argc)]
    41  	ai := lookupHandle(uintptr(C.sqlite3_user_data(ctx))).(*aggInfo)
    42  	ai.Step(ctx, args)
    43  }
    44  
    45  //export doneTrampoline
    46  func doneTrampoline(ctx *C.sqlite3_context) {
    47  	handle := uintptr(C.sqlite3_user_data(ctx))
    48  	ai := lookupHandle(handle).(*aggInfo)
    49  	ai.Done(ctx)
    50  }
    51  
    52  //export compareTrampoline
    53  func compareTrampoline(handlePtr uintptr, la C.int, a *C.char, lb C.int, b *C.char) C.int {
    54  	cmp := lookupHandle(handlePtr).(func(string, string) int)
    55  	return C.int(cmp(C.GoStringN(a, la), C.GoStringN(b, lb)))
    56  }
    57  
    58  //export commitHookTrampoline
    59  func commitHookTrampoline(handle uintptr) int {
    60  	callback := lookupHandle(handle).(func() int)
    61  	return callback()
    62  }
    63  
    64  //export rollbackHookTrampoline
    65  func rollbackHookTrampoline(handle uintptr) {
    66  	callback := lookupHandle(handle).(func())
    67  	callback()
    68  }
    69  
    70  //export updateHookTrampoline
    71  func updateHookTrampoline(handle uintptr, op int, db *C.char, table *C.char, rowid int64) {
    72  	callback := lookupHandle(handle).(func(int, string, string, int64))
    73  	callback(op, C.GoString(db), C.GoString(table), rowid)
    74  }
    75  
    76  //export authorizerTrampoline
    77  func authorizerTrampoline(handle uintptr, op int, arg1 *C.char, arg2 *C.char, arg3 *C.char) int {
    78  	callback := lookupHandle(handle).(func(int, string, string, string) int)
    79  	return callback(op, C.GoString(arg1), C.GoString(arg2), C.GoString(arg3))
    80  }
    81  
    82  // Use handles to avoid passing Go pointers to C.
    83  
    84  type handleVal struct {
    85  	db  *SQLiteConn
    86  	val interface{}
    87  }
    88  
    89  var handleLock sync.Mutex
    90  var handleVals = make(map[uintptr]handleVal)
    91  var handleIndex uintptr = 100
    92  
    93  func newHandle(db *SQLiteConn, v interface{}) uintptr {
    94  	handleLock.Lock()
    95  	defer handleLock.Unlock()
    96  	i := handleIndex
    97  	handleIndex++
    98  	handleVals[i] = handleVal{db, v}
    99  	return i
   100  }
   101  
   102  func lookupHandle(handle uintptr) interface{} {
   103  	handleLock.Lock()
   104  	defer handleLock.Unlock()
   105  	r, ok := handleVals[handle]
   106  	if !ok {
   107  		if handle >= 100 && handle < handleIndex {
   108  			panic("deleted handle")
   109  		} else {
   110  			panic("invalid handle")
   111  		}
   112  	}
   113  	return r.val
   114  }
   115  
   116  func deleteHandles(db *SQLiteConn) {
   117  	handleLock.Lock()
   118  	defer handleLock.Unlock()
   119  	for handle, val := range handleVals {
   120  		if val.db == db {
   121  			delete(handleVals, handle)
   122  		}
   123  	}
   124  }
   125  
   126  // This is only here so that tests can refer to it.
   127  type callbackArgRaw C.sqlite3_value
   128  
   129  type callbackArgConverter func(*C.sqlite3_value) (reflect.Value, error)
   130  
   131  type callbackArgCast struct {
   132  	f   callbackArgConverter
   133  	typ reflect.Type
   134  }
   135  
   136  func (c callbackArgCast) Run(v *C.sqlite3_value) (reflect.Value, error) {
   137  	val, err := c.f(v)
   138  	if err != nil {
   139  		return reflect.Value{}, err
   140  	}
   141  	if !val.Type().ConvertibleTo(c.typ) {
   142  		return reflect.Value{}, fmt.Errorf("cannot convert %s to %s", val.Type(), c.typ)
   143  	}
   144  	return val.Convert(c.typ), nil
   145  }
   146  
   147  func callbackArgInt64(v *C.sqlite3_value) (reflect.Value, error) {
   148  	if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
   149  		return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
   150  	}
   151  	return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil
   152  }
   153  
   154  func callbackArgBool(v *C.sqlite3_value) (reflect.Value, error) {
   155  	if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
   156  		return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
   157  	}
   158  	i := int64(C.sqlite3_value_int64(v))
   159  	val := false
   160  	if i != 0 {
   161  		val = true
   162  	}
   163  	return reflect.ValueOf(val), nil
   164  }
   165  
   166  func callbackArgFloat64(v *C.sqlite3_value) (reflect.Value, error) {
   167  	if C.sqlite3_value_type(v) != C.SQLITE_FLOAT {
   168  		return reflect.Value{}, fmt.Errorf("argument must be a FLOAT")
   169  	}
   170  	return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil
   171  }
   172  
   173  func callbackArgBytes(v *C.sqlite3_value) (reflect.Value, error) {
   174  	switch C.sqlite3_value_type(v) {
   175  	case C.SQLITE_BLOB:
   176  		l := C.sqlite3_value_bytes(v)
   177  		p := C.sqlite3_value_blob(v)
   178  		return reflect.ValueOf(C.GoBytes(p, l)), nil
   179  	case C.SQLITE_TEXT:
   180  		l := C.sqlite3_value_bytes(v)
   181  		c := unsafe.Pointer(C.sqlite3_value_text(v))
   182  		return reflect.ValueOf(C.GoBytes(c, l)), nil
   183  	default:
   184  		return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
   185  	}
   186  }
   187  
   188  func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) {
   189  	switch C.sqlite3_value_type(v) {
   190  	case C.SQLITE_BLOB:
   191  		l := C.sqlite3_value_bytes(v)
   192  		p := (*C.char)(C.sqlite3_value_blob(v))
   193  		return reflect.ValueOf(C.GoStringN(p, l)), nil
   194  	case C.SQLITE_TEXT:
   195  		c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v)))
   196  		return reflect.ValueOf(C.GoString(c)), nil
   197  	default:
   198  		return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
   199  	}
   200  }
   201  
   202  func callbackArgGeneric(v *C.sqlite3_value) (reflect.Value, error) {
   203  	switch C.sqlite3_value_type(v) {
   204  	case C.SQLITE_INTEGER:
   205  		return callbackArgInt64(v)
   206  	case C.SQLITE_FLOAT:
   207  		return callbackArgFloat64(v)
   208  	case C.SQLITE_TEXT:
   209  		return callbackArgString(v)
   210  	case C.SQLITE_BLOB:
   211  		return callbackArgBytes(v)
   212  	case C.SQLITE_NULL:
   213  		// Interpret NULL as a nil byte slice.
   214  		var ret []byte
   215  		return reflect.ValueOf(ret), nil
   216  	default:
   217  		panic("unreachable")
   218  	}
   219  }
   220  
   221  func callbackArg(typ reflect.Type) (callbackArgConverter, error) {
   222  	switch typ.Kind() {
   223  	case reflect.Interface:
   224  		if typ.NumMethod() != 0 {
   225  			return nil, errors.New("the only supported interface type is interface{}")
   226  		}
   227  		return callbackArgGeneric, nil
   228  	case reflect.Slice:
   229  		if typ.Elem().Kind() != reflect.Uint8 {
   230  			return nil, errors.New("the only supported slice type is []byte")
   231  		}
   232  		return callbackArgBytes, nil
   233  	case reflect.String:
   234  		return callbackArgString, nil
   235  	case reflect.Bool:
   236  		return callbackArgBool, nil
   237  	case reflect.Int64:
   238  		return callbackArgInt64, nil
   239  	case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
   240  		c := callbackArgCast{callbackArgInt64, typ}
   241  		return c.Run, nil
   242  	case reflect.Float64:
   243  		return callbackArgFloat64, nil
   244  	case reflect.Float32:
   245  		c := callbackArgCast{callbackArgFloat64, typ}
   246  		return c.Run, nil
   247  	default:
   248  		return nil, fmt.Errorf("don't know how to convert to %s", typ)
   249  	}
   250  }
   251  
   252  func callbackConvertArgs(argv []*C.sqlite3_value, converters []callbackArgConverter, variadic callbackArgConverter) ([]reflect.Value, error) {
   253  	var args []reflect.Value
   254  
   255  	if len(argv) < len(converters) {
   256  		return nil, fmt.Errorf("function requires at least %d arguments", len(converters))
   257  	}
   258  
   259  	for i, arg := range argv[:len(converters)] {
   260  		v, err := converters[i](arg)
   261  		if err != nil {
   262  			return nil, err
   263  		}
   264  		args = append(args, v)
   265  	}
   266  
   267  	if variadic != nil {
   268  		for _, arg := range argv[len(converters):] {
   269  			v, err := variadic(arg)
   270  			if err != nil {
   271  				return nil, err
   272  			}
   273  			args = append(args, v)
   274  		}
   275  	}
   276  	return args, nil
   277  }
   278  
   279  type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error
   280  
   281  func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error {
   282  	switch v.Type().Kind() {
   283  	case reflect.Int64:
   284  	case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
   285  		v = v.Convert(reflect.TypeOf(int64(0)))
   286  	case reflect.Bool:
   287  		b := v.Interface().(bool)
   288  		if b {
   289  			v = reflect.ValueOf(int64(1))
   290  		} else {
   291  			v = reflect.ValueOf(int64(0))
   292  		}
   293  	default:
   294  		return fmt.Errorf("cannot convert %s to INTEGER", v.Type())
   295  	}
   296  
   297  	C.sqlite3_result_int64(ctx, C.sqlite3_int64(v.Interface().(int64)))
   298  	return nil
   299  }
   300  
   301  func callbackRetFloat(ctx *C.sqlite3_context, v reflect.Value) error {
   302  	switch v.Type().Kind() {
   303  	case reflect.Float64:
   304  	case reflect.Float32:
   305  		v = v.Convert(reflect.TypeOf(float64(0)))
   306  	default:
   307  		return fmt.Errorf("cannot convert %s to FLOAT", v.Type())
   308  	}
   309  
   310  	C.sqlite3_result_double(ctx, C.double(v.Interface().(float64)))
   311  	return nil
   312  }
   313  
   314  func callbackRetBlob(ctx *C.sqlite3_context, v reflect.Value) error {
   315  	if v.Type().Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 {
   316  		return fmt.Errorf("cannot convert %s to BLOB", v.Type())
   317  	}
   318  	i := v.Interface()
   319  	if i == nil || len(i.([]byte)) == 0 {
   320  		C.sqlite3_result_null(ctx)
   321  	} else {
   322  		bs := i.([]byte)
   323  		C._sqlite3_result_blob(ctx, unsafe.Pointer(&bs[0]), C.int(len(bs)))
   324  	}
   325  	return nil
   326  }
   327  
   328  func callbackRetText(ctx *C.sqlite3_context, v reflect.Value) error {
   329  	if v.Type().Kind() != reflect.String {
   330  		return fmt.Errorf("cannot convert %s to TEXT", v.Type())
   331  	}
   332  	C._sqlite3_result_text(ctx, C.CString(v.Interface().(string)))
   333  	return nil
   334  }
   335  
   336  func callbackRetNil(ctx *C.sqlite3_context, v reflect.Value) error {
   337  	return nil
   338  }
   339  
   340  func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
   341  	switch typ.Kind() {
   342  	case reflect.Interface:
   343  		errorInterface := reflect.TypeOf((*error)(nil)).Elem()
   344  		if typ.Implements(errorInterface) {
   345  			return callbackRetNil, nil
   346  		}
   347  		fallthrough
   348  	case reflect.Slice:
   349  		if typ.Elem().Kind() != reflect.Uint8 {
   350  			return nil, errors.New("the only supported slice type is []byte")
   351  		}
   352  		return callbackRetBlob, nil
   353  	case reflect.String:
   354  		return callbackRetText, nil
   355  	case reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
   356  		return callbackRetInteger, nil
   357  	case reflect.Float32, reflect.Float64:
   358  		return callbackRetFloat, nil
   359  	default:
   360  		return nil, fmt.Errorf("don't know how to convert to %s", typ)
   361  	}
   362  }
   363  
   364  func callbackError(ctx *C.sqlite3_context, err error) {
   365  	cstr := C.CString(err.Error())
   366  	defer C.free(unsafe.Pointer(cstr))
   367  	C.sqlite3_result_error(ctx, cstr, -1)
   368  }
   369  
   370  // Test support code. Tests are not allowed to import "C", so we can't
   371  // declare any functions that use C.sqlite3_value.
   372  func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter {
   373  	return func(*C.sqlite3_value) (reflect.Value, error) {
   374  		return v, err
   375  	}
   376  }