github.com/lbryio/lbcd@v0.22.119/btcjson/cmdparse.go (about)

     1  // Copyright (c) 2014 The btcsuite developers
     2  // Use of this source code is governed by an ISC
     3  // license that can be found in the LICENSE file.
     4  
     5  package btcjson
     6  
     7  import (
     8  	"encoding/json"
     9  	"fmt"
    10  	"reflect"
    11  	"strconv"
    12  	"strings"
    13  )
    14  
    15  // makeParams creates a slice of interface values for the given struct.
    16  func makeParams(rt reflect.Type, rv reflect.Value) []interface{} {
    17  	numFields := rt.NumField()
    18  	params := make([]interface{}, 0, numFields)
    19  	lastParam := -1
    20  	for i := 0; i < numFields; i++ {
    21  		rtf := rt.Field(i)
    22  		rvf := rv.Field(i)
    23  		params = append(params, rvf.Interface())
    24  		if rtf.Type.Kind() == reflect.Ptr {
    25  			if rvf.IsNil() {
    26  				// Omit optional null params unless a non-null param follows
    27  				continue
    28  			}
    29  		}
    30  		lastParam = i
    31  	}
    32  	return params[:lastParam+1]
    33  }
    34  
    35  // MarshalCmd marshals the passed command to a JSON-RPC request byte slice that
    36  // is suitable for transmission to an RPC server.  The provided command type
    37  // must be a registered type.  All commands provided by this package are
    38  // registered by default.
    39  func MarshalCmd(rpcVersion RPCVersion, id interface{}, cmd interface{}) ([]byte, error) {
    40  	// Look up the cmd type and error out if not registered.
    41  	rt := reflect.TypeOf(cmd)
    42  	registerLock.RLock()
    43  	method, ok := concreteTypeToMethod[rt]
    44  	registerLock.RUnlock()
    45  	if !ok {
    46  		str := fmt.Sprintf("%q is not registered", method)
    47  		return nil, makeError(ErrUnregisteredMethod, str)
    48  	}
    49  
    50  	// The provided command must not be nil.
    51  	rv := reflect.ValueOf(cmd)
    52  	if rv.IsNil() {
    53  		str := "the specified command is nil"
    54  		return nil, makeError(ErrInvalidType, str)
    55  	}
    56  
    57  	// Create a slice of interface values in the order of the struct fields
    58  	// while respecting pointer fields as optional params and only adding
    59  	// them if they are non-nil.
    60  	params := makeParams(rt.Elem(), rv.Elem())
    61  
    62  	// Generate and marshal the final JSON-RPC request.
    63  	rawCmd, err := NewRequest(rpcVersion, id, method, params)
    64  	if err != nil {
    65  		return nil, err
    66  	}
    67  	return json.Marshal(rawCmd)
    68  }
    69  
    70  // checkNumParams ensures the supplied number of params is at least the minimum
    71  // required number for the command and less than the maximum allowed.
    72  func checkNumParams(numParams int, info *methodInfo) error {
    73  	if numParams < info.numReqParams || numParams > info.maxParams {
    74  		if info.numReqParams == info.maxParams {
    75  			str := fmt.Sprintf("wrong number of params (expected "+
    76  				"%d, received %d)", info.numReqParams,
    77  				numParams)
    78  			return makeError(ErrNumParams, str)
    79  		}
    80  
    81  		str := fmt.Sprintf("wrong number of params (expected "+
    82  			"between %d and %d, received %d)", info.numReqParams,
    83  			info.maxParams, numParams)
    84  		return makeError(ErrNumParams, str)
    85  	}
    86  
    87  	return nil
    88  }
    89  
    90  // populateDefaults populates default values into any remaining optional struct
    91  // fields that did not have parameters explicitly provided.  The caller should
    92  // have previously checked that the number of parameters being passed is at
    93  // least the required number of parameters to avoid unnecessary work in this
    94  // function, but since required fields never have default values, it will work
    95  // properly even without the check.
    96  func populateDefaults(numParams int, info *methodInfo, rv reflect.Value) {
    97  	// When there are no more parameters left in the supplied parameters,
    98  	// any remaining struct fields must be optional.  Thus, populate them
    99  	// with their associated default value as needed.
   100  	for i := numParams; i < info.maxParams; i++ {
   101  		rvf := rv.Field(i)
   102  		if defaultVal, ok := info.defaults[i]; ok {
   103  			rvf.Set(defaultVal)
   104  		}
   105  	}
   106  }
   107  
   108  // UnmarshalCmd unmarshals a JSON-RPC request into a suitable concrete command
   109  // so long as the method type contained within the marshalled request is
   110  // registered.
   111  func UnmarshalCmd(r *Request) (interface{}, error) {
   112  	registerLock.RLock()
   113  	rtp, ok := methodToConcreteType[r.Method]
   114  	info := methodToInfo[r.Method]
   115  	registerLock.RUnlock()
   116  	if !ok {
   117  		str := fmt.Sprintf("%q is not registered", r.Method)
   118  		return nil, makeError(ErrUnregisteredMethod, str)
   119  	}
   120  	rt := rtp.Elem()
   121  	rvp := reflect.New(rt)
   122  	rv := rvp.Elem()
   123  
   124  	// Ensure the number of parameters are correct.
   125  	numParams := len(r.Params)
   126  	if err := checkNumParams(numParams, &info); err != nil {
   127  		return nil, err
   128  	}
   129  
   130  	// Loop through each of the struct fields and unmarshal the associated
   131  	// parameter into them.
   132  	for i := 0; i < numParams; i++ {
   133  		rvf := rv.Field(i)
   134  		// Unmarshal the parameter into the struct field.
   135  		concreteVal := rvf.Addr().Interface()
   136  		if err := json.Unmarshal(r.Params[i], &concreteVal); err != nil {
   137  			// Parse Integer into Bool for compatibility with lbrycrd.
   138  			if rvf.Kind() == reflect.Ptr &&
   139  				rvf.Elem().Type().Kind() == reflect.Bool {
   140  				boolInt, errBoolInt := strconv.Atoi(string(r.Params[i]))
   141  				if errBoolInt == nil {
   142  					rvf.Elem().SetBool(boolInt != 0)
   143  					continue
   144  				}
   145  			}
   146  
   147  			// The most common error is the wrong type, so
   148  			// explicitly detect that error and make it nicer.
   149  			fieldName := strings.ToLower(rt.Field(i).Name)
   150  			if jerr, ok := err.(*json.UnmarshalTypeError); ok {
   151  				str := fmt.Sprintf("parameter #%d '%s' must "+
   152  					"be type %v (got %v)", i+1, fieldName,
   153  					jerr.Type, jerr.Value)
   154  				return nil, makeError(ErrInvalidType, str)
   155  			}
   156  
   157  			// Fallback to showing the underlying error.
   158  			str := fmt.Sprintf("parameter #%d '%s' failed to "+
   159  				"unmarshal: %v", i+1, fieldName, err)
   160  			return nil, makeError(ErrInvalidType, str)
   161  		}
   162  	}
   163  
   164  	// When there are less supplied parameters than the total number of
   165  	// params, any remaining struct fields must be optional.  Thus, populate
   166  	// them with their associated default value as needed.
   167  	if numParams < info.maxParams {
   168  		populateDefaults(numParams, &info, rv)
   169  	}
   170  
   171  	return rvp.Interface(), nil
   172  }
   173  
   174  // isNumeric returns whether the passed reflect kind is a signed or unsigned
   175  // integer of any magnitude or a float of any magnitude.
   176  func isNumeric(kind reflect.Kind) bool {
   177  	switch kind {
   178  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
   179  		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
   180  		reflect.Uint64, reflect.Float32, reflect.Float64:
   181  
   182  		return true
   183  	}
   184  
   185  	return false
   186  }
   187  
   188  // typesMaybeCompatible returns whether the source type can possibly be
   189  // assigned to the destination type.  This is intended as a relatively quick
   190  // check to weed out obviously invalid conversions.
   191  func typesMaybeCompatible(dest reflect.Type, src reflect.Type) bool {
   192  	// The same types are obviously compatible.
   193  	if dest == src {
   194  		return true
   195  	}
   196  
   197  	// When both types are numeric, they are potentially compatible.
   198  	srcKind := src.Kind()
   199  	destKind := dest.Kind()
   200  	if isNumeric(destKind) && isNumeric(srcKind) {
   201  		return true
   202  	}
   203  
   204  	if srcKind == reflect.String {
   205  		// Strings can potentially be converted to numeric types.
   206  		if isNumeric(destKind) {
   207  			return true
   208  		}
   209  
   210  		switch destKind {
   211  		// Strings can potentially be converted to bools by
   212  		// strconv.ParseBool.
   213  		case reflect.Bool:
   214  			return true
   215  
   216  		// Strings can be converted to any other type which has as
   217  		// underlying type of string.
   218  		case reflect.String:
   219  			return true
   220  
   221  		// Strings can potentially be converted to arrays, slice,
   222  		// structs, and maps via json.Unmarshal.
   223  		case reflect.Array, reflect.Slice, reflect.Struct, reflect.Map:
   224  			return true
   225  		}
   226  	}
   227  
   228  	return false
   229  }
   230  
   231  // baseType returns the type of the argument after indirecting through all
   232  // pointers along with how many indirections were necessary.
   233  func baseType(arg reflect.Type) (reflect.Type, int) {
   234  	var numIndirects int
   235  	for arg.Kind() == reflect.Ptr {
   236  		arg = arg.Elem()
   237  		numIndirects++
   238  	}
   239  	return arg, numIndirects
   240  }
   241  
   242  // assignField is the main workhorse for the NewCmd function which handles
   243  // assigning the provided source value to the destination field.  It supports
   244  // direct type assignments, indirection, conversion of numeric types, and
   245  // unmarshaling of strings into arrays, slices, structs, and maps via
   246  // json.Unmarshal.
   247  func assignField(paramNum int, fieldName string, dest reflect.Value, src reflect.Value) error {
   248  	// Just error now when the types have no chance of being compatible.
   249  	destBaseType, destIndirects := baseType(dest.Type())
   250  	srcBaseType, srcIndirects := baseType(src.Type())
   251  	if !typesMaybeCompatible(destBaseType, srcBaseType) {
   252  		str := fmt.Sprintf("parameter #%d '%s' must be type %v (got "+
   253  			"%v)", paramNum, fieldName, destBaseType, srcBaseType)
   254  		return makeError(ErrInvalidType, str)
   255  	}
   256  
   257  	// Check if it's possible to simply set the dest to the provided source.
   258  	// This is the case when the base types are the same or they are both
   259  	// pointers that can be indirected to be the same without needing to
   260  	// create pointers for the destination field.
   261  	if destBaseType == srcBaseType && srcIndirects >= destIndirects {
   262  		for i := 0; i < srcIndirects-destIndirects; i++ {
   263  			src = src.Elem()
   264  		}
   265  		dest.Set(src)
   266  		return nil
   267  	}
   268  
   269  	// Optional variables can be set null using "null" string
   270  	if destIndirects > 0 && src.String() == "null" {
   271  		return nil
   272  	}
   273  
   274  	// When the destination has more indirects than the source, the extra
   275  	// pointers have to be created.  Only create enough pointers to reach
   276  	// the same level of indirection as the source so the dest can simply be
   277  	// set to the provided source when the types are the same.
   278  	destIndirectsRemaining := destIndirects
   279  	if destIndirects > srcIndirects {
   280  		indirectDiff := destIndirects - srcIndirects
   281  		for i := 0; i < indirectDiff; i++ {
   282  			dest.Set(reflect.New(dest.Type().Elem()))
   283  			dest = dest.Elem()
   284  			destIndirectsRemaining--
   285  		}
   286  	}
   287  
   288  	if destBaseType == srcBaseType {
   289  		dest.Set(src)
   290  		return nil
   291  	}
   292  
   293  	// Make any remaining pointers needed to get to the base dest type since
   294  	// the above direct assign was not possible and conversions are done
   295  	// against the base types.
   296  	for i := 0; i < destIndirectsRemaining; i++ {
   297  		dest.Set(reflect.New(dest.Type().Elem()))
   298  		dest = dest.Elem()
   299  	}
   300  
   301  	// Indirect through to the base source value.
   302  	for src.Kind() == reflect.Ptr {
   303  		src = src.Elem()
   304  	}
   305  
   306  	// Perform supported type conversions.
   307  	switch src.Kind() {
   308  	// Source value is a signed integer of various magnitude.
   309  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
   310  		reflect.Int64:
   311  
   312  		switch dest.Kind() {
   313  		// Destination is a signed integer of various magnitude.
   314  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
   315  			reflect.Int64:
   316  
   317  			srcInt := src.Int()
   318  			if dest.OverflowInt(srcInt) {
   319  				str := fmt.Sprintf("parameter #%d '%s' "+
   320  					"overflows destination type %v",
   321  					paramNum, fieldName, destBaseType)
   322  				return makeError(ErrInvalidType, str)
   323  			}
   324  
   325  			dest.SetInt(srcInt)
   326  
   327  		// Destination is an unsigned integer of various magnitude.
   328  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
   329  			reflect.Uint64:
   330  
   331  			srcInt := src.Int()
   332  			if srcInt < 0 || dest.OverflowUint(uint64(srcInt)) {
   333  				str := fmt.Sprintf("parameter #%d '%s' "+
   334  					"overflows destination type %v",
   335  					paramNum, fieldName, destBaseType)
   336  				return makeError(ErrInvalidType, str)
   337  			}
   338  			dest.SetUint(uint64(srcInt))
   339  
   340  		default:
   341  			str := fmt.Sprintf("parameter #%d '%s' must be type "+
   342  				"%v (got %v)", paramNum, fieldName, destBaseType,
   343  				srcBaseType)
   344  			return makeError(ErrInvalidType, str)
   345  		}
   346  
   347  	// Source value is an unsigned integer of various magnitude.
   348  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
   349  		reflect.Uint64:
   350  
   351  		switch dest.Kind() {
   352  		// Destination is a signed integer of various magnitude.
   353  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
   354  			reflect.Int64:
   355  
   356  			srcUint := src.Uint()
   357  			if srcUint > uint64(1<<63)-1 {
   358  				str := fmt.Sprintf("parameter #%d '%s' "+
   359  					"overflows destination type %v",
   360  					paramNum, fieldName, destBaseType)
   361  				return makeError(ErrInvalidType, str)
   362  			}
   363  			if dest.OverflowInt(int64(srcUint)) {
   364  				str := fmt.Sprintf("parameter #%d '%s' "+
   365  					"overflows destination type %v",
   366  					paramNum, fieldName, destBaseType)
   367  				return makeError(ErrInvalidType, str)
   368  			}
   369  			dest.SetInt(int64(srcUint))
   370  
   371  		// Destination is an unsigned integer of various magnitude.
   372  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32,
   373  			reflect.Uint64:
   374  
   375  			srcUint := src.Uint()
   376  			if dest.OverflowUint(srcUint) {
   377  				str := fmt.Sprintf("parameter #%d '%s' "+
   378  					"overflows destination type %v",
   379  					paramNum, fieldName, destBaseType)
   380  				return makeError(ErrInvalidType, str)
   381  			}
   382  			dest.SetUint(srcUint)
   383  
   384  		default:
   385  			str := fmt.Sprintf("parameter #%d '%s' must be type "+
   386  				"%v (got %v)", paramNum, fieldName, destBaseType,
   387  				srcBaseType)
   388  			return makeError(ErrInvalidType, str)
   389  		}
   390  
   391  	// Source value is a float.
   392  	case reflect.Float32, reflect.Float64:
   393  		destKind := dest.Kind()
   394  		if destKind != reflect.Float32 && destKind != reflect.Float64 {
   395  			str := fmt.Sprintf("parameter #%d '%s' must be type "+
   396  				"%v (got %v)", paramNum, fieldName, destBaseType,
   397  				srcBaseType)
   398  			return makeError(ErrInvalidType, str)
   399  		}
   400  
   401  		srcFloat := src.Float()
   402  		if dest.OverflowFloat(srcFloat) {
   403  			str := fmt.Sprintf("parameter #%d '%s' overflows "+
   404  				"destination type %v", paramNum, fieldName,
   405  				destBaseType)
   406  			return makeError(ErrInvalidType, str)
   407  		}
   408  		dest.SetFloat(srcFloat)
   409  
   410  	// Source value is a string.
   411  	case reflect.String:
   412  		switch dest.Kind() {
   413  		// String -> bool
   414  		case reflect.Bool:
   415  			b, err := strconv.ParseBool(src.String())
   416  			if err != nil {
   417  				str := fmt.Sprintf("parameter #%d '%s' must "+
   418  					"parse to a %v", paramNum, fieldName,
   419  					destBaseType)
   420  				return makeError(ErrInvalidType, str)
   421  			}
   422  			dest.SetBool(b)
   423  
   424  		// String -> signed integer of varying size.
   425  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32,
   426  			reflect.Int64:
   427  
   428  			srcInt, err := strconv.ParseInt(src.String(), 0, 0)
   429  			if err != nil {
   430  				str := fmt.Sprintf("parameter #%d '%s' must "+
   431  					"parse to a %v", paramNum, fieldName,
   432  					destBaseType)
   433  				return makeError(ErrInvalidType, str)
   434  			}
   435  			if dest.OverflowInt(srcInt) {
   436  				str := fmt.Sprintf("parameter #%d '%s' "+
   437  					"overflows destination type %v",
   438  					paramNum, fieldName, destBaseType)
   439  				return makeError(ErrInvalidType, str)
   440  			}
   441  			dest.SetInt(srcInt)
   442  
   443  		// String -> unsigned integer of varying size.
   444  		case reflect.Uint, reflect.Uint8, reflect.Uint16,
   445  			reflect.Uint32, reflect.Uint64:
   446  
   447  			srcUint, err := strconv.ParseUint(src.String(), 0, 0)
   448  			if err != nil {
   449  				str := fmt.Sprintf("parameter #%d '%s' must "+
   450  					"parse to a %v", paramNum, fieldName,
   451  					destBaseType)
   452  				return makeError(ErrInvalidType, str)
   453  			}
   454  			if dest.OverflowUint(srcUint) {
   455  				str := fmt.Sprintf("parameter #%d '%s' "+
   456  					"overflows destination type %v",
   457  					paramNum, fieldName, destBaseType)
   458  				return makeError(ErrInvalidType, str)
   459  			}
   460  			dest.SetUint(srcUint)
   461  
   462  		// String -> float of varying size.
   463  		case reflect.Float32, reflect.Float64:
   464  			srcFloat, err := strconv.ParseFloat(src.String(), 0)
   465  			if err != nil {
   466  				str := fmt.Sprintf("parameter #%d '%s' must "+
   467  					"parse to a %v", paramNum, fieldName,
   468  					destBaseType)
   469  				return makeError(ErrInvalidType, str)
   470  			}
   471  			if dest.OverflowFloat(srcFloat) {
   472  				str := fmt.Sprintf("parameter #%d '%s' "+
   473  					"overflows destination type %v",
   474  					paramNum, fieldName, destBaseType)
   475  				return makeError(ErrInvalidType, str)
   476  			}
   477  			dest.SetFloat(srcFloat)
   478  
   479  		// String -> string (typecast).
   480  		case reflect.String:
   481  			dest.SetString(src.String())
   482  
   483  		// String -> arrays, slices, structs, and maps via
   484  		// json.Unmarshal.
   485  		case reflect.Array, reflect.Slice, reflect.Struct, reflect.Map:
   486  			concreteVal := dest.Addr().Interface()
   487  			err := json.Unmarshal([]byte(src.String()), &concreteVal)
   488  			if err != nil {
   489  				str := fmt.Sprintf("parameter #%d '%s' must "+
   490  					"be valid JSON which unsmarshals to a %v",
   491  					paramNum, fieldName, destBaseType)
   492  				return makeError(ErrInvalidType, str)
   493  			}
   494  			dest.Set(reflect.ValueOf(concreteVal).Elem())
   495  		}
   496  	}
   497  
   498  	return nil
   499  }
   500  
   501  // NewCmd provides a generic mechanism to create a new command that can marshal
   502  // to a JSON-RPC request while respecting the requirements of the provided
   503  // method.  The method must have been registered with the package already along
   504  // with its type definition.  All methods associated with the commands exported
   505  // by this package are already registered by default.
   506  //
   507  // The arguments are most efficient when they are the exact same type as the
   508  // underlying field in the command struct associated with the the method,
   509  // however this function also will perform a variety of conversions to make it
   510  // more flexible.  This allows, for example, command line args which are strings
   511  // to be passed unaltered.  In particular, the following conversions are
   512  // supported:
   513  //
   514  //   - Conversion between any size signed or unsigned integer so long as the
   515  //     value does not overflow the destination type
   516  //   - Conversion between float32 and float64 so long as the value does not
   517  //     overflow the destination type
   518  //   - Conversion from string to boolean for everything strconv.ParseBool
   519  //     recognizes
   520  //   - Conversion from string to any size integer for everything
   521  //     strconv.ParseInt and strconv.ParseUint recognizes
   522  //   - Conversion from string to any size float for everything
   523  //     strconv.ParseFloat recognizes
   524  //   - Conversion from string to arrays, slices, structs, and maps by treating
   525  //     the string as marshalled JSON and calling json.Unmarshal into the
   526  //     destination field
   527  func NewCmd(method string, args ...interface{}) (interface{}, error) {
   528  	// Look up details about the provided method.  Any methods that aren't
   529  	// registered are an error.
   530  	registerLock.RLock()
   531  	rtp, ok := methodToConcreteType[method]
   532  	info := methodToInfo[method]
   533  	registerLock.RUnlock()
   534  	if !ok {
   535  		str := fmt.Sprintf("%q is not registered", method)
   536  		return nil, makeError(ErrUnregisteredMethod, str)
   537  	}
   538  
   539  	// Ensure the number of parameters are correct.
   540  	numParams := len(args)
   541  	if err := checkNumParams(numParams, &info); err != nil {
   542  		return nil, err
   543  	}
   544  
   545  	// Create the appropriate command type for the method.  Since all types
   546  	// are enforced to be a pointer to a struct at registration time, it's
   547  	// safe to indirect to the struct now.
   548  	rvp := reflect.New(rtp.Elem())
   549  	rv := rvp.Elem()
   550  	rt := rtp.Elem()
   551  
   552  	// Loop through each of the struct fields and assign the associated
   553  	// parameter into them after checking its type validity.
   554  	for i := 0; i < numParams; i++ {
   555  		// Attempt to assign each of the arguments to the according
   556  		// struct field.
   557  		rvf := rv.Field(i)
   558  		fieldName := strings.ToLower(rt.Field(i).Name)
   559  		err := assignField(i+1, fieldName, rvf, reflect.ValueOf(args[i]))
   560  		if err != nil {
   561  			return nil, err
   562  		}
   563  	}
   564  
   565  	return rvp.Interface(), nil
   566  }