github.com/lmittmann/w3@v0.20.0/internal/abi/copy.go (about)

     1  package abi
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"math/big"
     7  	"reflect"
     8  
     9  	"github.com/ethereum/go-ethereum/accounts/abi"
    10  	"github.com/ethereum/go-ethereum/common"
    11  )
    12  
    13  var (
    14  	errUnassignable = errors.New("unassignable")
    15  
    16  	// src non slice/array/struct types
    17  	srcBasicTypes = map[reflect.Type]struct{}{
    18  		reflect.TypeFor[bool]():           {},
    19  		reflect.TypeFor[uint]():           {},
    20  		reflect.TypeFor[uint8]():          {},
    21  		reflect.TypeFor[uint16]():         {},
    22  		reflect.TypeFor[uint32]():         {},
    23  		reflect.TypeFor[uint64]():         {},
    24  		reflect.TypeFor[int]():            {},
    25  		reflect.TypeFor[int8]():           {},
    26  		reflect.TypeFor[int16]():          {},
    27  		reflect.TypeFor[int32]():          {},
    28  		reflect.TypeFor[int64]():          {},
    29  		reflect.TypeFor[[1]byte]():        {},
    30  		reflect.TypeFor[[2]byte]():        {},
    31  		reflect.TypeFor[[3]byte]():        {},
    32  		reflect.TypeFor[[4]byte]():        {},
    33  		reflect.TypeFor[[5]byte]():        {},
    34  		reflect.TypeFor[[6]byte]():        {},
    35  		reflect.TypeFor[[7]byte]():        {},
    36  		reflect.TypeFor[[8]byte]():        {},
    37  		reflect.TypeFor[[9]byte]():        {},
    38  		reflect.TypeFor[[10]byte]():       {},
    39  		reflect.TypeFor[[11]byte]():       {},
    40  		reflect.TypeFor[[12]byte]():       {},
    41  		reflect.TypeFor[[13]byte]():       {},
    42  		reflect.TypeFor[[14]byte]():       {},
    43  		reflect.TypeFor[[15]byte]():       {},
    44  		reflect.TypeFor[[16]byte]():       {},
    45  		reflect.TypeFor[[17]byte]():       {},
    46  		reflect.TypeFor[[18]byte]():       {},
    47  		reflect.TypeFor[[19]byte]():       {},
    48  		reflect.TypeFor[[20]byte]():       {},
    49  		reflect.TypeFor[[21]byte]():       {},
    50  		reflect.TypeFor[[22]byte]():       {},
    51  		reflect.TypeFor[[23]byte]():       {},
    52  		reflect.TypeFor[[24]byte]():       {},
    53  		reflect.TypeFor[[25]byte]():       {},
    54  		reflect.TypeFor[[26]byte]():       {},
    55  		reflect.TypeFor[[27]byte]():       {},
    56  		reflect.TypeFor[[28]byte]():       {},
    57  		reflect.TypeFor[[29]byte]():       {},
    58  		reflect.TypeFor[[30]byte]():       {},
    59  		reflect.TypeFor[[31]byte]():       {},
    60  		reflect.TypeFor[[32]byte]():       {},
    61  		reflect.TypeFor[common.Address](): {},
    62  		reflect.TypeFor[common.Hash]():    {},
    63  		reflect.TypeFor[string]():         {},
    64  		reflect.TypeFor[[]byte]():         {},
    65  		reflect.TypeFor[*big.Int]():       {},
    66  		reflect.TypeFor[big.Int]():        {},
    67  	}
    68  )
    69  
    70  // Copy shallow copies the value src to dst. If src is an anonymous struct or an
    71  // array/slice of anonymous structs, the fields of the anonymous struct are
    72  // copied to dst.
    73  func Copy(dst, src any) error {
    74  	// check if dst is valid
    75  	if dst == nil {
    76  		return fmt.Errorf("abi: decode nil")
    77  	}
    78  
    79  	rDst := reflect.ValueOf(dst)
    80  	if rDst.Kind() != reflect.Pointer {
    81  		return fmt.Errorf("abi: decode non-pointer %T", dst)
    82  	}
    83  	if rDst.IsNil() {
    84  		return fmt.Errorf("abi: decode nil %T", dst)
    85  	}
    86  
    87  	err := rCopy(
    88  		dereference(rDst),
    89  		reflect.ValueOf(src),
    90  	)
    91  	if errors.Is(err, errUnassignable) {
    92  		return fmt.Errorf("abi: can't assign %T to %T", src, dst)
    93  	} else if err != nil {
    94  		return fmt.Errorf("abi: %w", err)
    95  	}
    96  
    97  	return nil
    98  }
    99  
   100  func rCopy(dst, src reflect.Value) error {
   101  	if _, ok := srcBasicTypes[src.Type()]; ok {
   102  		return set(dst, reference(src))
   103  	} else if k := src.Kind(); k == reflect.Struct {
   104  		return setStruct(dst, src)
   105  	} else if k == reflect.Slice {
   106  		return setSlice(dst, src)
   107  	} else if k == reflect.Array {
   108  		return setArray(dst, src)
   109  	} else {
   110  		return fmt.Errorf("unsupported src type %T", src.Interface())
   111  	}
   112  }
   113  
   114  func set(dst, src reflect.Value) error {
   115  	if src.Kind() != reflect.Ptr && dst.Kind() == reflect.Ptr {
   116  		src = reference(src)
   117  	} else if src.Kind() == reflect.Pointer && dst.Kind() != reflect.Pointer {
   118  		src = src.Elem()
   119  	}
   120  
   121  	st, dt := src.Type(), dst.Type()
   122  	if !st.AssignableTo(dt) {
   123  		if st.ConvertibleTo(dt) {
   124  			src = src.Convert(dt)
   125  		} else {
   126  			return errUnassignable
   127  		}
   128  	}
   129  
   130  	if dst.CanAddr() {
   131  		dst.Set(src)
   132  	} else {
   133  		dst.Elem().Set(src.Elem())
   134  	}
   135  	return nil
   136  }
   137  
   138  func setStruct(dst, src reflect.Value) error {
   139  	if dst.Kind() == reflect.Pointer {
   140  		if dst.IsNil() {
   141  			dst.Set(reflect.New(dst.Type().Elem()))
   142  		}
   143  		dst = dst.Elem()
   144  	}
   145  
   146  	st, dt := src.Type(), dst.Type()
   147  
   148  	// field tag mapping (tags take precedence over names)
   149  	srcFields := make(map[string]reflect.StructField)
   150  	for i := range src.NumField() {
   151  		field := st.Field(i)
   152  		srcFields[field.Name] = field
   153  	}
   154  
   155  	for i := range dst.NumField() {
   156  		dstField := dt.Field(i)
   157  		srcField, ok := srcFields[dstField.Name]
   158  		if !ok {
   159  			if tag, ok := dstField.Tag.Lookup("abi"); ok {
   160  				name := abi.ToCamelCase(tag)
   161  				if srcField, ok = srcFields[name]; !ok {
   162  					continue
   163  				}
   164  			} else {
   165  				continue
   166  			}
   167  		}
   168  
   169  		rCopy(
   170  			dst.FieldByName(dstField.Name),
   171  			src.FieldByName(srcField.Name),
   172  		)
   173  	}
   174  	return nil
   175  }
   176  
   177  func setSlice(dst, src reflect.Value) error {
   178  	if dst.IsNil() && dst.Kind() == reflect.Pointer {
   179  		dst = reflect.New(dst.Type().Elem())
   180  	}
   181  	if dst.Kind() == reflect.Pointer {
   182  		dst.Elem().Set(reflect.MakeSlice(dst.Elem().Type(), src.Len(), src.Len()))
   183  	} else {
   184  		dst.Set(reflect.MakeSlice(dst.Type(), src.Len(), src.Len()))
   185  	}
   186  
   187  	for i := range src.Len() {
   188  		if dst.Kind() == reflect.Pointer {
   189  			rCopy(dst.Elem().Index(i), src.Index(i))
   190  		} else {
   191  			rCopy(dst.Index(i), src.Index(i))
   192  		}
   193  	}
   194  	return nil
   195  }
   196  
   197  func setArray(dst, src reflect.Value) error {
   198  	if dst.Kind() == reflect.Pointer && dst.IsNil() {
   199  		dst = reflect.New(dst.Type().Elem())
   200  	}
   201  
   202  	for i := range src.Len() {
   203  		if dst.Kind() == reflect.Pointer {
   204  			rCopy(dst.Elem().Index(i), src.Index(i))
   205  		} else {
   206  			rCopy(dst.Index(i), src.Index(i))
   207  		}
   208  	}
   209  	return nil
   210  }
   211  
   212  func dereference(v reflect.Value) reflect.Value {
   213  	for v.Kind() == reflect.Pointer && v.Elem().Kind() == reflect.Pointer {
   214  		v = v.Elem()
   215  	}
   216  	return v
   217  }
   218  
   219  func reference(v reflect.Value) reflect.Value {
   220  	if v.Kind() != reflect.Pointer {
   221  		if v.CanAddr() {
   222  			v = v.Addr()
   223  		} else {
   224  			switch vv := v.Interface().(type) {
   225  			case bool:
   226  				v = reflect.ValueOf(&vv)
   227  			case uint:
   228  				v = reflect.ValueOf(&vv)
   229  			case uint8:
   230  				v = reflect.ValueOf(&vv)
   231  			case uint16:
   232  				v = reflect.ValueOf(&vv)
   233  			case uint32:
   234  				v = reflect.ValueOf(&vv)
   235  			case uint64:
   236  				v = reflect.ValueOf(&vv)
   237  			case int:
   238  				v = reflect.ValueOf(&vv)
   239  			case int8:
   240  				v = reflect.ValueOf(&vv)
   241  			case int16:
   242  				v = reflect.ValueOf(&vv)
   243  			case int32:
   244  				v = reflect.ValueOf(&vv)
   245  			case int64:
   246  				v = reflect.ValueOf(&vv)
   247  			case [1]byte:
   248  				v = reflect.ValueOf(&vv)
   249  			case [2]byte:
   250  				v = reflect.ValueOf(&vv)
   251  			case [3]byte:
   252  				v = reflect.ValueOf(&vv)
   253  			case [4]byte:
   254  				v = reflect.ValueOf(&vv)
   255  			case [5]byte:
   256  				v = reflect.ValueOf(&vv)
   257  			case [6]byte:
   258  				v = reflect.ValueOf(&vv)
   259  			case [7]byte:
   260  				v = reflect.ValueOf(&vv)
   261  			case [8]byte:
   262  				v = reflect.ValueOf(&vv)
   263  			case [9]byte:
   264  				v = reflect.ValueOf(&vv)
   265  			case [10]byte:
   266  				v = reflect.ValueOf(&vv)
   267  			case [11]byte:
   268  				v = reflect.ValueOf(&vv)
   269  			case [12]byte:
   270  				v = reflect.ValueOf(&vv)
   271  			case [13]byte:
   272  				v = reflect.ValueOf(&vv)
   273  			case [14]byte:
   274  				v = reflect.ValueOf(&vv)
   275  			case [15]byte:
   276  				v = reflect.ValueOf(&vv)
   277  			case [16]byte:
   278  				v = reflect.ValueOf(&vv)
   279  			case [17]byte:
   280  				v = reflect.ValueOf(&vv)
   281  			case [18]byte:
   282  				v = reflect.ValueOf(&vv)
   283  			case [19]byte:
   284  				v = reflect.ValueOf(&vv)
   285  			case [20]byte:
   286  				v = reflect.ValueOf(&vv)
   287  			case [21]byte:
   288  				v = reflect.ValueOf(&vv)
   289  			case [22]byte:
   290  				v = reflect.ValueOf(&vv)
   291  			case [23]byte:
   292  				v = reflect.ValueOf(&vv)
   293  			case [24]byte:
   294  				v = reflect.ValueOf(&vv)
   295  			case [25]byte:
   296  				v = reflect.ValueOf(&vv)
   297  			case [26]byte:
   298  				v = reflect.ValueOf(&vv)
   299  			case [27]byte:
   300  				v = reflect.ValueOf(&vv)
   301  			case [28]byte:
   302  				v = reflect.ValueOf(&vv)
   303  			case [29]byte:
   304  				v = reflect.ValueOf(&vv)
   305  			case [30]byte:
   306  				v = reflect.ValueOf(&vv)
   307  			case [31]byte:
   308  				v = reflect.ValueOf(&vv)
   309  			case [32]byte:
   310  				v = reflect.ValueOf(&vv)
   311  			case common.Address:
   312  				v = reflect.ValueOf(&vv)
   313  			case common.Hash:
   314  				v = reflect.ValueOf(&vv)
   315  			case string:
   316  				v = reflect.ValueOf(&vv)
   317  			case []byte:
   318  				v = reflect.ValueOf(&vv)
   319  			}
   320  		}
   321  	}
   322  	return v
   323  }