github.com/dashpay/godash@v0.0.0-20160726055534-e038a21e0e3d/btcjson/cmdparse.go (about)

     1  // Copyright (c) 2014 The btcsuite developers
     2  // Copyright (c) 2016 The Dash developers
     3  // Use of this source code is governed by an ISC
     4  // license that can be found in the LICENSE file.
     5  
     6  package btcjson
     7  
     8  import (
     9  	"encoding/json"
    10  	"fmt"
    11  	"reflect"
    12  	"strconv"
    13  	"strings"
    14  )
    15  
    16  // makeParams creates a slice of interface values for the given struct.
    17  func makeParams(rt reflect.Type, rv reflect.Value) []interface{} {
    18  	numFields := rt.NumField()
    19  	params := make([]interface{}, 0, numFields)
    20  	for i := 0; i < numFields; i++ {
    21  		rtf := rt.Field(i)
    22  		rvf := rv.Field(i)
    23  		if rtf.Type.Kind() == reflect.Ptr {
    24  			if rvf.IsNil() {
    25  				break
    26  			}
    27  			rvf.Elem()
    28  		}
    29  		params = append(params, rvf.Interface())
    30  	}
    31  
    32  	return params
    33  }
    34  
    35  // MarshalCmd marshals the passed command to a JSON-RPC request byte slice that
    36  // is suitable for transmission to an RPC server.  The provided command type
    37  // must be a registered type.  All commands provided by this package are
    38  // registered by default.
    39  func MarshalCmd(id interface{}, cmd interface{}) ([]byte, error) {
    40  	// Look up the cmd type and error out if not registered.
    41  	rt := reflect.TypeOf(cmd)
    42  	registerLock.RLock()
    43  	method, ok := concreteTypeToMethod[rt]
    44  	registerLock.RUnlock()
    45  	if !ok {
    46  		str := fmt.Sprintf("%q is not registered", method)
    47  		return nil, makeError(ErrUnregisteredMethod, str)
    48  	}
    49  
    50  	// The provided command must not be nil.
    51  	rv := reflect.ValueOf(cmd)
    52  	if rv.IsNil() {
    53  		str := fmt.Sprint("the specified command is nil")
    54  		return nil, makeError(ErrInvalidType, str)
    55  	}
    56  
    57  	// Create a slice of interface values in the order of the struct fields
    58  	// while respecting pointer fields as optional params and only adding
    59  	// them if they are non-nil.
    60  	params := makeParams(rt.Elem(), rv.Elem())
    61  
    62  	// Generate and marshal the final JSON-RPC request.
    63  	rawCmd, err := NewRequest(id, method, params)
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  	return json.Marshal(rawCmd)
    68  }
    69  
    70  // checkNumParams ensures the supplied number of params is at least the minimum
    71  // required number for the command and less than the maximum allowed.
    72  func checkNumParams(numParams int, info *methodInfo) error {
    73  	if numParams < info.numReqParams || numParams > info.maxParams {
    74  		if info.numReqParams == info.maxParams {
    75  			str := fmt.Sprintf("wrong number of params (expected "+
    76  				"%d, received %d)", info.numReqParams,
    77  				numParams)
    78  			return makeError(ErrNumParams, str)
    79  		}
    80  
    81  		str := fmt.Sprintf("wrong number of params (expected "+
    82  			"between %d and %d, received %d)", info.numReqParams,
    83  			info.maxParams, numParams)
    84  		return makeError(ErrNumParams, str)
    85  	}
    86  
    87  	return nil
    88  }
    89  
    90  // populateDefaults populates default values into any remaining optional struct
    91  // fields that did not have parameters explicitly provided.  The caller should
    92  // have previously checked that the number of parameters being passed is at
    93  // least the required number of parameters to avoid unnecessary work in this
    94  // function, but since required fields never have default values, it will work
    95  // properly even without the check.
    96  func populateDefaults(numParams int, info *methodInfo, rv reflect.Value) {
    97  	// When there are no more parameters left in the supplied parameters,
    98  	// any remaining struct fields must be optional.  Thus, populate them
    99  	// with their associated default value as needed.
   100  	for i := numParams; i < info.maxParams; i++ {
   101  		rvf := rv.Field(i)
   102  		if defaultVal, ok := info.defaults[i]; ok {
   103  			rvf.Set(defaultVal)
   104  		}
   105  	}
   106  }
   107  
   108  // UnmarshalCmd unmarshals a JSON-RPC request into a suitable concrete command
   109  // so long as the method type contained within the marshalled request is
   110  // registered.
   111  func UnmarshalCmd(r *Request) (interface{}, error) {
   112  	registerLock.RLock()
   113  	rtp, ok := methodToConcreteType[r.Method]
   114  	info := methodToInfo[r.Method]
   115  	registerLock.RUnlock()
   116  	if !ok {
   117  		str := fmt.Sprintf("%q is not registered", r.Method)
   118  		return nil, makeError(ErrUnregisteredMethod, str)
   119  	}
   120  	rt := rtp.Elem()
   121  	rvp := reflect.New(rt)
   122  	rv := rvp.Elem()
   123  
   124  	// Ensure the number of parameters are correct.
   125  	numParams := len(r.Params)
   126  	if err := checkNumParams(numParams, &info); err != nil {
   127  		return nil, err
   128  	}
   129  
   130  	// Loop through each of the struct fields and unmarshal the associated
   131  	// parameter into them.
   132  	for i := 0; i < numParams; i++ {
   133  		rvf := rv.Field(i)
   134  		// Unmarshal the parameter into the struct field.
   135  		concreteVal := rvf.Addr().Interface()
   136  		if err := json.Unmarshal(r.Params[i], &concreteVal); err != nil {
   137  			// The most common error is the wrong type, so
   138  			// explicitly detect that error and make it nicer.
   139  			fieldName := strings.ToLower(rt.Field(i).Name)
   140  			if jerr, ok := err.(*json.UnmarshalTypeError); ok {
   141  				str := fmt.Sprintf("parameter #%d '%s' must "+
   142  					"be type %v (got %v)", i+1, fieldName,
   143  					jerr.Type, jerr.Value)
   144  				return nil, makeError(ErrInvalidType, str)
   145  			}
   146  
   147  			// Fallback to showing the underlying error.
   148  			str := fmt.Sprintf("parameter #%d '%s' failed to "+
   149  				"unmarshal: %v", i+1, fieldName, err)
   150  			return nil, makeError(ErrInvalidType, str)
   151  		}
   152  	}
   153  
   154  	// When there are less supplied parameters than the total number of
   155  	// params, any remaining struct fields must be optional.  Thus, populate
   156  	// them with their associated default value as needed.
   157  	if numParams < info.maxParams {
   158  		populateDefaults(numParams, &info, rv)
   159  	}
   160  
   161  	return rvp.Interface(), nil
   162  }
   163  
   164  // isNumeric returns whether the passed reflect kind is a signed or unsigned
   165  // integer of any magnitude or a float of any magnitude.
   166  func isNumeric(kind reflect.Kind) bool {
   167  	switch kind {
   168  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
   169  		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
   170  		reflect.Uint64, reflect.Float32, reflect.Float64:
   171  
   172  		return true
   173  	}
   174  
   175  	return false
   176  }
   177  
   178  // typesMaybeCompatible returns whether the source type can possibly be
   179  // assigned to the destination type.  This is intended as a relatively quick
   180  // check to weed out obviously invalid conversions.
   181  func typesMaybeCompatible(dest reflect.Type, src reflect.Type) bool {
   182  	// The same types are obviously compatible.
   183  	if dest == src {
   184  		return true
   185  	}
   186  
   187  	// When both types are numeric, they are potentially compatibile.
   188  	srcKind := src.Kind()
   189  	destKind := dest.Kind()
   190  	if isNumeric(destKind) && isNumeric(srcKind) {
   191  		return true
   192  	}
   193  
   194  	if srcKind == reflect.String {
   195  		// Strings can potentially be converted to numeric types.
   196  		if isNumeric(destKind) {
   197  			return true
   198  		}
   199  
   200  		switch destKind {
   201  		// Strings can potentially be converted to bools by
   202  		// strconv.ParseBool.
   203  		case reflect.Bool:
   204  			return true
   205  
   206  		// Strings can be converted to any other type which has as
   207  		// underlying type of string.
   208  		case reflect.String:
   209  			return true
   210  
   211  		// Strings can potentially be converted to arrays, slice,
   212  		// structs, and maps via json.Unmarshal.
   213  		case reflect.Array, reflect.Slice, reflect.Struct, reflect.Map:
   214  			return true
   215  		}
   216  	}
   217  
   218  	return false
   219  }
   220  
   221  // baseType returns the type of the argument after indirecting through all
   222  // pointers along with how many indirections were necessary.
   223  func baseType(arg reflect.Type) (reflect.Type, int) {
   224  	var numIndirects int
   225  	for arg.Kind() == reflect.Ptr {
   226  		arg = arg.Elem()
   227  		numIndirects++
   228  	}
   229  	return arg, numIndirects
   230  }
   231  
   232  // assignField is the main workhorse for the NewCmd function which handles
   233  // assigning the provided source value to the destination field.  It supports
   234  // direct type assignments, indirection, conversion of numeric types, and
   235  // unmarshaling of strings into arrays, slices, structs, and maps via
   236  // json.Unmarshal.
   237  func assignField(paramNum int, fieldName string, dest reflect.Value, src reflect.Value) error {
   238  	// Just error now when the types have no chance of being compatible.
   239  	destBaseType, destIndirects := baseType(dest.Type())
   240  	srcBaseType, srcIndirects := baseType(src.Type())
   241  	if !typesMaybeCompatible(destBaseType, srcBaseType) {
   242  		str := fmt.Sprintf("parameter #%d '%s' must be type %v (got "+
   243  			"%v)", paramNum, fieldName, destBaseType, srcBaseType)
   244  		return makeError(ErrInvalidType, str)
   245  	}
   246  
   247  	// Check if it's possible to simply set the dest to the provided source.
   248  	// This is the case when the base types are the same or they are both
   249  	// pointers that can be indirected to be the same without needing to
   250  	// create pointers for the destination field.
   251  	if destBaseType == srcBaseType && srcIndirects >= destIndirects {
   252  		for i := 0; i < srcIndirects-destIndirects; i++ {
   253  			src = src.Elem()
   254  		}
   255  		dest.Set(src)
   256  		return nil
   257  	}
   258  
   259  	// When the destination has more indirects than the source, the extra
   260  	// pointers have to be created.  Only create enough pointers to reach
   261  	// the same level of indirection as the source so the dest can simply be
   262  	// set to the provided source when the types are the same.
   263  	destIndirectsRemaining := destIndirects
   264  	if destIndirects > srcIndirects {
   265  		indirectDiff := destIndirects - srcIndirects
   266  		for i := 0; i < indirectDiff; i++ {
   267  			dest.Set(reflect.New(dest.Type().Elem()))
   268  			dest = dest.Elem()
   269  			destIndirectsRemaining--
   270  		}
   271  	}
   272  
   273  	if destBaseType == srcBaseType {
   274  		dest.Set(src)
   275  		return nil
   276  	}
   277  
   278  	// Make any remaining pointers needed to get to the base dest type since
   279  	// the above direct assign was not possible and conversions are done
   280  	// against the base types.
   281  	for i := 0; i < destIndirectsRemaining; i++ {
   282  		dest.Set(reflect.New(dest.Type().Elem()))
   283  		dest = dest.Elem()
   284  	}
   285  
   286  	// Indirect through to the base source value.
   287  	for src.Kind() == reflect.Ptr {
   288  		src = src.Elem()
   289  	}
   290  
   291  	// Perform supported type conversions.
   292  	switch src.Kind() {
   293  	// Source value is a signed integer of various magnitude.
   294  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
   295  		reflect.Int64:
   296  
   297  		switch dest.Kind() {
   298  		// Destination is a signed integer of various magnitude.
   299  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
   300  			reflect.Int64:
   301  
   302  			srcInt := src.Int()
   303  			if dest.OverflowInt(srcInt) {
   304  				str := fmt.Sprintf("parameter #%d '%s' "+
   305  					"overflows destination type %v",
   306  					paramNum, fieldName, destBaseType)
   307  				return makeError(ErrInvalidType, str)
   308  			}
   309  
   310  			dest.SetInt(srcInt)
   311  
   312  		// Destination is an unsigned integer of various magnitude.
   313  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
   314  			reflect.Uint64:
   315  
   316  			srcInt := src.Int()
   317  			if srcInt < 0 || dest.OverflowUint(uint64(srcInt)) {
   318  				str := fmt.Sprintf("parameter #%d '%s' "+
   319  					"overflows destination type %v",
   320  					paramNum, fieldName, destBaseType)
   321  				return makeError(ErrInvalidType, str)
   322  			}
   323  			dest.SetUint(uint64(srcInt))
   324  
   325  		default:
   326  			str := fmt.Sprintf("parameter #%d '%s' must be type "+
   327  				"%v (got %v)", paramNum, fieldName, destBaseType,
   328  				srcBaseType)
   329  			return makeError(ErrInvalidType, str)
   330  		}
   331  
   332  	// Source value is an unsigned integer of various magnitude.
   333  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
   334  		reflect.Uint64:
   335  
   336  		switch dest.Kind() {
   337  		// Destination is a signed integer of various magnitude.
   338  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
   339  			reflect.Int64:
   340  
   341  			srcUint := src.Uint()
   342  			if srcUint > uint64(1<<63)-1 {
   343  				str := fmt.Sprintf("parameter #%d '%s' "+
   344  					"overflows destination type %v",
   345  					paramNum, fieldName, destBaseType)
   346  				return makeError(ErrInvalidType, str)
   347  			}
   348  			if dest.OverflowInt(int64(srcUint)) {
   349  				str := fmt.Sprintf("parameter #%d '%s' "+
   350  					"overflows destination type %v",
   351  					paramNum, fieldName, destBaseType)
   352  				return makeError(ErrInvalidType, str)
   353  			}
   354  			dest.SetInt(int64(srcUint))
   355  
   356  		// Destination is an unsigned integer of various magnitude.
   357  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
   358  			reflect.Uint64:
   359  
   360  			srcUint := src.Uint()
   361  			if dest.OverflowUint(srcUint) {
   362  				str := fmt.Sprintf("parameter #%d '%s' "+
   363  					"overflows destination type %v",
   364  					paramNum, fieldName, destBaseType)
   365  				return makeError(ErrInvalidType, str)
   366  			}
   367  			dest.SetUint(srcUint)
   368  
   369  		default:
   370  			str := fmt.Sprintf("parameter #%d '%s' must be type "+
   371  				"%v (got %v)", paramNum, fieldName, destBaseType,
   372  				srcBaseType)
   373  			return makeError(ErrInvalidType, str)
   374  		}
   375  
   376  	// Source value is a float.
   377  	case reflect.Float32, reflect.Float64:
   378  		destKind := dest.Kind()
   379  		if destKind != reflect.Float32 && destKind != reflect.Float64 {
   380  			str := fmt.Sprintf("parameter #%d '%s' must be type "+
   381  				"%v (got %v)", paramNum, fieldName, destBaseType,
   382  				srcBaseType)
   383  			return makeError(ErrInvalidType, str)
   384  		}
   385  
   386  		srcFloat := src.Float()
   387  		if dest.OverflowFloat(srcFloat) {
   388  			str := fmt.Sprintf("parameter #%d '%s' overflows "+
   389  				"destination type %v", paramNum, fieldName,
   390  				destBaseType)
   391  			return makeError(ErrInvalidType, str)
   392  		}
   393  		dest.SetFloat(srcFloat)
   394  
   395  	// Source value is a string.
   396  	case reflect.String:
   397  		switch dest.Kind() {
   398  		// String -> bool
   399  		case reflect.Bool:
   400  			b, err := strconv.ParseBool(src.String())
   401  			if err != nil {
   402  				str := fmt.Sprintf("parameter #%d '%s' must "+
   403  					"parse to a %v", paramNum, fieldName,
   404  					destBaseType)
   405  				return makeError(ErrInvalidType, str)
   406  			}
   407  			dest.SetBool(b)
   408  
   409  		// String -> signed integer of varying size.
   410  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
   411  			reflect.Int64:
   412  
   413  			srcInt, err := strconv.ParseInt(src.String(), 0, 0)
   414  			if err != nil {
   415  				str := fmt.Sprintf("parameter #%d '%s' must "+
   416  					"parse to a %v", paramNum, fieldName,
   417  					destBaseType)
   418  				return makeError(ErrInvalidType, str)
   419  			}
   420  			if dest.OverflowInt(srcInt) {
   421  				str := fmt.Sprintf("parameter #%d '%s' "+
   422  					"overflows destination type %v",
   423  					paramNum, fieldName, destBaseType)
   424  				return makeError(ErrInvalidType, str)
   425  			}
   426  			dest.SetInt(srcInt)
   427  
   428  		// String -> unsigned integer of varying size.
   429  		case reflect.Uint, reflect.Uint8, reflect.Uint16,
   430  			reflect.Uint32, reflect.Uint64:
   431  
   432  			srcUint, err := strconv.ParseUint(src.String(), 0, 0)
   433  			if err != nil {
   434  				str := fmt.Sprintf("parameter #%d '%s' must "+
   435  					"parse to a %v", paramNum, fieldName,
   436  					destBaseType)
   437  				return makeError(ErrInvalidType, str)
   438  			}
   439  			if dest.OverflowUint(srcUint) {
   440  				str := fmt.Sprintf("parameter #%d '%s' "+
   441  					"overflows destination type %v",
   442  					paramNum, fieldName, destBaseType)
   443  				return makeError(ErrInvalidType, str)
   444  			}
   445  			dest.SetUint(srcUint)
   446  
   447  		// String -> float of varying size.
   448  		case reflect.Float32, reflect.Float64:
   449  			srcFloat, err := strconv.ParseFloat(src.String(), 0)
   450  			if err != nil {
   451  				str := fmt.Sprintf("parameter #%d '%s' must "+
   452  					"parse to a %v", paramNum, fieldName,
   453  					destBaseType)
   454  				return makeError(ErrInvalidType, str)
   455  			}
   456  			if dest.OverflowFloat(srcFloat) {
   457  				str := fmt.Sprintf("parameter #%d '%s' "+
   458  					"overflows destination type %v",
   459  					paramNum, fieldName, destBaseType)
   460  				return makeError(ErrInvalidType, str)
   461  			}
   462  			dest.SetFloat(srcFloat)
   463  
   464  		// String -> string (typecast).
   465  		case reflect.String:
   466  			dest.SetString(src.String())
   467  
   468  		// String -> arrays, slices, structs, and maps via
   469  		// json.Unmarshal.
   470  		case reflect.Array, reflect.Slice, reflect.Struct, reflect.Map:
   471  			concreteVal := dest.Addr().Interface()
   472  			err := json.Unmarshal([]byte(src.String()), &concreteVal)
   473  			if err != nil {
   474  				str := fmt.Sprintf("parameter #%d '%s' must "+
   475  					"be valid JSON which unsmarshals to a %v",
   476  					paramNum, fieldName, destBaseType)
   477  				return makeError(ErrInvalidType, str)
   478  			}
   479  			dest.Set(reflect.ValueOf(concreteVal).Elem())
   480  		}
   481  	}
   482  
   483  	return nil
   484  }
   485  
   486  // NewCmd provides a generic mechanism to create a new command that can marshal
   487  // to a JSON-RPC request while respecting the requirements of the provided
   488  // method.  The method must have been registered with the package already along
   489  // with its type definition.  All methods associated with the commands exported
   490  // by this package are already registered by default.
   491  //
   492  // The arguments are most efficient when they are the exact same type as the
   493  // underlying field in the command struct associated with the the method,
   494  // however this function also will perform a variety of conversions to make it
   495  // more flexible.  This allows, for example, command line args which are strings
   496  // to be passed unaltered.  In particular, the following conversions are
   497  // supported:
   498  //
   499  //   - Conversion between any size signed or unsigned integer so long as the
   500  //     value does not overflow the destination type
   501  //   - Conversion between float32 and float64 so long as the value does not
   502  //     overflow the destination type
   503  //   - Conversion from string to boolean for everything strconv.ParseBool
   504  //     recognizes
   505  //   - Conversion from string to any size integer for everything
   506  //     strconv.ParseInt and strconv.ParseUint recognizes
   507  //   - Conversion from string to any size float for everything
   508  //     strconv.ParseFloat recognizes
   509  //   - Conversion from string to arrays, slices, structs, and maps by treating
   510  //     the string as marshalled JSON and calling json.Unmarshal into the
   511  //     destination field
   512  func NewCmd(method string, args ...interface{}) (interface{}, error) {
   513  	// Look up details about the provided method.  Any methods that aren't
   514  	// registered are an error.
   515  	registerLock.RLock()
   516  	rtp, ok := methodToConcreteType[method]
   517  	info := methodToInfo[method]
   518  	registerLock.RUnlock()
   519  	if !ok {
   520  		str := fmt.Sprintf("%q is not registered", method)
   521  		return nil, makeError(ErrUnregisteredMethod, str)
   522  	}
   523  
   524  	// Ensure the number of parameters are correct.
   525  	numParams := len(args)
   526  	if err := checkNumParams(numParams, &info); err != nil {
   527  		return nil, err
   528  	}
   529  
   530  	// Create the appropriate command type for the method.  Since all types
   531  	// are enforced to be a pointer to a struct at registration time, it's
   532  	// safe to indirect to the struct now.
   533  	rvp := reflect.New(rtp.Elem())
   534  	rv := rvp.Elem()
   535  	rt := rtp.Elem()
   536  
   537  	// Loop through each of the struct fields and assign the associated
   538  	// parameter into them after checking its type validity.
   539  	for i := 0; i < numParams; i++ {
   540  		// Attempt to assign each of the arguments to the according
   541  		// struct field.
   542  		rvf := rv.Field(i)
   543  		fieldName := strings.ToLower(rt.Field(i).Name)
   544  		err := assignField(i+1, fieldName, rvf, reflect.ValueOf(args[i]))
   545  		if err != nil {
   546  			return nil, err
   547  		}
   548  	}
   549  
   550  	return rvp.Interface(), nil
   551  }