github.com/mavryk-network/mvgo@v1.19.9/micheline/marshal.go (about)

     1  // Copyright (c) 2023 Blockwatch Data Inc.
     2  // Author: alex@blockwatch.cc
     3  
     4  package micheline
     5  
     6  import (
     7  	"encoding/hex"
     8  	"fmt"
     9  	"math/big"
    10  	"strconv"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/mavryk-network/mvgo/mavryk"
    15  )
    16  
    17  type PrimMarshaler interface {
    18  	MarshalPrim() (Prim, error)
    19  }
    20  
    21  // SetPath replaces a nested primitive at path with dst.
    22  // Path segments are separated by slash (/).
    23  // Works on both type and value primitive trees.
    24  func (p *Prim) SetPath(path string, dst Prim) error {
    25  	index, err := p.getIndex(path)
    26  	if err != nil {
    27  		return err
    28  	}
    29  	return p.SetIndex(index, dst)
    30  }
    31  
    32  // SetPathExt replaces a nested primitive at path with dst if the primitive matches
    33  // the expected type. Path segments are separated by slash (/).
    34  // Works on best on value primitive trees.
    35  func (p *Prim) SetPathExt(path string, typ PrimType, dst Prim) error {
    36  	index, err := p.getIndex(path)
    37  	if err != nil {
    38  		return err
    39  	}
    40  	return p.SetIndexExt(index, typ, dst)
    41  }
    42  
    43  // SetIndex replaces a nested primitive at path index with dst.
    44  func (p *Prim) SetIndex(index []int, dst Prim) error {
    45  	prim := p
    46  	for _, v := range index {
    47  		if v < 0 || len(prim.Args) <= v {
    48  			return fmt.Errorf("micheline: index %d out of bounds", v)
    49  		}
    50  		prim = &prim.Args[v]
    51  	}
    52  	*prim = dst
    53  	return nil
    54  }
    55  
    56  // SetIndexExt replaces a nested primitive at path index if the primitive matches the
    57  // expected primitive type. This function works best with value trees which
    58  // lack opcode info. Use as extra cross-check when replacing prims.
    59  func (p *Prim) SetIndexExt(index []int, typ PrimType, dst Prim) error {
    60  	prim := p
    61  	for _, v := range index {
    62  		if v < 0 || len(prim.Args) <= v {
    63  			return fmt.Errorf("micheline: index %d out of bounds", v)
    64  		}
    65  		prim = &prim.Args[v]
    66  	}
    67  	if prim.Type != typ {
    68  		return fmt.Errorf("micheline: unexpected type %s at path %v", prim.Type, index)
    69  	}
    70  	*prim = dst
    71  	return nil
    72  }
    73  
    74  // Marshal takes a scalar or nested Go type and populates a Micheline
    75  // primitive tree compatible with type t. This method is compatible
    76  // with most contract entrypoints, contract storage, bigmap values, etc.
    77  // Use optimized to control whether the target prims contain values in
    78  // optimized form (binary addresses, numeric timestamps) or string form.
    79  //
    80  // Note: This is work in progress. Several data types are still unsupported
    81  // and entrypoint mapping requires some extra boilerplate:
    82  //
    83  //	// Entrypoint example (without error handling for brevity)
    84  //	eps, _ := script.Entrypoints(true)
    85  //	ep, _ := eps["name"]
    86  //
    87  //	// marshal to prim tree
    88  //	// Note: be mindful of the way entrypoint typedefs are structured:
    89  //	// - 1 arg: use scalar value in ep.Typedef[0]
    90  //	// - >1 arg: use entire list in ep.Typedef but wrap into struct
    91  //	typ := ep.Typedef[0]
    92  //	if len(ep.Typedef) > 1 {
    93  //	    typ = micheline.Typedef{
    94  //	        Name: micheline.CONST_ENTRYPOINT,
    95  //	        Type: micheline.TypeStruct,
    96  //	        Args: ep.Typedef,
    97  //	    }
    98  //	}
    99  //
   100  //	// then use the type to marshal into primitives
   101  //	prim, err := typ.Marshal(args, true)
   102  func (t Typedef) Marshal(v any, optimized bool) (Prim, error) {
   103  	return t.marshal(v, optimized, 0)
   104  }
   105  
   106  func (t Typedef) marshal(v any, optimized bool, depth int) (Prim, error) {
   107  	// fmt.Printf("Marshal %T %v => %#v\n", v, v, t)
   108  	if t.Optional {
   109  		val := v
   110  		if t.Name != "" && val != nil {
   111  			vals, ok := v.(map[string]any)
   112  			if ok {
   113  				val, ok = vals[t.Name]
   114  				if !ok {
   115  					return InvalidPrim, fmt.Errorf("missing arg %s", t.Name)
   116  				}
   117  			}
   118  		}
   119  		if val != nil {
   120  			t.Optional = false
   121  			p, err := t.marshal(val, optimized, depth+1)
   122  			if err != nil {
   123  				return InvalidPrim, err
   124  			}
   125  			return NewOption(p), nil
   126  		} else {
   127  			return NewOption(), nil
   128  		}
   129  	}
   130  	switch t.Type {
   131  	case TypeUnion:
   132  		// find the named union element in map
   133  		vals, ok := v.(map[string]any)
   134  		if !ok {
   135  			return InvalidPrim, fmt.Errorf("invalid type %T on union %s", v, t.Name)
   136  		}
   137  		var child Typedef
   138  		for _, n := range t.Args {
   139  			if _, ok := vals[n.Name]; ok {
   140  				child = n
   141  				break
   142  			}
   143  		}
   144  		// marshal child type
   145  		p, err := child.marshal(vals[child.Name], optimized, depth+1)
   146  		if err != nil {
   147  			return InvalidPrim, err
   148  		}
   149  		// produce OR tree for child's path
   150  		return NewUnion(child.Path[depth:], p), nil
   151  
   152  	case TypeStruct:
   153  		vals, ok := v.(map[string]any)
   154  		if !ok {
   155  			return InvalidPrim, fmt.Errorf("invalid type %T on struct %s", v, t.Name)
   156  		}
   157  		// for values with nested named structs try if name exists
   158  		if m, ok := vals[t.Name]; t.Name != "" && ok {
   159  			fmt.Printf("Unpacking nested struct %s\n", t.Name)
   160  			vals, ok = m.(map[string]any)
   161  			if !ok {
   162  				return InvalidPrim, fmt.Errorf("invalid type %T on nested struct %s", m, t.Name)
   163  			}
   164  		}
   165  		prims := []Prim{}
   166  		for _, v := range t.Args {
   167  			p, err := v.marshal(vals[v.Name], optimized, depth+1)
   168  			if err != nil {
   169  				return InvalidPrim, err
   170  			}
   171  			prims = append(prims, p)
   172  		}
   173  		if len(prims) > 2 {
   174  			// reconstruct struct structure as Pair tree from type paths
   175  			var root Prim
   176  			for i, v := range prims {
   177  				root.Insert(v, t.Args[i].Path[depth:])
   178  			}
   179  			return root, nil
   180  		}
   181  		return NewPair(prims[0], prims[1]), nil
   182  
   183  	case "list", "set":
   184  		if v == nil {
   185  			return NewSeq(), nil
   186  		}
   187  		listVals, ok := v.([]any)
   188  		if !ok {
   189  			// use nested value for named lists
   190  			vals, ok := v.(map[string]any)
   191  			if !ok {
   192  				return InvalidPrim, fmt.Errorf("invalid list/set type %T on field %s, must be map[string]any", v, t.Name)
   193  			}
   194  			list, ok := vals[t.Name]
   195  			if !ok {
   196  				return InvalidPrim, fmt.Errorf("missing list/set arg %s", t.Name)
   197  			}
   198  			listVals, ok = list.([]any)
   199  			if !ok {
   200  				return InvalidPrim, fmt.Errorf("invalid list/set type %T on field %s, must be []any", list, t.Name)
   201  			}
   202  		}
   203  		prims := []Prim{}
   204  		for _, v := range listVals {
   205  			p, err := t.Args[0].marshal(v, optimized, depth+1)
   206  			if err != nil {
   207  				return InvalidPrim, err
   208  			}
   209  			prims = append(prims, p)
   210  		}
   211  		return NewSeq(prims...), nil
   212  
   213  	case "map", "big_map":
   214  		if v == nil {
   215  			return NewMap(), nil
   216  		}
   217  		vals, ok := v.(map[string]any)
   218  		if !ok {
   219  			return InvalidPrim, fmt.Errorf("invalid map type %T on field %s, must be map[string]any", v, t.Name)
   220  		}
   221  		// for top-level maps (in entrypoints etc) try if map name is part of value tree
   222  		if depth == 0 {
   223  			if m, ok := vals[t.Name]; ok {
   224  				vals, ok = m.(map[string]any)
   225  				if !ok {
   226  					return InvalidPrim, fmt.Errorf("invalid map type %T on field %s, must be map[string]any", m, t.Name)
   227  				}
   228  			}
   229  		}
   230  		prims := []Prim{}
   231  		for n, v := range vals {
   232  			key, err := ParsePrim(t.Left(), n, optimized)
   233  			if err != nil {
   234  				return InvalidPrim, err
   235  			}
   236  			value, err := t.Right().marshal(v, optimized, depth+1)
   237  			if err != nil {
   238  				return InvalidPrim, err
   239  			}
   240  			prims = append(prims, NewMapElem(key, value))
   241  		}
   242  		return NewMap(prims...), nil
   243  
   244  	case "lambda":
   245  		switch val := v.(type) {
   246  		case string:
   247  			var p Prim
   248  			err := p.UnmarshalJSON([]byte(val))
   249  			return p, err
   250  		case PrimMarshaler:
   251  			return val.MarshalPrim()
   252  		case Prim:
   253  			return val, nil
   254  		default:
   255  			return InvalidPrim, fmt.Errorf("unsupported type %T for lambda on field %s", v, t.Name)
   256  		}
   257  
   258  	default:
   259  		// scalar
   260  		oc := t.OpCode()
   261  		if !oc.IsValid() {
   262  			return InvalidPrim, fmt.Errorf("invalid type code %s on field %s", t.Type, t.Name)
   263  		}
   264  		if v == nil {
   265  			if oc == T_UNIT {
   266  				return NewCode(D_UNIT), nil
   267  			}
   268  			return InvalidPrim, fmt.Errorf("missing arg %s (%s)", t.Name, t.Type)
   269  		}
   270  		switch val := v.(type) {
   271  		case map[string]any:
   272  			// recurse unpack the named value from this map
   273  			return t.marshal(val[t.Name], optimized, depth)
   274  		case Prim:
   275  			return val, nil
   276  		case PrimMarshaler:
   277  			return val.MarshalPrim( /* optimized */ )
   278  		case string:
   279  			// parse anything from string (supports config file and API map[string]string)
   280  			return ParsePrim(t, val, optimized)
   281  		case []byte:
   282  			return NewBytes(val), nil
   283  		case bool:
   284  			if val {
   285  				return NewCode(D_TRUE), nil
   286  			}
   287  			return NewCode(D_FALSE), nil
   288  		case int:
   289  			switch oc {
   290  			case T_BYTES:
   291  				return NewBytes([]byte(strconv.FormatInt(int64(val), 10))), nil
   292  			case T_STRING:
   293  				return NewString(strconv.FormatInt(int64(val), 10)), nil
   294  			case T_TIMESTAMP:
   295  				if optimized {
   296  					return NewInt64(int64(val)), nil
   297  				}
   298  				return NewString(time.Unix(int64(val), 0).UTC().Format(time.RFC3339)), nil
   299  			case T_INT, T_NAT, T_MUMAV:
   300  				return NewInt64(int64(val)), nil
   301  			default:
   302  				return InvalidPrim, fmt.Errorf("unsupported type conversion %T to opcode %s for on field %s", v, t.Type, t.Name)
   303  			}
   304  		case int64:
   305  			switch oc {
   306  			case T_BYTES:
   307  				return NewBytes([]byte(strconv.FormatInt(val, 10))), nil
   308  			case T_STRING:
   309  				return NewString(strconv.FormatInt(val, 10)), nil
   310  			case T_TIMESTAMP:
   311  				if optimized {
   312  					return NewInt64(val), nil
   313  				}
   314  				return NewString(time.Unix(val, 0).UTC().Format(time.RFC3339)), nil
   315  			case T_INT, T_NAT, T_MUMAV:
   316  				return NewInt64(val), nil
   317  			default:
   318  				return InvalidPrim, fmt.Errorf("unsupported type conversion %T to opcode %s on field %s", v, t.Type, t.Name)
   319  			}
   320  		case time.Time:
   321  			if optimized {
   322  				return NewInt64(val.Unix()), nil
   323  			}
   324  			return NewString(val.UTC().Format(time.RFC3339)), nil
   325  		case mavryk.Address:
   326  			if optimized {
   327  				switch oc {
   328  				case T_KEY_HASH:
   329  					return NewKeyHash(val), nil
   330  				case T_ADDRESS:
   331  					return NewAddress(val), nil
   332  				default:
   333  					return InvalidPrim, fmt.Errorf("unsupported type conversion from %T to opcode %s on field %s", v, t.Type, t.Name)
   334  				}
   335  			}
   336  			return NewString(val.String()), nil
   337  		case mavryk.Key:
   338  			if optimized {
   339  				return NewBytes(val.Bytes()), nil
   340  			}
   341  			return NewString(val.String()), nil
   342  		case mavryk.Signature:
   343  			if optimized {
   344  				return NewBytes(val.Bytes()), nil
   345  			}
   346  			return NewString(val.String()), nil
   347  		case mavryk.ChainIdHash:
   348  			return NewString(val.String()), nil
   349  
   350  		default:
   351  			// TODO
   352  			return InvalidPrim, fmt.Errorf("unsupported type %T for opcode %s on field %s", v, t.Type, t.Name)
   353  		}
   354  	}
   355  }
   356  
   357  func ParsePrim(typ Typedef, val string, optimized bool) (p Prim, err error) {
   358  	p = InvalidPrim
   359  	if !typ.OpCode().IsTypeCode() {
   360  		err = fmt.Errorf("invalid type code %q", typ)
   361  		return
   362  	}
   363  	switch typ.OpCode() {
   364  	case T_INT, T_NAT, T_MUMAV:
   365  		i := big.NewInt(0)
   366  		err = i.UnmarshalText([]byte(val))
   367  		p = NewBig(i)
   368  	case T_STRING:
   369  		p = NewString(val)
   370  	case T_BYTES:
   371  		if buf, err2 := hex.DecodeString(val); err2 != nil {
   372  			p = NewBytes([]byte(val))
   373  		} else {
   374  			p = NewBytes(buf)
   375  		}
   376  	case T_BOOL:
   377  		var b bool
   378  		b, err = strconv.ParseBool(val)
   379  		if b {
   380  			p = NewCode(D_TRUE)
   381  		} else {
   382  			p = NewCode(D_FALSE)
   383  		}
   384  	case T_TIMESTAMP:
   385  		// either RFC3339 or UNIX seconds
   386  		var tm time.Time
   387  		if strings.Contains(val, "T") {
   388  			tm, err = time.Parse(time.RFC3339, val)
   389  		} else {
   390  			var i int64
   391  			i, err = strconv.ParseInt(val, 10, 64)
   392  			tm = time.Unix(i, 0).UTC()
   393  		}
   394  		if optimized {
   395  			p = NewInt64(tm.Unix())
   396  		} else {
   397  			p = NewString(tm.Format(time.RFC3339))
   398  		}
   399  	case T_KEY_HASH:
   400  		var addr mavryk.Address
   401  		addr, err = mavryk.ParseAddress(val)
   402  		if optimized {
   403  			p = NewKeyHash(addr)
   404  		} else {
   405  			p = NewString(addr.String())
   406  		}
   407  	case T_ADDRESS:
   408  		var addr mavryk.Address
   409  		addr, err = mavryk.ParseAddress(val)
   410  		if optimized {
   411  			p = NewAddress(addr)
   412  		} else {
   413  			p = NewString(addr.String())
   414  		}
   415  	case T_KEY:
   416  		var key mavryk.Key
   417  		key, err = mavryk.ParseKey(val)
   418  		if optimized {
   419  			p = NewBytes(key.Bytes())
   420  		} else {
   421  			p = NewString(key.String())
   422  		}
   423  
   424  	case T_SIGNATURE:
   425  		var sig mavryk.Signature
   426  		sig, err = mavryk.ParseSignature(val)
   427  		if optimized {
   428  			p = NewBytes(sig.Bytes())
   429  		} else {
   430  			p = NewString(sig.String())
   431  		}
   432  
   433  	case T_UNIT:
   434  		if val == D_UNIT.String() || val == "" {
   435  			p = NewCode(D_UNIT)
   436  		} else {
   437  			err = fmt.Errorf("micheline: invalid value %q for unit type", val)
   438  		}
   439  
   440  	case T_PAIR:
   441  		// parse comma-separated list into map using type lables from typedef
   442  		// note: this only supports simple structs which is probably enough
   443  		// because bigmap keys must be comparable types
   444  		m := make(map[string]any)
   445  		for i, v := range strings.Split(val, ",") {
   446  			// find i-th child in typedef
   447  			if len(typ.Args) < i-1 {
   448  				err = fmt.Errorf("micheline: invalid value for bigmap key struct type %s", typ.Name)
   449  				return
   450  			}
   451  			m[typ.Args[i].Name] = v
   452  		}
   453  		return typ.marshal(m, optimized, 0)
   454  
   455  	default:
   456  		err = fmt.Errorf("micheline: unsupported big_map key type %s", typ)
   457  	}
   458  
   459  	if err != nil {
   460  		p = InvalidPrim
   461  	}
   462  	return
   463  }
   464  
   465  func (p *Prim) Insert(src Prim, path []int) {
   466  	if !p.IsValid() {
   467  		*p = NewPair(Prim{}, Prim{})
   468  	}
   469  
   470  	if len(p.Args) <= path[0] {
   471  		cp := make([]Prim, path[0]+1)
   472  		copy(cp, p.Args)
   473  		p.Args = cp
   474  		// convert to sequence
   475  		p.Type = PrimSequence
   476  		p.OpCode = 0
   477  	}
   478  
   479  	if len(path) == 1 {
   480  		p.Args[path[0]] = src
   481  		return
   482  	}
   483  
   484  	p.Args[path[0]].Insert(src, path[1:])
   485  }