github.com/ethersphere/bee/v2@v2.2.0/pkg/api/util.go (about) 1 // Copyright 2020 The Swarm Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package api 6 7 import ( 8 "crypto/ecdsa" 9 "encoding/hex" 10 "errors" 11 "fmt" 12 "math/big" 13 "reflect" 14 "strconv" 15 "strings" 16 17 "github.com/ethereum/go-ethereum/common" 18 "github.com/ethersphere/bee/v2/pkg/pss" 19 "github.com/ethersphere/bee/v2/pkg/swarm" 20 "github.com/hashicorp/go-multierror" 21 "github.com/multiformats/go-multiaddr" 22 ) 23 24 // mapStructureTagName represents the name of the tag used to map values. 25 const mapStructureTagName = "map" 26 27 // errHexLength reports an attempt to decode an odd-length input. 28 // It's a drop-in replacement for hex.ErrLength. 29 var errHexLength = errors.New("odd length hex string") 30 31 // hexInvalidByteError values describe errors resulting 32 // from an invalid byte in a hex string. 33 // It's a drop-in replacement for hex.InvalidByteError. 34 type hexInvalidByteError byte 35 36 // Error implements the error interface. 37 func (e hexInvalidByteError) Error() string { 38 return fmt.Sprintf("invalid hex byte: %#U", rune(e)) 39 } 40 41 // parseError is returned when an entry cannot be parsed. 42 type parseError struct { 43 Entry string 44 Value string 45 Cause error 46 } 47 48 // Error implements the error interface. 49 func (e *parseError) Error() string { 50 return fmt.Sprintf("`%s=%v`: %v", e.Entry, e.Value, e.Cause) 51 } 52 53 // Unwrap implements the interface required by errors.Unwrap function. 54 func (e *parseError) Unwrap() error { 55 return e.Cause 56 } 57 58 // Equal returns true if the given error 59 // type and fields are equal to this error. 60 // It is used to compare errors in tests. 61 func (e *parseError) Equal(err error) bool { 62 var p *parseError 63 if !errors.As(err, &p) { 64 return false 65 } 66 return e.Entry == p.Entry && e.Value == p.Value && errors.Is(e.Cause, p.Cause) 67 } 68 69 // newParseError returns a new mapStructure error. 70 // If the cause is strconv.NumError, its 71 // underlying error is unwrapped and 72 // used as a cause. The hex.InvalidByteError 73 // and hex.ErrLength errors are replaced in 74 // order to hide unnecessary information. 75 func newParseError(entry, value string, cause error) error { 76 var numErr *strconv.NumError 77 if errors.As(cause, &numErr) { 78 cause = numErr.Err 79 } 80 81 var hexErr hex.InvalidByteError 82 if errors.As(cause, &hexErr) { 83 cause = hexInvalidByteError(hexErr) 84 } 85 86 if errors.Is(cause, hex.ErrLength) { 87 cause = errHexLength 88 } 89 90 return &parseError{ 91 Entry: entry, 92 Value: value, 93 Cause: cause, 94 } 95 } 96 97 // flattenErrorsFormat flattens the errors in 98 // the multierror.Error as a one-line string. 99 var flattenErrorsFormat = func(es []error) string { 100 messages := make([]string, len(es)) 101 for i, err := range es { 102 messages[i] = err.Error() 103 } 104 return fmt.Sprintf( 105 "%d error(s) occurred: %v", 106 len(es), 107 strings.Join(messages, "; "), 108 ) 109 } 110 111 // mapStructure maps the input to the output values. 112 // The input is one of the following: 113 // - map[string]string 114 // - map[string][]string 115 // 116 // In the second case, the first value of 117 // the string array is taken as a value. 118 // 119 // The output struct fields can contain the 120 // `map` tag that refers to the map input key. 121 // For example: 122 // 123 // type Output struct { 124 // BoolVal bool `map:"boolVal,omitempty"` 125 // } 126 // 127 // If the `map` tag is not present, the field name is used. 128 // If the field name or the `map` tag is not present in 129 // the input map, the field is skipped. If the map value 130 // is empty and the` omitempty` tag is present then the 131 // field is skipped. 132 // 133 // In case of parsing error, a new parseError is returned to the caller. 134 // The caller can use the Unwrap method to get the original error. 135 func mapStructure(input, output interface{}, hooks map[string]func(v string) (string, error)) (err error) { 136 if input == nil || output == nil { 137 return nil 138 } 139 140 defer func() { 141 if e := recover(); e != nil { 142 err = fmt.Errorf("%v", e) 143 } 144 }() 145 146 var ( 147 inputVal reflect.Value 148 outputVal reflect.Value 149 ) 150 151 // Do input sanity checks. 152 inputVal = reflect.ValueOf(input) 153 if inputVal.Kind() == reflect.Ptr { 154 inputVal = inputVal.Elem() 155 } 156 switch { 157 case inputVal.Kind() != reflect.Map: 158 return errors.New("input is not a map") 159 case !inputVal.IsValid(): 160 return nil 161 } 162 163 // Do output sanity checks. 164 outputVal = reflect.ValueOf(output) 165 switch { 166 case outputVal.Kind() != reflect.Ptr: 167 return errors.New("output is not a pointer") 168 case outputVal.Elem().Kind() != reflect.Struct: 169 return errors.New("output is not a struct") 170 } 171 outputVal = outputVal.Elem() 172 173 // set is the workhorse here, parsing and setting the values. 174 var set func(string, reflect.Value) error 175 set = func(value string, field reflect.Value) error { 176 switch fieldKind := field.Kind(); fieldKind { 177 case reflect.Ptr: 178 if field.IsNil() { 179 field.Set(reflect.New(field.Type().Elem())) 180 } 181 err := set(value, field.Elem()) 182 if err != nil { 183 field.Set(reflect.Zero(field.Type())) // Clear the field on error. 184 } 185 return err 186 case reflect.Bool: 187 val, err := strconv.ParseBool(value) 188 if err != nil { 189 return err 190 } 191 field.SetBool(val) 192 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 193 val, err := strconv.ParseUint(value, 10, numberSize(fieldKind)) 194 if err != nil { 195 return err 196 } 197 field.SetUint(val) 198 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 199 val, err := strconv.ParseInt(value, 10, numberSize(fieldKind)) 200 if err != nil { 201 return err 202 } 203 field.SetInt(val) 204 case reflect.Float32, reflect.Float64: 205 val, err := strconv.ParseFloat(value, numberSize(fieldKind)) 206 if err != nil { 207 return err 208 } 209 field.SetFloat(val) 210 case reflect.String: 211 field.SetString(value) 212 case reflect.Slice: 213 if value == "" { 214 return nil // Nil slice. 215 } 216 val, err := hex.DecodeString(value) 217 if err != nil { 218 return err 219 } 220 field.SetBytes(val) 221 case reflect.Array: 222 switch field.Interface().(type) { 223 case common.Hash: 224 val := common.HexToHash(value) 225 field.Set(reflect.ValueOf(val)) 226 case common.Address: 227 val := common.HexToAddress(value) 228 field.Set(reflect.ValueOf(val)) 229 } 230 case reflect.Struct: 231 switch field.Interface().(type) { 232 case big.Int: 233 val, ok := new(big.Int).SetString(value, 10) 234 if !ok { 235 return errors.New("invalid value") 236 } 237 field.Set(reflect.ValueOf(*val)) 238 case swarm.Address: 239 val, err := swarm.ParseHexAddress(value) 240 if err != nil { 241 return err 242 } 243 field.Set(reflect.ValueOf(val)) 244 case common.Hash: 245 val := common.HexToHash(value) 246 field.Set(reflect.ValueOf(val)) 247 case ecdsa.PublicKey: 248 val, err := pss.ParseRecipient(value) 249 if err != nil { 250 return err 251 } 252 field.Set(reflect.ValueOf(*val)) 253 } 254 case reflect.Interface: 255 switch field.Type() { 256 case reflect.TypeOf((*multiaddr.Multiaddr)(nil)).Elem(): 257 val, err := multiaddr.NewMultiaddr(value) 258 if err != nil { 259 return err 260 } 261 field.Set(reflect.ValueOf(val)) 262 } 263 default: 264 return fmt.Errorf("unsupported type %T", field.Interface()) 265 } 266 return nil 267 } 268 269 // parseFieldTags parses the given field tags into name, hook, and omitempty. 270 parseFieldTags := func(field reflect.StructField) (name string, hook func(v string) (string, error), omitempty bool) { 271 hook = func(v string) (string, error) { return v, nil } 272 273 val, ok := field.Tag.Lookup(mapStructureTagName) 274 if !ok { 275 return field.Name, hook, false 276 } 277 278 tags := strings.SplitN(val, ",", 3) 279 name = tags[0] 280 for _, tag := range tags[1:] { 281 switch tag { 282 case "omitempty": 283 omitempty = true 284 default: 285 if len(hooks) == 0 { 286 panic(errors.New("zero registered hooks")) 287 } 288 hook, ok = hooks[tag] 289 if !ok { 290 panic(fmt.Errorf("unknown hook %q for field: %s", tag, field.Name)) 291 } 292 } 293 } 294 295 return name, hook, omitempty 296 } 297 298 // Map input into output. 299 pErrs := &multierror.Error{ErrorFormat: flattenErrorsFormat} 300 for i := 0; i < outputVal.NumField(); i++ { 301 name, hook, omitempty := parseFieldTags(outputVal.Type().Field(i)) 302 303 mKey := reflect.ValueOf(name) 304 mVal := inputVal.MapIndex(mKey) 305 if !mVal.IsValid() { 306 continue 307 } 308 309 value := flattenValue(mVal).String() 310 if omitempty && value == "" { 311 continue 312 } 313 314 trans, err := hook(value) 315 if err != nil { 316 pErrs = multierror.Append(pErrs, newParseError(name, value, err)) 317 continue 318 } 319 320 if err := set(trans, outputVal.Field(i)); err != nil { 321 pErrs = multierror.Append(pErrs, newParseError(name, value, err)) 322 } 323 } 324 return pErrs.ErrorOrNil() 325 } 326 327 // numberSize returns the size of the number in bits. 328 func numberSize(k reflect.Kind) int { 329 switch k { 330 case reflect.Uint8, reflect.Int8: 331 return 8 332 case reflect.Uint16, reflect.Int16: 333 return 16 334 case reflect.Uint32, reflect.Int32, reflect.Float32: 335 return 32 336 case reflect.Uint64, reflect.Int64, reflect.Float64: 337 return 64 338 } 339 return 0 340 } 341 342 // flattenValue returns the first element of the value if it is a slice. 343 func flattenValue(val reflect.Value) reflect.Value { 344 switch val.Kind() { 345 case reflect.Slice: 346 return val.Index(0) 347 } 348 return val 349 }