github.com/mavryk-network/mvgo@v1.19.9/micheline/unmarshal.go (about)

     1  // Copyright (c) 2020-2022 Blockwatch Data Inc.
     2  // Author: alex@blockwatch.cc
     3  
     4  package micheline
     5  
     6  import (
     7  	"encoding"
     8  	"fmt"
     9  	"reflect"
    10  	"strconv"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/mavryk-network/mvgo/mavryk"
    15  )
    16  
    17  type PrimUnmarshaler interface {
    18  	UnmarshalPrim(Prim) error
    19  }
    20  
    21  // FindLabel searches a nested type annotation path. Must be used on a type prim.
    22  // Path segments are separated by dot (.)
    23  func (p Prim) FindLabel(label string) (Prim, bool) {
    24  	idx, ok := p.LabelIndex(label)
    25  	if !ok {
    26  		return InvalidPrim, false
    27  	}
    28  	prim, _ := p.GetIndex(idx)
    29  	return prim, true
    30  }
    31  
    32  // LabelIndex returns the indexed path to a type annotation label and true
    33  // if path exists. Path segments are separated by dot (.)
    34  func (p Prim) LabelIndex(label string) ([]int, bool) {
    35  	return p.findLabelPath(strings.Split(label, "."), nil)
    36  }
    37  
    38  func (p Prim) findLabelPath(path []string, idx []int) ([]int, bool) {
    39  	prim := p
    40  next:
    41  	for {
    42  		if len(path) == 0 {
    43  			return idx, true
    44  		}
    45  		var found bool
    46  		for i, v := range prim.Args {
    47  			if v.HasAnno() && v.MatchesAnno(path[0]) {
    48  				idx = append(idx, i)
    49  				path = path[1:]
    50  				prim = v
    51  				found = true
    52  				continue next
    53  			}
    54  		}
    55  
    56  		for i, v := range prim.Args {
    57  			if v.HasAnno() {
    58  				continue
    59  			}
    60  			idx2, ok := v.findLabelPath(path, append(idx, i))
    61  			if ok {
    62  				path = path[:0]
    63  				found = true
    64  				idx = idx2
    65  				continue next
    66  			}
    67  		}
    68  		if !found {
    69  			return nil, false
    70  		}
    71  	}
    72  }
    73  
    74  // GetPath returns a nested primitive at path. Path segments are separated by slash (/).
    75  // Works on both type and value primitive trees.
    76  func (p Prim) GetPath(path string) (Prim, error) {
    77  	index, err := p.getIndex(path)
    78  	if err != nil {
    79  		return InvalidPrim, err
    80  	}
    81  	return p.GetIndex(index)
    82  }
    83  
    84  // GetPathExt returns a nested primitive at path if the primitive matches
    85  // the expected opcode. Path segments are separated by slash (/).
    86  // Works on both type and value primitive trees.
    87  func (p Prim) GetPathExt(path string, typ OpCode) (Prim, error) {
    88  	prim, err := p.GetPath(path)
    89  	if err != nil {
    90  		return prim, err
    91  	}
    92  	if prim.OpCode != typ {
    93  		return InvalidPrim, fmt.Errorf("micheline: unexpected type %s at path %v", prim.OpCode, path)
    94  	}
    95  	return prim, nil
    96  }
    97  
    98  func (p Prim) getIndex(path string) ([]int, error) {
    99  	index := make([]int, 0)
   100  	path = strings.TrimPrefix(path, "/")
   101  	path = strings.TrimSuffix(path, "/")
   102  	if len(path) == 0 {
   103  		return nil, nil
   104  	}
   105  	for i, v := range strings.Split(path, "/") {
   106  		switch v {
   107  		case "L", "l", "0":
   108  			index = append(index, 0)
   109  		case "R", "r", "1":
   110  			index = append(index, 1)
   111  		default:
   112  			idx, err := strconv.Atoi(v)
   113  			if err != nil {
   114  				return nil, fmt.Errorf("micheline: invalid path component '%v' at pos %d", v, i)
   115  			}
   116  			index = append(index, idx)
   117  		}
   118  	}
   119  	return index, nil
   120  }
   121  
   122  // HasIndex returns true when a nested primitive exists at path defined by index.
   123  func (p Prim) HasIndex(index []int) bool {
   124  	prim := p
   125  	for _, v := range index {
   126  		if v < 0 || len(prim.Args) <= v {
   127  			return false
   128  		}
   129  		prim = prim.Args[v]
   130  	}
   131  	return true
   132  }
   133  
   134  // GetIndex returns a nested primitive at path index.
   135  func (p Prim) GetIndex(index []int) (Prim, error) {
   136  	prim := p
   137  	for _, v := range index {
   138  		if v < 0 || len(prim.Args) <= v {
   139  			return InvalidPrim, fmt.Errorf("micheline: index %d out of bounds", v)
   140  		}
   141  		prim = prim.Args[v]
   142  	}
   143  	return prim, nil
   144  }
   145  
   146  // GetIndex returns a nested primitive at path index if the primitive matches the
   147  // expected opcode. This only works on type trees. Value trees lack opcode info.
   148  func (p Prim) GetIndexExt(index []int, typ OpCode) (Prim, error) {
   149  	prim, err := p.GetIndex(index)
   150  	if err != nil {
   151  		return InvalidPrim, err
   152  	}
   153  	if prim.OpCode != typ {
   154  		return InvalidPrim, fmt.Errorf("micheline: unexpected type %s at path %v", prim.OpCode, index)
   155  	}
   156  	return prim, nil
   157  }
   158  
   159  // Decode unmarshals a prim tree into a Go struct. The mapping uses Go struct tags
   160  // to define primitive paths that are mapped to each struct member. Types are
   161  // converted between Micheline and Go when possible.
   162  //
   163  // Examples of struct field tags and their meanings:
   164  //
   165  //	// maps Micheline path 0/0/0 to string field and fails on type mismatch
   166  //	Field string `prim:",path=0/0/1"`
   167  //
   168  //	// ignore type errors and do not update struct field
   169  //	Field string  `prim:",path=0/0/1,nofail"`
   170  //
   171  //	// ignore struct field
   172  //	Field string  `prim:"-"`
   173  func (p Prim) Decode(v interface{}) error {
   174  	val := reflect.ValueOf(v)
   175  	if val.Kind() != reflect.Ptr {
   176  		return fmt.Errorf("micheline: non-pointer passed to Decode: %s %s", val.Kind(), val.Type().String())
   177  	}
   178  	val = reflect.Indirect(val)
   179  	if val.Kind() != reflect.Struct {
   180  		return fmt.Errorf("micheline: non-struct passed to Decode %s %s", val.Kind(), val.Type().String())
   181  	}
   182  	return p.unmarshal(val)
   183  }
   184  
   185  func (p Prim) unmarshal(val reflect.Value) error {
   186  	val = derefValue(val)
   187  	if val.CanInterface() && val.Type().Implements(primUnmarshalerType) {
   188  		// This is an unmarshaler with a non-pointer receiver,
   189  		// so it's likely to be incorrect, but we do what we're told.
   190  		return val.Interface().(PrimUnmarshaler).UnmarshalPrim(p)
   191  	}
   192  	if val.CanAddr() {
   193  		pv := val.Addr()
   194  		if pv.CanInterface() && pv.Type().Implements(primUnmarshalerType) {
   195  			return pv.Interface().(PrimUnmarshaler).UnmarshalPrim(p)
   196  		}
   197  	}
   198  
   199  	tinfo, err := getTypeInfo(indirectType(val.Type()))
   200  	if err != nil {
   201  		return err
   202  	}
   203  	for _, finfo := range tinfo.fields {
   204  		dst := finfo.value(val)
   205  		if !dst.IsValid() {
   206  			continue
   207  		}
   208  		if dst.Kind() == reflect.Ptr {
   209  			if dst.IsNil() && dst.CanSet() {
   210  				dst.Set(reflect.New(dst.Type()))
   211  			}
   212  			dst = dst.Elem()
   213  		}
   214  		pp, err := p.GetIndex(finfo.path)
   215  		if err != nil {
   216  			if finfo.nofail {
   217  				continue
   218  			}
   219  			return err
   220  		}
   221  		switch finfo.typ {
   222  		case T_BYTES:
   223  			if dst.CanAddr() {
   224  				pv := dst.Addr()
   225  				if pv.CanInterface() {
   226  					if pv.Type().Implements(binaryUnmarshalerType) {
   227  						if err := pv.Interface().(encoding.BinaryUnmarshaler).UnmarshalBinary(pp.Bytes); err != nil {
   228  							if !finfo.nofail {
   229  								return err
   230  							}
   231  						}
   232  						break
   233  					}
   234  					if pv.Type().Implements(textUnmarshalerType) {
   235  						if err := pv.Interface().(encoding.TextUnmarshaler).UnmarshalText(pp.Bytes); err != nil {
   236  							if !finfo.nofail {
   237  								return err
   238  							}
   239  						}
   240  						break
   241  					}
   242  				}
   243  			}
   244  			buf := make([]byte, len(pp.Bytes))
   245  			copy(buf, pp.Bytes)
   246  			dst.SetBytes(buf)
   247  		case T_STRING:
   248  			if dst.CanAddr() {
   249  				pv := dst.Addr()
   250  				if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
   251  					if pp.Bytes != nil {
   252  						if err := pv.Interface().(encoding.TextUnmarshaler).UnmarshalText(pp.Bytes); err != nil {
   253  							if !finfo.nofail {
   254  								return err
   255  							}
   256  						}
   257  						break
   258  					}
   259  					if err := pv.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(pp.String)); err != nil {
   260  						if !finfo.nofail {
   261  							return err
   262  						}
   263  					}
   264  					break
   265  				}
   266  			}
   267  			dst.SetString(pp.String)
   268  		case T_INT, T_NAT:
   269  			if dst.CanAddr() {
   270  				pv := dst.Addr()
   271  				if pv.CanInterface() && pv.Type().Implements(textUnmarshalerType) {
   272  					if pp.Int != nil {
   273  						if err := pv.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(pp.Int.Text(10))); err != nil {
   274  							if !finfo.nofail {
   275  								return err
   276  							}
   277  						}
   278  						break
   279  					}
   280  					if err := pv.Interface().(encoding.TextUnmarshaler).UnmarshalText([]byte(pp.String)); err != nil {
   281  						if !finfo.nofail {
   282  							return err
   283  						}
   284  					}
   285  					break
   286  				}
   287  			}
   288  			switch dst.Type().Kind() {
   289  			case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   290  				dst.SetUint(uint64(pp.Int.Int64()))
   291  			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   292  				dst.SetInt(pp.Int.Int64())
   293  			}
   294  
   295  		case T_BOOL:
   296  			dst.SetBool(pp.OpCode == D_TRUE)
   297  		case T_TIMESTAMP:
   298  			if pp.Int != nil {
   299  				dst.Set(reflect.ValueOf(time.Unix(pp.Int.Int64(), 0).UTC()))
   300  			} else {
   301  				tm, err := time.Parse(time.RFC3339, pp.String)
   302  				if err != nil {
   303  					if !finfo.nofail {
   304  						return err
   305  					}
   306  				}
   307  				dst.Set(reflect.ValueOf(tm))
   308  			}
   309  		case T_ADDRESS:
   310  			var (
   311  				addr mavryk.Address
   312  				err  error
   313  			)
   314  			if pp.Bytes != nil {
   315  				err = addr.Decode(pp.Bytes)
   316  			} else {
   317  				err = addr.UnmarshalText([]byte(pp.String))
   318  			}
   319  			if err != nil && !finfo.nofail {
   320  				return err
   321  			}
   322  			dst.Set(reflect.ValueOf(addr))
   323  		case T_KEY:
   324  			var (
   325  				key mavryk.Key
   326  				err error
   327  			)
   328  			if pp.Bytes != nil {
   329  				err = key.UnmarshalBinary(pp.Bytes)
   330  			} else {
   331  				err = key.UnmarshalText([]byte(pp.String))
   332  			}
   333  			if err != nil && !finfo.nofail {
   334  				return err
   335  			}
   336  			dst.Set(reflect.ValueOf(key))
   337  		case T_SIGNATURE:
   338  			var (
   339  				sig mavryk.Signature
   340  				err error
   341  			)
   342  			if pp.Bytes != nil {
   343  				err = sig.UnmarshalBinary(pp.Bytes)
   344  			} else {
   345  				err = sig.UnmarshalText([]byte(pp.String))
   346  			}
   347  			if err != nil && !finfo.nofail {
   348  				return err
   349  			}
   350  			dst.Set(reflect.ValueOf(sig))
   351  		case T_CHAIN_ID:
   352  			var (
   353  				chain mavryk.ChainIdHash
   354  				err   error
   355  			)
   356  			if pp.Bytes != nil {
   357  				err = chain.UnmarshalBinary(pp.Bytes)
   358  			} else {
   359  				err = chain.UnmarshalText([]byte(pp.String))
   360  			}
   361  			if err != nil && !finfo.nofail {
   362  				return err
   363  			}
   364  			dst.Set(reflect.ValueOf(chain))
   365  		case T_LIST:
   366  			styp := dst.Type()
   367  			if dst.IsNil() {
   368  				dst.Set(reflect.MakeSlice(styp, 0, len(pp.Args)))
   369  			}
   370  			for idx, ppp := range pp.Args {
   371  				sval := reflect.New(styp.Elem())
   372  				if sval.Type().Kind() == reflect.Ptr && sval.IsNil() && sval.CanSet() {
   373  					sval.Set(reflect.New(sval.Type().Elem()))
   374  				}
   375  				// decode from value prim
   376  				if err := ppp.unmarshal(sval); err != nil && !finfo.nofail {
   377  					return err
   378  				}
   379  				dst.SetLen(idx + 1)
   380  				dst.Index(idx).Set(sval.Elem())
   381  			}
   382  		case T_MAP:
   383  			mtyp := dst.Type()
   384  			switch mtyp.Key().Kind() {
   385  			case reflect.String:
   386  			default:
   387  				return fmt.Errorf("micheline: only string keys are supported for map %s", finfo.name)
   388  			}
   389  			if dst.IsNil() {
   390  				dst.Set(reflect.MakeMap(mtyp))
   391  			}
   392  
   393  			// process ELT args
   394  			for _, ppp := range pp.Args {
   395  				// must be an ELT
   396  				if ppp.OpCode != D_ELT {
   397  					return fmt.Errorf("micheline: expected ELT data for map field %s, got %s",
   398  						finfo.name, ppp.Dump())
   399  				}
   400  				// decode string from ELT key
   401  				k, err := NewKey(ppp.Args[0].BuildType(), ppp.Args[0])
   402  				if err != nil {
   403  					return fmt.Errorf("micheline: cannot convert ELT key for field %s val=%s: %v",
   404  						finfo.name, ppp.Args[0].Dump(), err)
   405  				}
   406  				name := k.String()
   407  
   408  				// allocate value
   409  				mval := reflect.New(mtyp.Elem()).Elem()
   410  				if mval.Type().Kind() == reflect.Ptr && mval.IsNil() && mval.CanSet() {
   411  					mval.Set(reflect.New(mval.Type().Elem()))
   412  				}
   413  
   414  				// decode from value prim
   415  				if err := ppp.Args[1].unmarshal(mval); err != nil && !finfo.nofail {
   416  					return err
   417  				}
   418  
   419  				// assign to map
   420  				dst.SetMapIndex(reflect.ValueOf(name), mval)
   421  			}
   422  		default:
   423  			return fmt.Errorf("micheline: unsupported prim %#v for struct field %s", pp, finfo.name)
   424  
   425  		}
   426  	}
   427  	return nil
   428  }