github.com/ethersphere/bee/v2@v2.2.0/pkg/api/util.go (about)

     1  // Copyright 2020 The Swarm Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package api
     6  
     7  import (
     8  	"crypto/ecdsa"
     9  	"encoding/hex"
    10  	"errors"
    11  	"fmt"
    12  	"math/big"
    13  	"reflect"
    14  	"strconv"
    15  	"strings"
    16  
    17  	"github.com/ethereum/go-ethereum/common"
    18  	"github.com/ethersphere/bee/v2/pkg/pss"
    19  	"github.com/ethersphere/bee/v2/pkg/swarm"
    20  	"github.com/hashicorp/go-multierror"
    21  	"github.com/multiformats/go-multiaddr"
    22  )
    23  
    24  // mapStructureTagName represents the name of the tag used to map values.
    25  const mapStructureTagName = "map"
    26  
    27  // errHexLength reports an attempt to decode an odd-length input.
    28  // It's a drop-in replacement for hex.ErrLength.
    29  var errHexLength = errors.New("odd length hex string")
    30  
    31  // hexInvalidByteError values describe errors resulting
    32  // from an invalid byte in a hex string.
    33  // It's a drop-in replacement for hex.InvalidByteError.
    34  type hexInvalidByteError byte
    35  
    36  // Error implements the error interface.
    37  func (e hexInvalidByteError) Error() string {
    38  	return fmt.Sprintf("invalid hex byte: %#U", rune(e))
    39  }
    40  
    41  // parseError is returned when an entry cannot be parsed.
    42  type parseError struct {
    43  	Entry string
    44  	Value string
    45  	Cause error
    46  }
    47  
    48  // Error implements the error interface.
    49  func (e *parseError) Error() string {
    50  	return fmt.Sprintf("`%s=%v`: %v", e.Entry, e.Value, e.Cause)
    51  }
    52  
    53  // Unwrap implements the interface required by errors.Unwrap function.
    54  func (e *parseError) Unwrap() error {
    55  	return e.Cause
    56  }
    57  
    58  // Equal returns true if the given error
    59  // type and fields are equal to this error.
    60  // It is used to compare errors in tests.
    61  func (e *parseError) Equal(err error) bool {
    62  	var p *parseError
    63  	if !errors.As(err, &p) {
    64  		return false
    65  	}
    66  	return e.Entry == p.Entry && e.Value == p.Value && errors.Is(e.Cause, p.Cause)
    67  }
    68  
    69  // newParseError returns a new mapStructure error.
    70  // If the cause is strconv.NumError, its
    71  // underlying error is unwrapped and
    72  // used as a cause. The hex.InvalidByteError
    73  // and hex.ErrLength errors are replaced in
    74  // order to hide unnecessary information.
    75  func newParseError(entry, value string, cause error) error {
    76  	var numErr *strconv.NumError
    77  	if errors.As(cause, &numErr) {
    78  		cause = numErr.Err
    79  	}
    80  
    81  	var hexErr hex.InvalidByteError
    82  	if errors.As(cause, &hexErr) {
    83  		cause = hexInvalidByteError(hexErr)
    84  	}
    85  
    86  	if errors.Is(cause, hex.ErrLength) {
    87  		cause = errHexLength
    88  	}
    89  
    90  	return &parseError{
    91  		Entry: entry,
    92  		Value: value,
    93  		Cause: cause,
    94  	}
    95  }
    96  
    97  // flattenErrorsFormat flattens the errors in
    98  // the multierror.Error as a one-line string.
    99  var flattenErrorsFormat = func(es []error) string {
   100  	messages := make([]string, len(es))
   101  	for i, err := range es {
   102  		messages[i] = err.Error()
   103  	}
   104  	return fmt.Sprintf(
   105  		"%d error(s) occurred: %v",
   106  		len(es),
   107  		strings.Join(messages, "; "),
   108  	)
   109  }
   110  
   111  // mapStructure maps the input to the output values.
   112  // The input is one of the following:
   113  //   - map[string]string
   114  //   - map[string][]string
   115  //
   116  // In the second case, the first value of
   117  // the string array is taken as a value.
   118  //
   119  // The output struct fields can contain the
   120  // `map` tag that refers to the map input key.
   121  // For example:
   122  //
   123  //	type Output struct {
   124  //		BoolVal bool `map:"boolVal,omitempty"`
   125  //	}
   126  //
   127  // If the `map` tag is not present, the field name is used.
   128  // If the field name or the `map` tag is not present in
   129  // the input map, the field is skipped. If the map value
   130  // is empty and the` omitempty` tag is present then the
   131  // field is skipped.
   132  //
   133  // In case of parsing error, a new parseError is returned to the caller.
   134  // The caller can use the Unwrap method to get the original error.
   135  func mapStructure(input, output interface{}, hooks map[string]func(v string) (string, error)) (err error) {
   136  	if input == nil || output == nil {
   137  		return nil
   138  	}
   139  
   140  	defer func() {
   141  		if e := recover(); e != nil {
   142  			err = fmt.Errorf("%v", e)
   143  		}
   144  	}()
   145  
   146  	var (
   147  		inputVal  reflect.Value
   148  		outputVal reflect.Value
   149  	)
   150  
   151  	// Do input sanity checks.
   152  	inputVal = reflect.ValueOf(input)
   153  	if inputVal.Kind() == reflect.Ptr {
   154  		inputVal = inputVal.Elem()
   155  	}
   156  	switch {
   157  	case inputVal.Kind() != reflect.Map:
   158  		return errors.New("input is not a map")
   159  	case !inputVal.IsValid():
   160  		return nil
   161  	}
   162  
   163  	// Do output sanity checks.
   164  	outputVal = reflect.ValueOf(output)
   165  	switch {
   166  	case outputVal.Kind() != reflect.Ptr:
   167  		return errors.New("output is not a pointer")
   168  	case outputVal.Elem().Kind() != reflect.Struct:
   169  		return errors.New("output is not a struct")
   170  	}
   171  	outputVal = outputVal.Elem()
   172  
   173  	// set is the workhorse here, parsing and setting the values.
   174  	var set func(string, reflect.Value) error
   175  	set = func(value string, field reflect.Value) error {
   176  		switch fieldKind := field.Kind(); fieldKind {
   177  		case reflect.Ptr:
   178  			if field.IsNil() {
   179  				field.Set(reflect.New(field.Type().Elem()))
   180  			}
   181  			err := set(value, field.Elem())
   182  			if err != nil {
   183  				field.Set(reflect.Zero(field.Type())) // Clear the field on error.
   184  			}
   185  			return err
   186  		case reflect.Bool:
   187  			val, err := strconv.ParseBool(value)
   188  			if err != nil {
   189  				return err
   190  			}
   191  			field.SetBool(val)
   192  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   193  			val, err := strconv.ParseUint(value, 10, numberSize(fieldKind))
   194  			if err != nil {
   195  				return err
   196  			}
   197  			field.SetUint(val)
   198  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   199  			val, err := strconv.ParseInt(value, 10, numberSize(fieldKind))
   200  			if err != nil {
   201  				return err
   202  			}
   203  			field.SetInt(val)
   204  		case reflect.Float32, reflect.Float64:
   205  			val, err := strconv.ParseFloat(value, numberSize(fieldKind))
   206  			if err != nil {
   207  				return err
   208  			}
   209  			field.SetFloat(val)
   210  		case reflect.String:
   211  			field.SetString(value)
   212  		case reflect.Slice:
   213  			if value == "" {
   214  				return nil // Nil slice.
   215  			}
   216  			val, err := hex.DecodeString(value)
   217  			if err != nil {
   218  				return err
   219  			}
   220  			field.SetBytes(val)
   221  		case reflect.Array:
   222  			switch field.Interface().(type) {
   223  			case common.Hash:
   224  				val := common.HexToHash(value)
   225  				field.Set(reflect.ValueOf(val))
   226  			case common.Address:
   227  				val := common.HexToAddress(value)
   228  				field.Set(reflect.ValueOf(val))
   229  			}
   230  		case reflect.Struct:
   231  			switch field.Interface().(type) {
   232  			case big.Int:
   233  				val, ok := new(big.Int).SetString(value, 10)
   234  				if !ok {
   235  					return errors.New("invalid value")
   236  				}
   237  				field.Set(reflect.ValueOf(*val))
   238  			case swarm.Address:
   239  				val, err := swarm.ParseHexAddress(value)
   240  				if err != nil {
   241  					return err
   242  				}
   243  				field.Set(reflect.ValueOf(val))
   244  			case common.Hash:
   245  				val := common.HexToHash(value)
   246  				field.Set(reflect.ValueOf(val))
   247  			case ecdsa.PublicKey:
   248  				val, err := pss.ParseRecipient(value)
   249  				if err != nil {
   250  					return err
   251  				}
   252  				field.Set(reflect.ValueOf(*val))
   253  			}
   254  		case reflect.Interface:
   255  			switch field.Type() {
   256  			case reflect.TypeOf((*multiaddr.Multiaddr)(nil)).Elem():
   257  				val, err := multiaddr.NewMultiaddr(value)
   258  				if err != nil {
   259  					return err
   260  				}
   261  				field.Set(reflect.ValueOf(val))
   262  			}
   263  		default:
   264  			return fmt.Errorf("unsupported type %T", field.Interface())
   265  		}
   266  		return nil
   267  	}
   268  
   269  	// parseFieldTags parses the given field tags into name, hook, and omitempty.
   270  	parseFieldTags := func(field reflect.StructField) (name string, hook func(v string) (string, error), omitempty bool) {
   271  		hook = func(v string) (string, error) { return v, nil }
   272  
   273  		val, ok := field.Tag.Lookup(mapStructureTagName)
   274  		if !ok {
   275  			return field.Name, hook, false
   276  		}
   277  
   278  		tags := strings.SplitN(val, ",", 3)
   279  		name = tags[0]
   280  		for _, tag := range tags[1:] {
   281  			switch tag {
   282  			case "omitempty":
   283  				omitempty = true
   284  			default:
   285  				if len(hooks) == 0 {
   286  					panic(errors.New("zero registered hooks"))
   287  				}
   288  				hook, ok = hooks[tag]
   289  				if !ok {
   290  					panic(fmt.Errorf("unknown hook %q for field: %s", tag, field.Name))
   291  				}
   292  			}
   293  		}
   294  
   295  		return name, hook, omitempty
   296  	}
   297  
   298  	// Map input into output.
   299  	pErrs := &multierror.Error{ErrorFormat: flattenErrorsFormat}
   300  	for i := 0; i < outputVal.NumField(); i++ {
   301  		name, hook, omitempty := parseFieldTags(outputVal.Type().Field(i))
   302  
   303  		mKey := reflect.ValueOf(name)
   304  		mVal := inputVal.MapIndex(mKey)
   305  		if !mVal.IsValid() {
   306  			continue
   307  		}
   308  
   309  		value := flattenValue(mVal).String()
   310  		if omitempty && value == "" {
   311  			continue
   312  		}
   313  
   314  		trans, err := hook(value)
   315  		if err != nil {
   316  			pErrs = multierror.Append(pErrs, newParseError(name, value, err))
   317  			continue
   318  		}
   319  
   320  		if err := set(trans, outputVal.Field(i)); err != nil {
   321  			pErrs = multierror.Append(pErrs, newParseError(name, value, err))
   322  		}
   323  	}
   324  	return pErrs.ErrorOrNil()
   325  }
   326  
   327  // numberSize returns the size of the number in bits.
   328  func numberSize(k reflect.Kind) int {
   329  	switch k {
   330  	case reflect.Uint8, reflect.Int8:
   331  		return 8
   332  	case reflect.Uint16, reflect.Int16:
   333  		return 16
   334  	case reflect.Uint32, reflect.Int32, reflect.Float32:
   335  		return 32
   336  	case reflect.Uint64, reflect.Int64, reflect.Float64:
   337  		return 64
   338  	}
   339  	return 0
   340  }
   341  
   342  // flattenValue returns the first element of the value if it is a slice.
   343  func flattenValue(val reflect.Value) reflect.Value {
   344  	switch val.Kind() {
   345  	case reflect.Slice:
   346  		return val.Index(0)
   347  	}
   348  	return val
   349  }