github.com/gmemcc/yaegi@v0.12.1-0.20221128122509-aa99124c5d16/interp/conv.go (about)

     1  package interp
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"fmt"
     7  	"github.com/jinzhu/copier"
     8  	"github.com/spf13/cast"
     9  	"go/constant"
    10  	"reflect"
    11  	"strconv"
    12  	"strings"
    13  )
    14  
    15  func canIconv(typ *itype, expected *itype) bool {
    16  	if typ.assignableTo(expected) {
    17  		return true
    18  	}
    19  	if typ.rtype.Kind() == reflect.String {
    20  		return true
    21  	}
    22  	ertype := expected.rtype
    23  	_, err := rconv(reflect.New(typ.rtype).Elem(), ertype)
    24  	return err == nil
    25  }
    26  
    27  func canIconvBool(t *itype) bool {
    28  	typ := t.TypeOf()
    29  	return typ.Kind() == reflect.Bool || isNumber(typ) || isString(typ) || isInterface(t)
    30  }
    31  
    32  func canRconvBool(t reflect.Type) bool {
    33  	return t.Kind() == reflect.Bool || t.Kind() == reflect.Interface || isNumber(t) || isString(t)
    34  }
    35  
    36  func rconv(src reflect.Value, expectedType reflect.Type) (reflect.Value, error) {
    37  	if !src.IsValid() {
    38  		return src, nil
    39  	}
    40  	srcType := src.Type()
    41  	if srcType == expectedType {
    42  		return src, nil
    43  	}
    44  	if srcType.Kind() == expectedType.Kind() && srcType.Kind() != reflect.Struct &&
    45  		(srcType.PkgPath() != expectedType.PkgPath() || srcType.Name() != expectedType.Name()) {
    46  		// type def from existing type
    47  		return src.Convert(expectedType), nil
    48  	}
    49  	value := src.Interface()
    50  	switch expectedType.Kind() {
    51  	case reflect.Bool:
    52  		casted, err := cast.ToBoolE(value)
    53  		if err == nil {
    54  			return reflect.ValueOf(casted), nil
    55  		} else {
    56  			return src, err
    57  		}
    58  	case reflect.Int:
    59  		casted, err := cast.ToIntE(value)
    60  		if err == nil {
    61  			return reflect.ValueOf(casted), nil
    62  		} else {
    63  			return src, err
    64  		}
    65  	case reflect.Int8:
    66  		casted, err := cast.ToInt8E(value)
    67  		if err == nil {
    68  			return reflect.ValueOf(casted), nil
    69  		} else {
    70  			return src, err
    71  		}
    72  	case reflect.Int16:
    73  		casted, err := cast.ToInt16E(value)
    74  		if err == nil {
    75  			return reflect.ValueOf(casted), nil
    76  		} else {
    77  			return src, err
    78  		}
    79  	case reflect.Int32:
    80  		casted, err := cast.ToInt32E(value)
    81  		if err == nil {
    82  			return reflect.ValueOf(casted), nil
    83  		} else {
    84  			return src, err
    85  		}
    86  	case reflect.Int64:
    87  		casted, err := cast.ToInt64E(value)
    88  		if err == nil {
    89  			return reflect.ValueOf(casted), nil
    90  		} else {
    91  			return src, err
    92  		}
    93  	case reflect.Uint:
    94  		casted, err := cast.ToUintE(value)
    95  		if err == nil {
    96  			return reflect.ValueOf(casted), nil
    97  		} else {
    98  			return src, err
    99  		}
   100  	case reflect.Uint8:
   101  		casted, err := cast.ToUint8E(value)
   102  		if err == nil {
   103  			return reflect.ValueOf(casted), nil
   104  		} else {
   105  			return src, err
   106  		}
   107  	case reflect.Uint16:
   108  		casted, err := cast.ToUint16E(value)
   109  		if err == nil {
   110  			return reflect.ValueOf(casted), nil
   111  		} else {
   112  			return src, err
   113  		}
   114  	case reflect.Uint32:
   115  		casted, err := cast.ToUint32E(value)
   116  		if err == nil {
   117  			return reflect.ValueOf(casted), nil
   118  		} else {
   119  			return src, err
   120  		}
   121  	case reflect.Uint64:
   122  		casted, err := cast.ToUint64E(value)
   123  		if err == nil {
   124  			return reflect.ValueOf(casted), nil
   125  		} else {
   126  			return src, err
   127  		}
   128  	case reflect.Float32:
   129  		casted, err := cast.ToFloat32E(value)
   130  		if err == nil {
   131  			return reflect.ValueOf(casted), nil
   132  		} else {
   133  			return src, err
   134  		}
   135  	case reflect.Float64:
   136  		casted, err := cast.ToFloat64E(value)
   137  		if err == nil {
   138  			return reflect.ValueOf(casted), nil
   139  		} else {
   140  			return src, err
   141  		}
   142  	case reflect.String:
   143  		switch reflect.ValueOf(value).Kind() {
   144  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
   145  			reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   146  			casted := fmt.Sprintf("%d", value)
   147  			return reflect.ValueOf(casted), nil
   148  		case reflect.Map:
   149  			bytes, err := json.Marshal(value)
   150  			return reflect.ValueOf(string(bytes)), err
   151  		default:
   152  			casted, err := cast.ToStringE(value)
   153  			if err == nil {
   154  				return reflect.ValueOf(casted), nil
   155  			} else {
   156  				return src, err
   157  			}
   158  		}
   159  	case reflect.Struct:
   160  		castedPtrValue := reflect.New(expectedType)
   161  		indirect := reflect.Indirect(src)
   162  		kind := indirect.Kind()
   163  		switch kind {
   164  		case reflect.String:
   165  			// assume value is in json format
   166  			err := json.Unmarshal([]byte(indirect.String()), castedPtrValue.Interface())
   167  			if err == nil {
   168  				return castedPtrValue.Elem(), nil
   169  			} else {
   170  				return src, err
   171  			}
   172  		case reflect.Struct:
   173  			err := copier.Copy(castedPtrValue.Interface(), value)
   174  			if err == nil {
   175  				return castedPtrValue.Elem(), nil
   176  			} else {
   177  				return src, err
   178  			}
   179  		case reflect.Map:
   180  			var bytes []byte
   181  			var err error
   182  			bytes, err = json.Marshal(indirect.Interface())
   183  			if err == nil {
   184  				err = json.Unmarshal(bytes, castedPtrValue.Interface())
   185  				if err == nil {
   186  					return castedPtrValue.Elem(), nil
   187  				}
   188  			}
   189  			return src, err
   190  		case reflect.Interface:
   191  			return indirect.Elem().Convert(expectedType), nil
   192  		default:
   193  			return src, errors.New(fmt.Sprintf(""))
   194  		}
   195  	case reflect.Map:
   196  		indirect := rconvToConcrete(reflect.Indirect(src))
   197  		kind := indirect.Kind()
   198  		switch kind {
   199  		case reflect.String:
   200  			castedPtrValue := reflect.New(expectedType)
   201  			str := indirect.String()
   202  			if str == "" {
   203  				return castedPtrValue.Elem(), nil
   204  			}
   205  			err := json.Unmarshal([]byte(str), castedPtrValue.Interface())
   206  			if err == nil {
   207  				return castedPtrValue.Elem(), nil
   208  			} else {
   209  				return src, err
   210  			}
   211  		case reflect.Map:
   212  			break
   213  		case reflect.Invalid:
   214  			return reflect.Zero(expectedType), nil
   215  		case reflect.Struct:
   216  			srcBytes, err := json.Marshal(src.Interface())
   217  			if err != nil {
   218  				return src, err
   219  			} else {
   220  				castedPtrValue := reflect.New(expectedType)
   221  				err := json.Unmarshal(srcBytes, castedPtrValue.Interface())
   222  				if err == nil {
   223  					return castedPtrValue.Elem(), nil
   224  				} else {
   225  					return src, err
   226  				}
   227  			}
   228  		default:
   229  			return src, nil
   230  		}
   231  		ktype := expectedType.Key()
   232  		vtype := expectedType.Elem()
   233  		if ktype == indirect.Type().Key() && vtype == indirect.Type().Elem() {
   234  			return indirect, nil
   235  		}
   236  		castedValue := reflect.MakeMapWithSize(expectedType, 0)
   237  		keys := indirect.MapKeys()
   238  		for i := 0; i < len(keys); i++ {
   239  			k := keys[i]
   240  			v := indirect.MapIndex(k)
   241  			var kcasted, vcasted reflect.Value
   242  			var err error
   243  			kcasted, err = rconv(k, ktype)
   244  			if err != nil {
   245  				return src, err
   246  			}
   247  			vcasted, err = rconv(v, vtype)
   248  			if err != nil {
   249  				return src, err
   250  			}
   251  			castedValue.SetMapIndex(kcasted, vcasted)
   252  		}
   253  		return castedValue, nil
   254  	case reflect.Slice:
   255  		src = rconvToConcrete(src)
   256  		srcType = src.Type()
   257  		if srcType.Kind() == reflect.String {
   258  			return rconv(reflect.ValueOf([]uint8(src.String())), expectedType)
   259  		} else if srcType.Kind() != reflect.Slice {
   260  			return src, nil
   261  		}
   262  		vtype := expectedType.Elem()
   263  		castedValue := reflect.MakeSlice(expectedType, src.Len(), src.Cap())
   264  		for i := 0; i < src.Len(); i++ {
   265  			vcasted, err := rconv(src.Index(i), vtype)
   266  			if err == nil {
   267  				castedValue.Index(i).Set(vcasted)
   268  			} else {
   269  				return src, err
   270  			}
   271  		}
   272  		return castedValue, nil
   273  	case reflect.Ptr:
   274  		castedValue, err := rconv(src, expectedType.Elem())
   275  		casted := castedValue.Interface()
   276  		if err == nil {
   277  			castedPtrVal := reflect.New(reflect.TypeOf(casted))
   278  			castedPtrVal.Elem().Set(castedValue)
   279  			return castedPtrVal, nil
   280  		} else {
   281  			return src, err
   282  		}
   283  	default:
   284  		return src, nil
   285  	}
   286  }
   287  
   288  func rconvAndSet(dvalue reflect.Value, svalue reflect.Value) error {
   289  	tleft := dvalue.Type()
   290  	tright := svalue.Type()
   291  	if tright.AssignableTo(tleft) {
   292  		dvalue.Set(svalue)
   293  	} else {
   294  		vright, err := rconv(svalue, tleft)
   295  		if err == nil {
   296  			dvalue.Set(vright)
   297  		} else {
   298  			return err
   299  		}
   300  	}
   301  	return nil
   302  }
   303  
   304  func rconvNumber(value reflect.Value) reflect.Value {
   305  	if !value.IsValid() || value.IsZero() {
   306  		return reflect.ValueOf(0)
   307  	}
   308  	if value.Kind() == reflect.Interface || value.Kind() == reflect.Ptr {
   309  		return rconvNumber(value.Elem())
   310  	}
   311  	if isString(value.Type()) {
   312  		val := value.Interface().(string)
   313  		var num interface{}
   314  		var err error
   315  		if strings.Index(val, ".") > -1 {
   316  			num, err = strconv.ParseFloat(val, 64)
   317  			if err != nil {
   318  				return value
   319  			} else {
   320  				return reflect.ValueOf(num)
   321  			}
   322  		} else {
   323  			num, err = strconv.ParseUint(val, 0, 64)
   324  			num, err = strconv.ParseInt(val, 0, 64)
   325  			if err != nil {
   326  				return value
   327  			} else {
   328  				return reflect.ValueOf(num)
   329  			}
   330  		}
   331  	} else {
   332  		return value
   333  	}
   334  }
   335  
   336  func rconvConst(val constant.Value, kind constant.Kind) constant.Value {
   337  	v := constToInterface(val)
   338  	switch kind {
   339  	case constant.Bool:
   340  		return constant.MakeBool(cast.ToBool(v))
   341  	case constant.String:
   342  		return constant.MakeString(cast.ToString(v))
   343  	case constant.Int:
   344  		return constant.MakeInt64(cast.ToInt64(v))
   345  	case constant.Float:
   346  		return constant.MakeFloat64(cast.ToFloat64(v))
   347  	}
   348  	return nil
   349  }
   350  
   351  func constToInterface(value constant.Value) interface{} {
   352  	switch value.Kind() {
   353  	case constant.Bool:
   354  		return constant.BoolVal(value)
   355  	case constant.String:
   356  		return constant.StringVal(value)
   357  	case constant.Int:
   358  		v, _ := constant.Int64Val(value)
   359  		return v
   360  	case constant.Float:
   361  		v, _ := constant.Float64Val(value)
   362  		return v
   363  	default:
   364  		return nil
   365  	}
   366  }
   367  
   368  func rconvToConcrete(value reflect.Value) reflect.Value {
   369  	if value.Kind() == reflect.Interface {
   370  		return rconvToConcrete(value.Elem())
   371  	} else {
   372  		return value
   373  	}
   374  }
   375  
   376  func rconvConstNumber(val constant.Value) (c constant.Value) {
   377  	v := constToInterface(val)
   378  	switch reflect.ValueOf(v).Kind() {
   379  	case reflect.Bool:
   380  		c = constant.MakeInt64(cast.ToInt64(v))
   381  	case reflect.String:
   382  		vstr := v.(string)
   383  		if strings.Index(vstr, ".") > -1 {
   384  			var num float64
   385  			var err error
   386  			num, err = cast.ToFloat64E(vstr)
   387  			if err != nil {
   388  				panic(err)
   389  			}
   390  			c = constant.MakeFloat64(num)
   391  		} else {
   392  			var num int64
   393  			var err error
   394  			num, err = cast.ToInt64E(vstr)
   395  			if err != nil {
   396  				panic(err)
   397  			}
   398  			c = constant.MakeInt64(num)
   399  		}
   400  	case reflect.Int64:
   401  		c = constant.MakeInt64(cast.ToInt64(v))
   402  	case reflect.Float64:
   403  		c = constant.MakeFloat64(cast.ToFloat64(v))
   404  	}
   405  	return
   406  }
   407  
   408  func rconvConstInt(value constant.Value) constant.Value {
   409  	v := rconvConstNumber(value)
   410  	if v.Kind() == constant.Float {
   411  		return constant.MakeInt64(cast.ToInt64(constToInterface(v)))
   412  	}
   413  	return v
   414  }
   415  
   416  func rconvConstBool(value constant.Value) constant.Value {
   417  	return constant.MakeBool(cast.ToBool(constToInterface(value)))
   418  }
   419  
   420  func rconvToString(val reflect.Value) string {
   421  	if !val.IsValid() {
   422  		return ""
   423  	}
   424  	return cast.ToString(val.Interface())
   425  }
   426  
   427  func rconvToBool(val reflect.Value) bool {
   428  	if !val.IsValid() {
   429  		return false
   430  	}
   431  	return cast.ToBool(val.Interface())
   432  }
   433  
   434  func rcompare(val0, val1 reflect.Value, op string) (value bool, err error) {
   435  	if val0.Kind() == reflect.String || (val0.Kind() == reflect.Interface || val0.Kind() == reflect.Ptr) && val0.Elem().Kind() == reflect.String {
   436  		value, err = compareString(rconvToString(val0), rconvToString(val1), op)
   437  	} else {
   438  		val0 = rconvNumber(val0)
   439  		val1 = rconvNumber(val1)
   440  		typ0 := val0.Type()
   441  		typ1 := val1.Type()
   442  		if isNumber(typ0) && isNumber(typ1) {
   443  			switch {
   444  			case isUint(typ0):
   445  				switch {
   446  				case isUint(typ1):
   447  					value, err = compareUint(val0.Uint(), val1.Uint(), op)
   448  				case isInt(typ1):
   449  					value, err = compareUint(val0.Uint(), uint64(val1.Int()), op)
   450  				case isFloat(typ1):
   451  					value, err = compareUint(val0.Uint(), uint64(val1.Float()), op)
   452  				}
   453  			case isInt(typ0):
   454  				switch {
   455  				case isUint(typ1):
   456  					value, err = compareInt(val0.Int(), int64(val1.Uint()), op)
   457  				case isInt(typ1):
   458  					value, err = compareInt(val0.Int(), val1.Int(), op)
   459  				case isFloat(typ1):
   460  					value, err = compareInt(val0.Int(), int64(val1.Float()), op)
   461  				}
   462  			case isFloat(typ0):
   463  				switch {
   464  				case isUint(typ1):
   465  					value, err = compareFloat(val0.Float(), float64(val1.Uint()), op)
   466  				case isInt(typ1):
   467  					value, err = compareFloat(val0.Float(), float64(val1.Int()), op)
   468  				case isFloat(typ1):
   469  					value, err = compareFloat(val0.Float(), float64(val1.Float()), op)
   470  				}
   471  			}
   472  		} else {
   473  			err = fmt.Errorf("type %s doesn't support %s operator", typ0.String(), op)
   474  		}
   475  	}
   476  	return
   477  }
   478  
   479  func compareString(v0 string, v1 string, op string) (bool, error) {
   480  	switch op {
   481  	case "==":
   482  		return v0 == v1, nil
   483  	case ">":
   484  		return v0 > v1, nil
   485  	case ">=":
   486  		return v0 >= v1, nil
   487  	case "<":
   488  		return v0 < v1, nil
   489  	case "<=":
   490  		return v0 <= v1, nil
   491  	case "!=":
   492  		return v0 != v1, nil
   493  	default:
   494  		return false, fmt.Errorf("unknown comparison operator %s", op)
   495  	}
   496  }
   497  
   498  func compareInt(v0 int64, v1 int64, op string) (bool, error) {
   499  	switch op {
   500  	case "==":
   501  		return v0 == v1, nil
   502  	case ">":
   503  		return v0 > v1, nil
   504  	case ">=":
   505  		return v0 >= v1, nil
   506  	case "<":
   507  		return v0 < v1, nil
   508  	case "<=":
   509  		return v0 <= v1, nil
   510  	case "!=":
   511  		return v0 != v1, nil
   512  	default:
   513  		return false, fmt.Errorf("unknown comparison operator %s", op)
   514  	}
   515  }
   516  
   517  func compareUint(v0 uint64, v1 uint64, op string) (bool, error) {
   518  	switch op {
   519  	case "==":
   520  		return v0 == v1, nil
   521  	case ">":
   522  		return v0 > v1, nil
   523  	case ">=":
   524  		return v0 >= v1, nil
   525  	case "<":
   526  		return v0 < v1, nil
   527  	case "<=":
   528  		return v0 <= v1, nil
   529  	case "!=":
   530  		return v0 != v1, nil
   531  	default:
   532  		return false, fmt.Errorf("unknown comparison operator %s", op)
   533  	}
   534  }
   535  
   536  func compareFloat(v0 float64, v1 float64, op string) (bool, error) {
   537  	switch op {
   538  	case "==":
   539  		return v0 == v1, nil
   540  	case ">":
   541  		return v0 > v1, nil
   542  	case ">=":
   543  		return v0 >= v1, nil
   544  	case "<":
   545  		return v0 < v1, nil
   546  	case "<=":
   547  		return v0 <= v1, nil
   548  	case "!=":
   549  		return v0 != v1, nil
   550  	default:
   551  		return false, fmt.Errorf("unknown comparison operator %s", op)
   552  	}
   553  }