github.com/InjectiveLabs/sdk-go@v1.53.0/eip712_cosmos.go (about)

     1  package ante
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"fmt"
     7  	"math/big"
     8  	"reflect"
     9  	"runtime/debug"
    10  	"strings"
    11  	"time"
    12  
    13  	"cosmossdk.io/math"
    14  	"github.com/cosmos/cosmos-sdk/codec"
    15  	codectypes "github.com/cosmos/cosmos-sdk/codec/types"
    16  	cosmtypes "github.com/cosmos/cosmos-sdk/types"
    17  	"github.com/cosmos/cosmos-sdk/x/auth/migrations/legacytx"
    18  	authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing"
    19  	"github.com/ethereum/go-ethereum/common"
    20  	ethmath "github.com/ethereum/go-ethereum/common/math"
    21  	"github.com/pkg/errors"
    22  
    23  	"github.com/InjectiveLabs/sdk-go/typeddata"
    24  )
    25  
    26  type EIP712Wrapper func(
    27  	cdc codec.ProtoCodecMarshaler,
    28  	chainID uint64,
    29  	signerData *authsigning.SignerData,
    30  	timeoutHeight uint64,
    31  	memo string,
    32  	feeInfo legacytx.StdFee,
    33  	msgs []cosmtypes.Msg,
    34  	feeDelegation *FeeDelegationOptions,
    35  ) (typeddata.TypedData, error)
    36  
    37  // WrapTxToEIP712 is an ultimate method that wraps Amino-encoded Cosmos Tx JSON data
    38  // into an EIP712-compatible request. All messages must be of the same type.
    39  func WrapTxToEIP712(
    40  	cdc codec.ProtoCodecMarshaler,
    41  	chainID uint64,
    42  	signerData *authsigning.SignerData,
    43  	timeoutHeight uint64,
    44  	memo string,
    45  	feeInfo legacytx.StdFee,
    46  	msgs []cosmtypes.Msg,
    47  	feeDelegation *FeeDelegationOptions,
    48  ) (typeddata.TypedData, error) {
    49  	data := legacytx.StdSignBytes(
    50  		signerData.ChainID,
    51  		signerData.AccountNumber,
    52  		signerData.Sequence,
    53  		timeoutHeight,
    54  		feeInfo,
    55  		msgs, memo,
    56  	)
    57  
    58  	txData := make(map[string]interface{})
    59  	if err := json.Unmarshal(data, &txData); err != nil {
    60  		err = errors.Wrap(err, "failed to unmarshal data provided into WrapTxToEIP712")
    61  		return typeddata.TypedData{}, err
    62  	}
    63  
    64  	domain := typeddata.TypedDataDomain{
    65  		Name:              "Injective Web3",
    66  		Version:           "1.0.0",
    67  		ChainId:           ethmath.NewHexOrDecimal256(int64(chainID)),
    68  		VerifyingContract: "cosmos",
    69  		Salt:              "0",
    70  	}
    71  
    72  	msgTypes, err := extractMsgTypes(cdc, "MsgValue", msgs[0])
    73  	if err != nil {
    74  		return typeddata.TypedData{}, err
    75  	}
    76  
    77  	if feeDelegation != nil {
    78  		feeInfo := txData["fee"].(map[string]interface{})
    79  		feeInfo["feePayer"] = feeDelegation.FeePayer.String()
    80  
    81  		// also patching msgTypes to include feePayer
    82  		msgTypes["Fee"] = []typeddata.Type{
    83  			{Name: "feePayer", Type: "string"},
    84  			{Name: "amount", Type: "Coin[]"},
    85  			{Name: "gas", Type: "string"},
    86  		}
    87  	}
    88  
    89  	var typedData = typeddata.TypedData{
    90  		Types:       msgTypes,
    91  		PrimaryType: "Tx",
    92  		Domain:      domain,
    93  		Message:     txData,
    94  	}
    95  
    96  	return typedData, nil
    97  }
    98  
    99  type FeeDelegationOptions struct {
   100  	FeePayer cosmtypes.AccAddress
   101  }
   102  
   103  func extractMsgTypes(cdc codec.ProtoCodecMarshaler, msgTypeName string, msg cosmtypes.Msg) (typeddata.Types, error) {
   104  	rootTypes := typeddata.Types{
   105  		"EIP712Domain": {
   106  			{
   107  				Name: "name",
   108  				Type: "string",
   109  			},
   110  			{
   111  				Name: "version",
   112  				Type: "string",
   113  			},
   114  			{
   115  				Name: "chainId",
   116  				Type: "uint256",
   117  			},
   118  			{
   119  				Name: "verifyingContract",
   120  				Type: "string",
   121  			},
   122  			{
   123  				Name: "salt",
   124  				Type: "string",
   125  			},
   126  		},
   127  		"Tx": {
   128  			{Name: "account_number", Type: "string"},
   129  			{Name: "chain_id", Type: "string"},
   130  			{Name: "fee", Type: "Fee"},
   131  			{Name: "memo", Type: "string"},
   132  			{Name: "msgs", Type: "Msg[]"},
   133  			{Name: "sequence", Type: "string"},
   134  			{Name: "timeout_height", Type: "string"},
   135  		},
   136  		"Fee": {
   137  			{Name: "amount", Type: "Coin[]"},
   138  			{Name: "gas", Type: "string"},
   139  		},
   140  		"Coin": {
   141  			{Name: "denom", Type: "string"},
   142  			{Name: "amount", Type: "string"},
   143  		},
   144  		"Msg": {
   145  			{Name: "type", Type: "string"},
   146  			{Name: "value", Type: msgTypeName},
   147  		},
   148  		msgTypeName: {},
   149  	}
   150  
   151  	err := walkFields(cdc, rootTypes, msgTypeName, msg)
   152  	if err != nil {
   153  		return nil, err
   154  	}
   155  
   156  	return rootTypes, nil
   157  }
   158  
   159  const typeDefPrefix = "_"
   160  
   161  func walkFields(cdc codec.ProtoCodecMarshaler, typeMap typeddata.Types, rootType string, in interface{}) (err error) {
   162  	defer doRecover(&err)
   163  
   164  	t := reflect.TypeOf(in)
   165  	v := reflect.ValueOf(in)
   166  
   167  	for {
   168  		if t.Kind() == reflect.Ptr || t.Kind() == reflect.Interface {
   169  			t = t.Elem()
   170  			v = v.Elem()
   171  			continue
   172  		}
   173  		break
   174  	}
   175  
   176  	err = traverseFields(cdc, typeMap, rootType, typeDefPrefix, t, v)
   177  	return
   178  }
   179  
   180  type cosmosAnyWrapper struct {
   181  	Type  string      `json:"type"`
   182  	Value interface{} `json:"value"`
   183  }
   184  
   185  func traverseFields(
   186  	cdc codec.ProtoCodecMarshaler,
   187  	typeMap typeddata.Types,
   188  	rootType string,
   189  	prefix string,
   190  	t reflect.Type,
   191  	v reflect.Value,
   192  ) (err error) {
   193  	n := t.NumField()
   194  
   195  	if prefix == typeDefPrefix {
   196  		if len(typeMap[rootType]) == n {
   197  			return nil
   198  		}
   199  	} else {
   200  		typeDef := sanitizeTypedef(prefix)
   201  		if len(typeMap[typeDef]) == n {
   202  			return nil
   203  		}
   204  	}
   205  
   206  	for i := 0; i < n; i++ {
   207  		var field reflect.Value
   208  		if v.IsValid() {
   209  			field = v.Field(i)
   210  		}
   211  
   212  		fieldType := t.Field(i).Type
   213  		fieldName := jsonNameFromTag(t.Field(i).Tag)
   214  		var isCollection bool
   215  		if fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice {
   216  			if field.Len() == 0 {
   217  				// skip empty collections from type mapping
   218  				continue
   219  			}
   220  
   221  			fieldType = fieldType.Elem()
   222  			field = field.Index(0)
   223  			isCollection = true
   224  		}
   225  
   226  		if fieldType == cosmosAnyType {
   227  			any := field.Interface().(*codectypes.Any)
   228  			anyWrapper := &cosmosAnyWrapper{
   229  				Type: any.TypeUrl,
   230  			}
   231  
   232  			err = cdc.UnpackAny(any, &anyWrapper.Value)
   233  			if err != nil {
   234  				err = errors.Wrap(err, "failed to unpack Any in msg struct")
   235  				return err
   236  			}
   237  
   238  			fieldType = reflect.TypeOf(anyWrapper)
   239  			field = reflect.ValueOf(anyWrapper)
   240  			// then continue as normal
   241  		}
   242  
   243  		for {
   244  			if fieldType.Kind() == reflect.Ptr {
   245  				fieldType = fieldType.Elem()
   246  
   247  				if field.IsValid() {
   248  					field = field.Elem()
   249  				}
   250  
   251  				continue
   252  			}
   253  
   254  			if fieldType.Kind() == reflect.Interface {
   255  				fieldType = reflect.TypeOf(field.Interface())
   256  				continue
   257  			}
   258  
   259  			if field.Kind() == reflect.Ptr {
   260  				field = field.Elem()
   261  				continue
   262  			}
   263  
   264  			break
   265  		}
   266  
   267  		for {
   268  			if fieldType.Kind() == reflect.Ptr {
   269  				fieldType = fieldType.Elem()
   270  
   271  				if field.IsValid() {
   272  					field = field.Elem()
   273  				}
   274  
   275  				continue
   276  			}
   277  
   278  			if fieldType.Kind() == reflect.Interface {
   279  				fieldType = reflect.TypeOf(field.Interface())
   280  				continue
   281  			}
   282  
   283  			if field.Kind() == reflect.Ptr {
   284  				field = field.Elem()
   285  				continue
   286  			}
   287  
   288  			break
   289  		}
   290  
   291  		fieldPrefix := fmt.Sprintf("%s.%s", prefix, fieldName)
   292  		ethTyp := typToEth(fieldType)
   293  		if ethTyp != "" {
   294  			if isCollection {
   295  				ethTyp += "[]"
   296  			}
   297  			if field.Kind() == reflect.String && field.Len() == 0 {
   298  				// skip empty strings from type mapping
   299  				continue
   300  			}
   301  			if prefix == typeDefPrefix {
   302  				typeMap[rootType] = append(typeMap[rootType], typeddata.Type{
   303  					Name: fieldName,
   304  					Type: ethTyp,
   305  				})
   306  			} else {
   307  				typeDef := sanitizeTypedef(prefix)
   308  				typeMap[typeDef] = append(typeMap[typeDef], typeddata.Type{
   309  					Name: fieldName,
   310  					Type: ethTyp,
   311  				})
   312  			}
   313  
   314  			continue
   315  		}
   316  
   317  		if fieldType.Kind() == reflect.Struct {
   318  			var fieldTypedef string
   319  			if isCollection {
   320  				fieldTypedef = sanitizeTypedef(fieldPrefix) + "[]"
   321  			} else {
   322  				fieldTypedef = sanitizeTypedef(fieldPrefix)
   323  			}
   324  
   325  			if prefix == typeDefPrefix {
   326  				typeMap[rootType] = append(typeMap[rootType], typeddata.Type{
   327  					Name: fieldName,
   328  					Type: fieldTypedef,
   329  				})
   330  			} else {
   331  				typeDef := sanitizeTypedef(prefix)
   332  				typeMap[typeDef] = append(typeMap[typeDef], typeddata.Type{
   333  					Name: fieldName,
   334  					Type: fieldTypedef,
   335  				})
   336  			}
   337  
   338  			err = traverseFields(cdc, typeMap, rootType, fieldPrefix, fieldType, field)
   339  			if err != nil {
   340  				return err
   341  			}
   342  
   343  			continue
   344  		}
   345  	}
   346  
   347  	return nil
   348  }
   349  
   350  func jsonNameFromTag(tag reflect.StructTag) string {
   351  	jsonTags := tag.Get("json")
   352  	parts := strings.Split(jsonTags, ",")
   353  	return parts[0]
   354  }
   355  
   356  // _.foo_bar.baz -> TypeFooBarBaz
   357  // this is needed for Geth's own signing code which doesn't
   358  // tolerate complex type names
   359  func sanitizeTypedef(str string) string {
   360  	buf := new(bytes.Buffer)
   361  	parts := strings.Split(str, ".")
   362  
   363  	for _, part := range parts {
   364  		if part == "_" {
   365  			buf.WriteString("Type")
   366  			continue
   367  		}
   368  
   369  		subparts := strings.Split(part, "_")
   370  		for _, subpart := range subparts {
   371  			buf.WriteString(strings.Title(subpart)) //nolint // strings is used for compat
   372  		}
   373  	}
   374  
   375  	return buf.String()
   376  }
   377  
   378  var (
   379  	hashType      = reflect.TypeOf(common.Hash{})
   380  	addressType   = reflect.TypeOf(common.Address{})
   381  	bigIntType    = reflect.TypeOf(big.Int{})
   382  	cosmIntType   = reflect.TypeOf(math.Int{})
   383  	cosmosAnyType = reflect.TypeOf(&codectypes.Any{})
   384  	timeType      = reflect.TypeOf(time.Time{})
   385  )
   386  
   387  // typToEth supports only basic types and arrays of basic types.
   388  // https://github.com/ethereum/EIPs/blob/master/EIPS/eip-712.md
   389  func typToEth(typ reflect.Type) string {
   390  	switch typ.Kind() {
   391  	case reflect.String:
   392  		return "string"
   393  	case reflect.Bool:
   394  		return "bool"
   395  	case reflect.Int:
   396  		return "int64"
   397  	case reflect.Int8:
   398  		return "int8"
   399  	case reflect.Int16:
   400  		return "int16"
   401  	case reflect.Int32:
   402  		return "int32"
   403  	case reflect.Int64:
   404  		return "int64"
   405  	case reflect.Uint:
   406  		return "uint64"
   407  	case reflect.Uint8:
   408  		return "uint8"
   409  	case reflect.Uint16:
   410  		return "uint16"
   411  	case reflect.Uint32:
   412  		return "uint32"
   413  	case reflect.Uint64:
   414  		return "uint64"
   415  	case reflect.Slice:
   416  		ethName := typToEth(typ.Elem())
   417  		if ethName != "" {
   418  			return ethName + "[]"
   419  		}
   420  	case reflect.Array:
   421  		ethName := typToEth(typ.Elem())
   422  		if ethName != "" {
   423  			return ethName + "[]"
   424  		}
   425  	case reflect.Ptr:
   426  		if typ.Elem().ConvertibleTo(bigIntType) ||
   427  			typ.Elem().ConvertibleTo(timeType) ||
   428  			typ.Elem().ConvertibleTo(cosmIntType) {
   429  			return "string"
   430  		}
   431  	case reflect.Struct:
   432  		if typ.ConvertibleTo(hashType) ||
   433  			typ.ConvertibleTo(addressType) ||
   434  			typ.ConvertibleTo(bigIntType) ||
   435  			typ.ConvertibleTo(timeType) ||
   436  			typ.ConvertibleTo(cosmIntType) {
   437  			return "string"
   438  		}
   439  	}
   440  
   441  	return ""
   442  }
   443  
   444  //nolint:gocritic // this is a handy way to return err in defered funcs
   445  func doRecover(err *error) {
   446  	if r := recover(); r != nil {
   447  		debug.PrintStack()
   448  
   449  		if e, ok := r.(error); ok {
   450  			e = errors.Wrap(e, "panicked with error")
   451  			*err = e
   452  			return
   453  		}
   454  
   455  		*err = errors.Errorf("%v", r)
   456  	}
   457  }
   458  
   459  func signableTypes() typeddata.Types {
   460  	return typeddata.Types{
   461  		"EIP712Domain": {
   462  			{
   463  				Name: "name",
   464  				Type: "string",
   465  			},
   466  			{
   467  				Name: "version",
   468  				Type: "string",
   469  			},
   470  			{
   471  				Name: "chainId",
   472  				Type: "uint256",
   473  			},
   474  			{
   475  				Name: "verifyingContract",
   476  				Type: "address",
   477  			},
   478  			{
   479  				Name: "salt",
   480  				Type: "string",
   481  			},
   482  		},
   483  		"Tx": {
   484  			{Name: "context", Type: "string"},
   485  			{Name: "msgs", Type: "string"},
   486  		},
   487  	}
   488  }
   489  
   490  func WrapTxToEIP712V2(
   491  	cdc codec.ProtoCodecMarshaler,
   492  	chainID uint64,
   493  	signerData *authsigning.SignerData,
   494  	timeoutHeight uint64,
   495  	memo string,
   496  	feeInfo legacytx.StdFee,
   497  	msgs []cosmtypes.Msg,
   498  	feeDelegation *FeeDelegationOptions,
   499  ) (typeddata.TypedData, error) {
   500  	domain := typeddata.TypedDataDomain{
   501  		Name:              "Injective Web3",
   502  		Version:           "1.0.0",
   503  		ChainId:           ethmath.NewHexOrDecimal256(int64(chainID)),
   504  		VerifyingContract: "0xCcCCccccCCCCcCCCCCCcCcCccCcCCCcCcccccccC",
   505  		Salt:              "0",
   506  	}
   507  
   508  	msgTypes := signableTypes()
   509  	msgsJsons := make([]json.RawMessage, len(msgs))
   510  	for idx, m := range msgs {
   511  		bzMsg, err := cdc.MarshalInterfaceJSON(m)
   512  		if err != nil {
   513  			return typeddata.TypedData{}, fmt.Errorf("cannot marshal json at index %d: %w", idx, err)
   514  		}
   515  
   516  		msgsJsons[idx] = bzMsg
   517  	}
   518  
   519  	bzMsgs, err := json.Marshal(msgsJsons)
   520  	if err != nil {
   521  		return typeddata.TypedData{}, fmt.Errorf("marshal json err: %w", err)
   522  	}
   523  
   524  	if feeDelegation != nil {
   525  		feeInfo.Payer = feeDelegation.FeePayer.String()
   526  	}
   527  
   528  	bzFee, err := json.Marshal(feeInfo)
   529  	if err != nil {
   530  		return typeddata.TypedData{}, fmt.Errorf("marshal fee info failed: %w", err)
   531  	}
   532  
   533  	context := map[string]interface{}{
   534  		"account_number": signerData.AccountNumber,
   535  		"sequence":       signerData.Sequence,
   536  		"timeout_height": timeoutHeight,
   537  		"chain_id":       signerData.ChainID,
   538  		"memo":           memo,
   539  		"fee":            json.RawMessage(bzFee),
   540  	}
   541  
   542  	bzTxContext, err := json.Marshal(context)
   543  	if err != nil {
   544  		return typeddata.TypedData{}, fmt.Errorf("marshal json err: %w", err)
   545  	}
   546  
   547  	var typedData = typeddata.TypedData{
   548  		Types:       msgTypes,
   549  		PrimaryType: "Tx",
   550  		Domain:      domain,
   551  		Message: typeddata.TypedDataMessage{
   552  			"context": string(bzTxContext),
   553  			"msgs":    string(bzMsgs),
   554  		},
   555  	}
   556  
   557  	return typedData, nil
   558  }