github.com/btcsuite/btcd@v0.24.0/btcjson/cmdparse.go (about)

     1  // Copyright (c) 2014 The btcsuite developers
     2  // Use of this source code is governed by an ISC
     3  // license that can be found in the LICENSE file.
     4  
     5  package btcjson
     6  
     7  import (
     8  	"encoding/json"
     9  	"fmt"
    10  	"reflect"
    11  	"strconv"
    12  	"strings"
    13  )
    14  
    15  // makeParams creates a slice of interface values for the given struct.
    16  func makeParams(rt reflect.Type, rv reflect.Value) []interface{} {
    17  	numFields := rt.NumField()
    18  	params := make([]interface{}, 0, numFields)
    19  	lastParam := -1
    20  	for i := 0; i < numFields; i++ {
    21  		rtf := rt.Field(i)
    22  		rvf := rv.Field(i)
    23  		params = append(params, rvf.Interface())
    24  		if rtf.Type.Kind() == reflect.Ptr {
    25  			if rvf.IsNil() {
    26  				// Omit optional null params unless a non-null param follows
    27  				continue
    28  			}
    29  		}
    30  		lastParam = i
    31  	}
    32  	return params[:lastParam+1]
    33  }
    34  
    35  // MarshalCmd marshals the passed command to a JSON-RPC request byte slice that
    36  // is suitable for transmission to an RPC server.  The provided command type
    37  // must be a registered type.  All commands provided by this package are
    38  // registered by default.
    39  func MarshalCmd(rpcVersion RPCVersion, id interface{}, cmd interface{}) ([]byte, error) {
    40  	// Look up the cmd type and error out if not registered.
    41  	rt := reflect.TypeOf(cmd)
    42  	registerLock.RLock()
    43  	method, ok := concreteTypeToMethod[rt]
    44  	registerLock.RUnlock()
    45  	if !ok {
    46  		str := fmt.Sprintf("%q is not registered", method)
    47  		return nil, makeError(ErrUnregisteredMethod, str)
    48  	}
    49  
    50  	// The provided command must not be nil.
    51  	rv := reflect.ValueOf(cmd)
    52  	if rv.IsNil() {
    53  		str := "the specified command is nil"
    54  		return nil, makeError(ErrInvalidType, str)
    55  	}
    56  
    57  	// Create a slice of interface values in the order of the struct fields
    58  	// while respecting pointer fields as optional params and only adding
    59  	// them if they are non-nil.
    60  	params := makeParams(rt.Elem(), rv.Elem())
    61  
    62  	// Generate and marshal the final JSON-RPC request.
    63  	rawCmd, err := NewRequest(rpcVersion, id, method, params)
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  	return json.Marshal(rawCmd)
    68  }
    69  
    70  // checkNumParams ensures the supplied number of params is at least the minimum
    71  // required number for the command and less than the maximum allowed.
    72  func checkNumParams(numParams int, info *methodInfo) error {
    73  	if numParams < info.numReqParams || numParams > info.maxParams {
    74  		if info.numReqParams == info.maxParams {
    75  			str := fmt.Sprintf("wrong number of params (expected "+
    76  				"%d, received %d)", info.numReqParams,
    77  				numParams)
    78  			return makeError(ErrNumParams, str)
    79  		}
    80  
    81  		str := fmt.Sprintf("wrong number of params (expected "+
    82  			"between %d and %d, received %d)", info.numReqParams,
    83  			info.maxParams, numParams)
    84  		return makeError(ErrNumParams, str)
    85  	}
    86  
    87  	return nil
    88  }
    89  
    90  // populateDefaults populates default values into any remaining optional struct
    91  // fields that did not have parameters explicitly provided.  The caller should
    92  // have previously checked that the number of parameters being passed is at
    93  // least the required number of parameters to avoid unnecessary work in this
    94  // function, but since required fields never have default values, it will work
    95  // properly even without the check.
    96  func populateDefaults(numParams int, info *methodInfo, rv reflect.Value) {
    97  	// When there are no more parameters left in the supplied parameters,
    98  	// any remaining struct fields must be optional.  Thus, populate them
    99  	// with their associated default value as needed.
   100  	for i := numParams; i < info.maxParams; i++ {
   101  		rvf := rv.Field(i)
   102  		if defaultVal, ok := info.defaults[i]; ok {
   103  			rvf.Set(defaultVal)
   104  		}
   105  	}
   106  }
   107  
   108  // UnmarshalCmd unmarshals a JSON-RPC request into a suitable concrete command
   109  // so long as the method type contained within the marshalled request is
   110  // registered.
   111  func UnmarshalCmd(r *Request) (interface{}, error) {
   112  	registerLock.RLock()
   113  	rtp, ok := methodToConcreteType[r.Method]
   114  	info := methodToInfo[r.Method]
   115  	registerLock.RUnlock()
   116  	if !ok {
   117  		str := fmt.Sprintf("%q is not registered", r.Method)
   118  		return nil, makeError(ErrUnregisteredMethod, str)
   119  	}
   120  	rt := rtp.Elem()
   121  	rvp := reflect.New(rt)
   122  	rv := rvp.Elem()
   123  
   124  	// Ensure the number of parameters are correct.
   125  	numParams := len(r.Params)
   126  	if err := checkNumParams(numParams, &info); err != nil {
   127  		return nil, err
   128  	}
   129  
   130  	// Loop through each of the struct fields and unmarshal the associated
   131  	// parameter into them.
   132  	for i := 0; i < numParams; i++ {
   133  		rvf := rv.Field(i)
   134  		// Unmarshal the parameter into the struct field.
   135  		concreteVal := rvf.Addr().Interface()
   136  		if err := json.Unmarshal(r.Params[i], &concreteVal); err != nil {
   137  			// The most common error is the wrong type, so
   138  			// explicitly detect that error and make it nicer.
   139  			fieldName := strings.ToLower(rt.Field(i).Name)
   140  			if jerr, ok := err.(*json.UnmarshalTypeError); ok {
   141  				str := fmt.Sprintf("parameter #%d '%s' must "+
   142  					"be type %v (got %v)", i+1, fieldName,
   143  					jerr.Type, jerr.Value)
   144  				return nil, makeError(ErrInvalidType, str)
   145  			}
   146  
   147  			// Fallback to showing the underlying error.
   148  			str := fmt.Sprintf("parameter #%d '%s' failed to "+
   149  				"unmarshal: %v", i+1, fieldName, err)
   150  			return nil, makeError(ErrInvalidType, str)
   151  		}
   152  	}
   153  
   154  	// When there are less supplied parameters than the total number of
   155  	// params, any remaining struct fields must be optional.  Thus, populate
   156  	// them with their associated default value as needed.
   157  	if numParams < info.maxParams {
   158  		populateDefaults(numParams, &info, rv)
   159  	}
   160  
   161  	return rvp.Interface(), nil
   162  }
   163  
   164  // isNumeric returns whether the passed reflect kind is a signed or unsigned
   165  // integer of any magnitude or a float of any magnitude.
   166  func isNumeric(kind reflect.Kind) bool {
   167  	switch kind {
   168  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
   169  		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
   170  		reflect.Uint64, reflect.Float32, reflect.Float64:
   171  
   172  		return true
   173  	}
   174  
   175  	return false
   176  }
   177  
   178  // typesMaybeCompatible returns whether the source type can possibly be
   179  // assigned to the destination type.  This is intended as a relatively quick
   180  // check to weed out obviously invalid conversions.
   181  func typesMaybeCompatible(dest reflect.Type, src reflect.Type) bool {
   182  	// The same types are obviously compatible.
   183  	if dest == src {
   184  		return true
   185  	}
   186  
   187  	// When both types are numeric, they are potentially compatible.
   188  	srcKind := src.Kind()
   189  	destKind := dest.Kind()
   190  	if isNumeric(destKind) && isNumeric(srcKind) {
   191  		return true
   192  	}
   193  
   194  	if srcKind == reflect.String {
   195  		// Strings can potentially be converted to numeric types.
   196  		if isNumeric(destKind) {
   197  			return true
   198  		}
   199  
   200  		switch destKind {
   201  		// Strings can potentially be converted to bools by
   202  		// strconv.ParseBool.
   203  		case reflect.Bool:
   204  			return true
   205  
   206  		// Strings can be converted to any other type which has as
   207  		// underlying type of string.
   208  		case reflect.String:
   209  			return true
   210  
   211  		// Strings can potentially be converted to arrays, slice,
   212  		// structs, and maps via json.Unmarshal.
   213  		case reflect.Array, reflect.Slice, reflect.Struct, reflect.Map:
   214  			return true
   215  		}
   216  	}
   217  
   218  	return false
   219  }
   220  
   221  // baseType returns the type of the argument after indirecting through all
   222  // pointers along with how many indirections were necessary.
   223  func baseType(arg reflect.Type) (reflect.Type, int) {
   224  	var numIndirects int
   225  	for arg.Kind() == reflect.Ptr {
   226  		arg = arg.Elem()
   227  		numIndirects++
   228  	}
   229  	return arg, numIndirects
   230  }
   231  
   232  // assignField is the main workhorse for the NewCmd function which handles
   233  // assigning the provided source value to the destination field.  It supports
   234  // direct type assignments, indirection, conversion of numeric types, and
   235  // unmarshaling of strings into arrays, slices, structs, and maps via
   236  // json.Unmarshal.
   237  func assignField(paramNum int, fieldName string, dest reflect.Value, src reflect.Value) error {
   238  	// Just error now when the types have no chance of being compatible.
   239  	destBaseType, destIndirects := baseType(dest.Type())
   240  	srcBaseType, srcIndirects := baseType(src.Type())
   241  	if !typesMaybeCompatible(destBaseType, srcBaseType) {
   242  		str := fmt.Sprintf("parameter #%d '%s' must be type %v (got "+
   243  			"%v)", paramNum, fieldName, destBaseType, srcBaseType)
   244  		return makeError(ErrInvalidType, str)
   245  	}
   246  
   247  	// Check if it's possible to simply set the dest to the provided source.
   248  	// This is the case when the base types are the same or they are both
   249  	// pointers that can be indirected to be the same without needing to
   250  	// create pointers for the destination field.
   251  	if destBaseType == srcBaseType && srcIndirects >= destIndirects {
   252  		for i := 0; i < srcIndirects-destIndirects; i++ {
   253  			src = src.Elem()
   254  		}
   255  		dest.Set(src)
   256  		return nil
   257  	}
   258  
   259  	// Optional variables can be set null using "null" string
   260  	if destIndirects > 0 && src.String() == "null" {
   261  		return nil
   262  	}
   263  
   264  	// When the destination has more indirects than the source, the extra
   265  	// pointers have to be created.  Only create enough pointers to reach
   266  	// the same level of indirection as the source so the dest can simply be
   267  	// set to the provided source when the types are the same.
   268  	destIndirectsRemaining := destIndirects
   269  	if destIndirects > srcIndirects {
   270  		indirectDiff := destIndirects - srcIndirects
   271  		for i := 0; i < indirectDiff; i++ {
   272  			dest.Set(reflect.New(dest.Type().Elem()))
   273  			dest = dest.Elem()
   274  			destIndirectsRemaining--
   275  		}
   276  	}
   277  
   278  	if destBaseType == srcBaseType {
   279  		dest.Set(src)
   280  		return nil
   281  	}
   282  
   283  	// Make any remaining pointers needed to get to the base dest type since
   284  	// the above direct assign was not possible and conversions are done
   285  	// against the base types.
   286  	for i := 0; i < destIndirectsRemaining; i++ {
   287  		dest.Set(reflect.New(dest.Type().Elem()))
   288  		dest = dest.Elem()
   289  	}
   290  
   291  	// Indirect through to the base source value.
   292  	for src.Kind() == reflect.Ptr {
   293  		src = src.Elem()
   294  	}
   295  
   296  	// Perform supported type conversions.
   297  	switch src.Kind() {
   298  	// Source value is a signed integer of various magnitude.
   299  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
   300  		reflect.Int64:
   301  
   302  		switch dest.Kind() {
   303  		// Destination is a signed integer of various magnitude.
   304  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
   305  			reflect.Int64:
   306  
   307  			srcInt := src.Int()
   308  			if dest.OverflowInt(srcInt) {
   309  				str := fmt.Sprintf("parameter #%d '%s' "+
   310  					"overflows destination type %v",
   311  					paramNum, fieldName, destBaseType)
   312  				return makeError(ErrInvalidType, str)
   313  			}
   314  
   315  			dest.SetInt(srcInt)
   316  
   317  		// Destination is an unsigned integer of various magnitude.
   318  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
   319  			reflect.Uint64:
   320  
   321  			srcInt := src.Int()
   322  			if srcInt < 0 || dest.OverflowUint(uint64(srcInt)) {
   323  				str := fmt.Sprintf("parameter #%d '%s' "+
   324  					"overflows destination type %v",
   325  					paramNum, fieldName, destBaseType)
   326  				return makeError(ErrInvalidType, str)
   327  			}
   328  			dest.SetUint(uint64(srcInt))
   329  
   330  		default:
   331  			str := fmt.Sprintf("parameter #%d '%s' must be type "+
   332  				"%v (got %v)", paramNum, fieldName, destBaseType,
   333  				srcBaseType)
   334  			return makeError(ErrInvalidType, str)
   335  		}
   336  
   337  	// Source value is an unsigned integer of various magnitude.
   338  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
   339  		reflect.Uint64:
   340  
   341  		switch dest.Kind() {
   342  		// Destination is a signed integer of various magnitude.
   343  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
   344  			reflect.Int64:
   345  
   346  			srcUint := src.Uint()
   347  			if srcUint > uint64(1<<63)-1 {
   348  				str := fmt.Sprintf("parameter #%d '%s' "+
   349  					"overflows destination type %v",
   350  					paramNum, fieldName, destBaseType)
   351  				return makeError(ErrInvalidType, str)
   352  			}
   353  			if dest.OverflowInt(int64(srcUint)) {
   354  				str := fmt.Sprintf("parameter #%d '%s' "+
   355  					"overflows destination type %v",
   356  					paramNum, fieldName, destBaseType)
   357  				return makeError(ErrInvalidType, str)
   358  			}
   359  			dest.SetInt(int64(srcUint))
   360  
   361  		// Destination is an unsigned integer of various magnitude.
   362  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
   363  			reflect.Uint64:
   364  
   365  			srcUint := src.Uint()
   366  			if dest.OverflowUint(srcUint) {
   367  				str := fmt.Sprintf("parameter #%d '%s' "+
   368  					"overflows destination type %v",
   369  					paramNum, fieldName, destBaseType)
   370  				return makeError(ErrInvalidType, str)
   371  			}
   372  			dest.SetUint(srcUint)
   373  
   374  		default:
   375  			str := fmt.Sprintf("parameter #%d '%s' must be type "+
   376  				"%v (got %v)", paramNum, fieldName, destBaseType,
   377  				srcBaseType)
   378  			return makeError(ErrInvalidType, str)
   379  		}
   380  
   381  	// Source value is a float.
   382  	case reflect.Float32, reflect.Float64:
   383  		destKind := dest.Kind()
   384  		if destKind != reflect.Float32 && destKind != reflect.Float64 {
   385  			str := fmt.Sprintf("parameter #%d '%s' must be type "+
   386  				"%v (got %v)", paramNum, fieldName, destBaseType,
   387  				srcBaseType)
   388  			return makeError(ErrInvalidType, str)
   389  		}
   390  
   391  		srcFloat := src.Float()
   392  		if dest.OverflowFloat(srcFloat) {
   393  			str := fmt.Sprintf("parameter #%d '%s' overflows "+
   394  				"destination type %v", paramNum, fieldName,
   395  				destBaseType)
   396  			return makeError(ErrInvalidType, str)
   397  		}
   398  		dest.SetFloat(srcFloat)
   399  
   400  	// Source value is a string.
   401  	case reflect.String:
   402  		switch dest.Kind() {
   403  		// String -> bool
   404  		case reflect.Bool:
   405  			b, err := strconv.ParseBool(src.String())
   406  			if err != nil {
   407  				str := fmt.Sprintf("parameter #%d '%s' must "+
   408  					"parse to a %v", paramNum, fieldName,
   409  					destBaseType)
   410  				return makeError(ErrInvalidType, str)
   411  			}
   412  			dest.SetBool(b)
   413  
   414  		// String -> signed integer of varying size.
   415  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
   416  			reflect.Int64:
   417  
   418  			srcInt, err := strconv.ParseInt(src.String(), 0, 0)
   419  			if err != nil {
   420  				str := fmt.Sprintf("parameter #%d '%s' must "+
   421  					"parse to a %v", paramNum, fieldName,
   422  					destBaseType)
   423  				return makeError(ErrInvalidType, str)
   424  			}
   425  			if dest.OverflowInt(srcInt) {
   426  				str := fmt.Sprintf("parameter #%d '%s' "+
   427  					"overflows destination type %v",
   428  					paramNum, fieldName, destBaseType)
   429  				return makeError(ErrInvalidType, str)
   430  			}
   431  			dest.SetInt(srcInt)
   432  
   433  		// String -> unsigned integer of varying size.
   434  		case reflect.Uint, reflect.Uint8, reflect.Uint16,
   435  			reflect.Uint32, reflect.Uint64:
   436  
   437  			srcUint, err := strconv.ParseUint(src.String(), 0, 0)
   438  			if err != nil {
   439  				str := fmt.Sprintf("parameter #%d '%s' must "+
   440  					"parse to a %v", paramNum, fieldName,
   441  					destBaseType)
   442  				return makeError(ErrInvalidType, str)
   443  			}
   444  			if dest.OverflowUint(srcUint) {
   445  				str := fmt.Sprintf("parameter #%d '%s' "+
   446  					"overflows destination type %v",
   447  					paramNum, fieldName, destBaseType)
   448  				return makeError(ErrInvalidType, str)
   449  			}
   450  			dest.SetUint(srcUint)
   451  
   452  		// String -> float of varying size.
   453  		case reflect.Float32, reflect.Float64:
   454  			srcFloat, err := strconv.ParseFloat(src.String(), 0)
   455  			if err != nil {
   456  				str := fmt.Sprintf("parameter #%d '%s' must "+
   457  					"parse to a %v", paramNum, fieldName,
   458  					destBaseType)
   459  				return makeError(ErrInvalidType, str)
   460  			}
   461  			if dest.OverflowFloat(srcFloat) {
   462  				str := fmt.Sprintf("parameter #%d '%s' "+
   463  					"overflows destination type %v",
   464  					paramNum, fieldName, destBaseType)
   465  				return makeError(ErrInvalidType, str)
   466  			}
   467  			dest.SetFloat(srcFloat)
   468  
   469  		// String -> string (typecast).
   470  		case reflect.String:
   471  			dest.SetString(src.String())
   472  
   473  		// String -> arrays, slices, structs, and maps via
   474  		// json.Unmarshal.
   475  		case reflect.Array, reflect.Slice, reflect.Struct, reflect.Map:
   476  			concreteVal := dest.Addr().Interface()
   477  			err := json.Unmarshal([]byte(src.String()), &concreteVal)
   478  			if err != nil {
   479  				str := fmt.Sprintf("parameter #%d '%s' must "+
   480  					"be valid JSON which unsmarshals to a %v",
   481  					paramNum, fieldName, destBaseType)
   482  				return makeError(ErrInvalidType, str)
   483  			}
   484  			dest.Set(reflect.ValueOf(concreteVal).Elem())
   485  		}
   486  	}
   487  
   488  	return nil
   489  }
   490  
   491  // NewCmd provides a generic mechanism to create a new command that can marshal
   492  // to a JSON-RPC request while respecting the requirements of the provided
   493  // method.  The method must have been registered with the package already along
   494  // with its type definition.  All methods associated with the commands exported
   495  // by this package are already registered by default.
   496  //
   497  // The arguments are most efficient when they are the exact same type as the
   498  // underlying field in the command struct associated with the method,
   499  // however this function also will perform a variety of conversions to make it
   500  // more flexible.  This allows, for example, command line args which are strings
   501  // to be passed unaltered.  In particular, the following conversions are
   502  // supported:
   503  //
   504  //   - Conversion between any size signed or unsigned integer so long as the
   505  //     value does not overflow the destination type
   506  //   - Conversion between float32 and float64 so long as the value does not
   507  //     overflow the destination type
   508  //   - Conversion from string to boolean for everything strconv.ParseBool
   509  //     recognizes
   510  //   - Conversion from string to any size integer for everything
   511  //     strconv.ParseInt and strconv.ParseUint recognizes
   512  //   - Conversion from string to any size float for everything
   513  //     strconv.ParseFloat recognizes
   514  //   - Conversion from string to arrays, slices, structs, and maps by treating
   515  //     the string as marshalled JSON and calling json.Unmarshal into the
   516  //     destination field
   517  func NewCmd(method string, args ...interface{}) (interface{}, error) {
   518  	// Look up details about the provided method.  Any methods that aren't
   519  	// registered are an error.
   520  	registerLock.RLock()
   521  	rtp, ok := methodToConcreteType[method]
   522  	info := methodToInfo[method]
   523  	registerLock.RUnlock()
   524  	if !ok {
   525  		str := fmt.Sprintf("%q is not registered", method)
   526  		return nil, makeError(ErrUnregisteredMethod, str)
   527  	}
   528  
   529  	// Ensure the number of parameters are correct.
   530  	numParams := len(args)
   531  	if err := checkNumParams(numParams, &info); err != nil {
   532  		return nil, err
   533  	}
   534  
   535  	// Create the appropriate command type for the method.  Since all types
   536  	// are enforced to be a pointer to a struct at registration time, it's
   537  	// safe to indirect to the struct now.
   538  	rvp := reflect.New(rtp.Elem())
   539  	rv := rvp.Elem()
   540  	rt := rtp.Elem()
   541  
   542  	// Loop through each of the struct fields and assign the associated
   543  	// parameter into them after checking its type validity.
   544  	for i := 0; i < numParams; i++ {
   545  		// Attempt to assign each of the arguments to the according
   546  		// struct field.
   547  		rvf := rv.Field(i)
   548  		fieldName := strings.ToLower(rt.Field(i).Name)
   549  		err := assignField(i+1, fieldName, rvf, reflect.ValueOf(args[i]))
   550  		if err != nil {
   551  			return nil, err
   552  		}
   553  	}
   554  
   555  	return rvp.Interface(), nil
   556  }