github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/common/service/did_comm_msg.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  SPDX-License-Identifier: Apache-2.0
     4  */
     5  
     6  package service
     7  
     8  import (
     9  	"encoding/base64"
    10  	"encoding/json"
    11  	"errors"
    12  	"fmt"
    13  	"reflect"
    14  	"strings"
    15  	"time"
    16  
    17  	"github.com/google/uuid"
    18  	"github.com/mitchellh/mapstructure"
    19  
    20  	"github.com/hyperledger/aries-framework-go/pkg/doc/did"
    21  )
    22  
    23  const (
    24  	jsonIDV1           = "@id"
    25  	jsonIDV2           = "id"
    26  	jsonTypeV1         = "@type"
    27  	jsonTypeV2         = "type"
    28  	jsonThread         = "~thread"
    29  	jsonThreadID       = "thid"
    30  	jsonParentThreadID = "pthid"
    31  	jsonMetadata       = "_internal_metadata"
    32  
    33  	basePIURI = "https://didcomm.org/"
    34  	oldPIURI  = "did:sov:BzCbsNYhMrjHiqZDTUASHg;spec/"
    35  )
    36  
    37  // Version represents DIDComm protocol version.
    38  type Version string
    39  
    40  // DIDComm versions.
    41  const (
    42  	V1 Version = "v1"
    43  	V2 Version = "v2"
    44  )
    45  
    46  // Metadata may contain additional payload for the protocol. It might be populated by the client/protocol
    47  // for outbound messages. If metadata were populated, the messenger will automatically add it to the incoming
    48  // messages by the threadID. If Metadata is <nil> in the outbound message the previous payload
    49  // will be added to the incoming message. Otherwise, the payload will be rewritten.
    50  // NOTE: Metadata is not a part of the JSON message. The payload will not be sent to another agent.
    51  // Metadata should be used by embedding it to the model structure. e.g
    52  // 	type A struct {
    53  // 		Metadata `json:",squash"`
    54  // 	}
    55  type Metadata struct {
    56  	Payload map[string]interface{} `json:"_internal_metadata,omitempty"`
    57  }
    58  
    59  // DIDCommMsgMap did comm msg.
    60  type DIDCommMsgMap map[string]interface{}
    61  
    62  // UnmarshalJSON implements the json.Unmarshaler interface.
    63  func (m *DIDCommMsgMap) UnmarshalJSON(b []byte) error {
    64  	defer func() {
    65  		if (*m) != nil {
    66  			// sets empty metadata
    67  			(*m)[jsonMetadata] = map[string]interface{}{}
    68  		}
    69  	}()
    70  
    71  	return json.Unmarshal(b, (*map[string]interface{})(m))
    72  }
    73  
    74  // MarshalJSON implements the json.Marshaler interface.
    75  func (m DIDCommMsgMap) MarshalJSON() ([]byte, error) {
    76  	if m != nil {
    77  		metadata := m[jsonMetadata]
    78  		delete(m, jsonMetadata)
    79  
    80  		defer func() { m[jsonMetadata] = metadata }()
    81  	}
    82  
    83  	return json.Marshal(map[string]interface{}(m))
    84  }
    85  
    86  // ParseDIDCommMsgMap returns DIDCommMsg with Header.
    87  func ParseDIDCommMsgMap(payload []byte) (DIDCommMsgMap, error) {
    88  	var msg DIDCommMsgMap
    89  
    90  	err := json.Unmarshal(payload, &msg)
    91  	if err != nil {
    92  		return nil, fmt.Errorf("invalid payload data format: %w", err)
    93  	}
    94  
    95  	// Interop: accept old PIURI when it's used, as we handle backwards-compatibility at a more fine-grained level.
    96  	_, ok := msg[jsonTypeV1]
    97  	if typ := msg.Type(); typ != "" && ok {
    98  		msg[jsonTypeV1] = strings.Replace(typ, oldPIURI, basePIURI, 1)
    99  	}
   100  
   101  	return msg, nil
   102  }
   103  
   104  // IsDIDCommV2 returns true iff the message is a DIDComm/v2 message, false iff the message is a DIDComm/v1 message,
   105  // and an error if neither case applies.
   106  func IsDIDCommV2(msg *DIDCommMsgMap) (bool, error) {
   107  	_, hasIDV2 := (*msg)["id"]
   108  	_, hasTypeV2 := (*msg)["type"]
   109  	// TODO: some present-proof v3 messages forget to include the body, enable the hasBodyV2 check when that is fixed.
   110  	// TODO: see issue: https://github.com/hyperledger/aries-framework-go/issues/3039
   111  	// _, hasBodyV2 := (*msg)["body"]
   112  
   113  	if hasIDV2 || hasTypeV2 /* && hasBodyV2 */ {
   114  		return true, nil
   115  	}
   116  
   117  	_, hasIDV1 := (*msg)["@id"]
   118  	_, hasTypeV1 := (*msg)["@type"]
   119  
   120  	if hasIDV1 || hasTypeV1 {
   121  		return false, nil
   122  	}
   123  
   124  	return false, fmt.Errorf("not a valid didcomm v1 or v2 message")
   125  }
   126  
   127  // NewDIDCommMsgMap converts structure(model) to DIDCommMsgMap.
   128  func NewDIDCommMsgMap(v interface{}) DIDCommMsgMap {
   129  	// NOTE: do not try to replace it with mapstructure pkg
   130  	// it doesn't work as expected, at least time.Time won't be converted
   131  	msg := toMap(v)
   132  
   133  	// sets empty metadata
   134  	msg[jsonMetadata] = map[string]interface{}{}
   135  
   136  	_, hasIDV1 := msg["@id"]
   137  	_, hasTypeV1 := msg["@type"]
   138  	_, hasIDV2 := msg["id"]
   139  	_, hasTypeV2 := msg["type"]
   140  
   141  	if hasIDV1 || hasIDV2 {
   142  		return msg
   143  	}
   144  
   145  	if hasTypeV2 && !hasIDV2 {
   146  		msg["id"] = uuid.New().String()
   147  	} else if hasTypeV1 && !hasIDV1 {
   148  		msg["@id"] = uuid.New().String()
   149  	}
   150  
   151  	return msg
   152  }
   153  
   154  // ThreadID returns msg ~thread.thid if there is no ~thread.thid returns msg @id
   155  // message is invalid if ~thread.thid exist and @id is absent.
   156  func (m DIDCommMsgMap) ThreadID() (string, error) {
   157  	if m == nil {
   158  		return "", ErrInvalidMessage
   159  	}
   160  
   161  	thid, err := m.threadIDV1()
   162  	if err == nil || !errors.Is(err, ErrThreadIDNotFound) {
   163  		return thid, err
   164  	}
   165  
   166  	return m.threadIDV2()
   167  }
   168  
   169  func (m DIDCommMsgMap) threadIDV2() (string, error) {
   170  	id := m.idV2()
   171  
   172  	threadID, ok := m[jsonThreadID].(string)
   173  	if ok && threadID != "" {
   174  		if id == "" {
   175  			return "", ErrInvalidMessage
   176  		}
   177  
   178  		return threadID, nil
   179  	}
   180  
   181  	if id != "" {
   182  		return id, nil
   183  	}
   184  
   185  	return "", ErrThreadIDNotFound
   186  }
   187  
   188  func (m DIDCommMsgMap) threadIDV1() (string, error) {
   189  	msgID := m.idV1()
   190  	thread, ok := m[jsonThread].(map[string]interface{})
   191  
   192  	if ok && thread[jsonThreadID] != nil {
   193  		var thID string
   194  		if v, ok := thread[jsonThreadID].(string); ok {
   195  			thID = v
   196  		}
   197  
   198  		// if message has ~thread.thid but @id is absent this is invalid message
   199  		if len(thID) > 0 && msgID == "" {
   200  			return "", ErrInvalidMessage
   201  		}
   202  
   203  		if len(thID) > 0 {
   204  			return thID, nil
   205  		}
   206  	}
   207  
   208  	// we need to return it only if there is no ~thread.thid
   209  	if len(msgID) > 0 {
   210  		return msgID, nil
   211  	}
   212  
   213  	return "", ErrThreadIDNotFound
   214  }
   215  
   216  // Metadata returns message metadata.
   217  func (m DIDCommMsgMap) Metadata() map[string]interface{} {
   218  	if m[jsonMetadata] == nil {
   219  		return nil
   220  	}
   221  
   222  	metadata, ok := m[jsonMetadata].(map[string]interface{})
   223  	if !ok {
   224  		return nil
   225  	}
   226  
   227  	return metadata
   228  }
   229  
   230  func (m DIDCommMsgMap) typeV1() string {
   231  	if m == nil || m[jsonTypeV1] == nil {
   232  		return ""
   233  	}
   234  
   235  	res, ok := m[jsonTypeV1].(string)
   236  	if !ok {
   237  		return ""
   238  	}
   239  
   240  	return res
   241  }
   242  
   243  func (m DIDCommMsgMap) typeV2() string {
   244  	if m == nil || m[jsonTypeV2] == nil {
   245  		return ""
   246  	}
   247  
   248  	res, ok := m[jsonTypeV2].(string)
   249  	if !ok {
   250  		return ""
   251  	}
   252  
   253  	return res
   254  }
   255  
   256  // Type returns the message type.
   257  func (m DIDCommMsgMap) Type() string {
   258  	if val := m.typeV1(); val != "" {
   259  		return val
   260  	}
   261  
   262  	return m.typeV2()
   263  }
   264  
   265  // ParentThreadID returns the message parent threadID.
   266  func (m DIDCommMsgMap) ParentThreadID() string {
   267  	if m == nil {
   268  		return ""
   269  	}
   270  
   271  	parentThreadID, ok := m[jsonParentThreadID].(string)
   272  	if ok && parentThreadID != "" {
   273  		return parentThreadID
   274  	}
   275  
   276  	if m[jsonThread] == nil {
   277  		return ""
   278  	}
   279  
   280  	if thread, ok := m[jsonThread].(map[string]interface{}); ok && thread != nil {
   281  		if pthID, ok := thread[jsonParentThreadID].(string); ok && pthID != "" {
   282  			return pthID
   283  		}
   284  	}
   285  
   286  	return ""
   287  }
   288  
   289  func (m DIDCommMsgMap) idV1() string {
   290  	if m == nil || m[jsonIDV1] == nil {
   291  		return ""
   292  	}
   293  
   294  	res, ok := m[jsonIDV1].(string)
   295  	if !ok {
   296  		return ""
   297  	}
   298  
   299  	return res
   300  }
   301  
   302  func (m DIDCommMsgMap) idV2() string {
   303  	if m == nil || m[jsonIDV2] == nil {
   304  		return ""
   305  	}
   306  
   307  	res, ok := m[jsonIDV2].(string)
   308  	if !ok {
   309  		return ""
   310  	}
   311  
   312  	return res
   313  }
   314  
   315  // ID returns the message id.
   316  func (m DIDCommMsgMap) ID() string {
   317  	if val := m.idV1(); val != "" {
   318  		return val
   319  	}
   320  
   321  	return m.idV2()
   322  }
   323  
   324  // Opt represents an option.
   325  type Opt func(o *options)
   326  
   327  type options struct {
   328  	V Version
   329  }
   330  
   331  func getOptions(opts ...Opt) *options {
   332  	o := &options{}
   333  
   334  	for i := range opts {
   335  		opts[i](o)
   336  	}
   337  
   338  	if o.V == "" {
   339  		o.V = V1
   340  	}
   341  
   342  	return o
   343  }
   344  
   345  // WithVersion specifies which version to use.
   346  func WithVersion(v Version) Opt {
   347  	return func(o *options) {
   348  		o.V = v
   349  	}
   350  }
   351  
   352  // SetID sets the message id.
   353  func (m DIDCommMsgMap) SetID(id string, opts ...Opt) {
   354  	if m == nil {
   355  		return
   356  	}
   357  
   358  	o := getOptions(opts...)
   359  
   360  	if o.V == V2 {
   361  		m[jsonIDV2] = id
   362  
   363  		return
   364  	}
   365  
   366  	m[jsonIDV1] = id
   367  }
   368  
   369  // SetThread sets the message thread.
   370  func (m DIDCommMsgMap) SetThread(thid, pthid string, opts ...Opt) {
   371  	if m == nil {
   372  		return
   373  	}
   374  
   375  	if thid == "" && pthid == "" {
   376  		return
   377  	}
   378  
   379  	o := getOptions(opts...)
   380  
   381  	if o.V == V2 {
   382  		if thid != "" {
   383  			m[jsonThreadID] = thid
   384  		}
   385  
   386  		if pthid != "" {
   387  			m[jsonParentThreadID] = pthid
   388  		}
   389  
   390  		return
   391  	}
   392  
   393  	thread := map[string]interface{}{}
   394  
   395  	if thid != "" {
   396  		thread[jsonThreadID] = thid
   397  	}
   398  
   399  	if pthid != "" {
   400  		thread[jsonParentThreadID] = pthid
   401  	}
   402  
   403  	m[jsonThread] = thread
   404  }
   405  
   406  // UnsetThread unsets thread.
   407  func (m DIDCommMsgMap) UnsetThread() {
   408  	if m == nil {
   409  		return
   410  	}
   411  
   412  	delete(m, jsonThread)
   413  	delete(m, jsonThreadID)
   414  	delete(m, jsonParentThreadID)
   415  }
   416  
   417  // MsgMapDecoder is implemented by objects that handle their own parsing from DIDCommMsgMap.
   418  type MsgMapDecoder interface {
   419  	FromDIDCommMsgMap(msgMap DIDCommMsgMap) error
   420  }
   421  
   422  // Decode converts message to  struct.
   423  func (m DIDCommMsgMap) Decode(v interface{}) error {
   424  	if dec, ok := v.(MsgMapDecoder); ok {
   425  		return dec.FromDIDCommMsgMap(m)
   426  	}
   427  
   428  	decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
   429  		DecodeHook:       decodeHook,
   430  		WeaklyTypedInput: true,
   431  		Result:           v,
   432  		TagName:          "json",
   433  	})
   434  	if err != nil {
   435  		return err
   436  	}
   437  
   438  	return decoder.Decode(m)
   439  }
   440  
   441  // Clone copies first level keys-values into another map (DIDCommMsgMap).
   442  func (m DIDCommMsgMap) Clone() DIDCommMsgMap {
   443  	if m == nil {
   444  		return nil
   445  	}
   446  
   447  	msg := DIDCommMsgMap{}
   448  	for k, v := range m {
   449  		msg[k] = v
   450  	}
   451  
   452  	return msg
   453  }
   454  
   455  func toMap(v interface{}) map[string]interface{} {
   456  	res := make(map[string]interface{})
   457  
   458  	// if it is pointer returns the value
   459  	rv := reflect.Indirect(reflect.ValueOf(v))
   460  	for rfv, field := range mapValueStructField(rv) {
   461  		// the default name is equal to field Name
   462  		name := field.Name
   463  
   464  		tags := strings.Split(field.Tag.Get(`json`), ",")
   465  		// if tag is not empty name is equal to tag
   466  		if tags[0] != "" {
   467  			name = tags[0]
   468  		}
   469  
   470  		res[name] = convert(rfv)
   471  	}
   472  
   473  	return res
   474  }
   475  
   476  func mapValueStructField(value reflect.Value) map[reflect.Value]reflect.StructField {
   477  	fields := make(map[reflect.Value]reflect.StructField)
   478  	rt := value.Type()
   479  
   480  	for i := 0; i < rt.NumField(); i++ {
   481  		rv, sf := value.Field(i), rt.Field(i)
   482  
   483  		tags := strings.Split(sf.Tag.Get(`json`), ",")
   484  
   485  		// the field should be ignored according to JSON tag `json:"-"`
   486  		if tags[0] == "-" {
   487  			continue
   488  		}
   489  
   490  		// the field should be ignored if it is empty according to JSON tag `json:",omitempty"`
   491  		// NOTE: works when omitempty it the last one
   492  		if tags[len(tags)-1] == "omitempty" {
   493  			if reflect.DeepEqual(reflect.Zero(rv.Type()).Interface(), rv.Interface()) {
   494  				continue
   495  			}
   496  		}
   497  
   498  		// unexported fields should be ignored as well
   499  		if sf.PkgPath != "" {
   500  			continue
   501  		}
   502  
   503  		// if it is an embedded field, we need to add it to the map
   504  		// NOTE: for now, the only embedded structure is supported
   505  		rv = reflect.Indirect(rv)
   506  		if sf.Anonymous && rv.Kind() == reflect.Struct {
   507  			// if an embedded field doesn't have a tag it means the same level
   508  			if tags[0] == "" {
   509  				for k, v := range mapValueStructField(rv) {
   510  					fields[k] = v
   511  				}
   512  
   513  				continue
   514  			}
   515  		}
   516  
   517  		fields[rv] = sf
   518  	}
   519  
   520  	return fields
   521  }
   522  
   523  func convert(val reflect.Value) interface{} {
   524  	switch reflect.Indirect(val).Kind() {
   525  	case reflect.Array, reflect.Slice:
   526  		res := make([]interface{}, val.Len())
   527  		for i := range res {
   528  			res[i] = convert(val.Index(i))
   529  		}
   530  
   531  		return res
   532  	case reflect.Map:
   533  		res := make(map[string]interface{}, val.Len())
   534  		for _, k := range val.MapKeys() {
   535  			res[k.String()] = convert(val.MapIndex(k))
   536  		}
   537  
   538  		return res
   539  	case reflect.Struct:
   540  		if res := toMap(val.Interface()); len(res) != 0 {
   541  			return res
   542  		}
   543  
   544  		return val.Interface()
   545  	}
   546  
   547  	return val.Interface()
   548  }
   549  
   550  func decodeHook(rt1, rt2 reflect.Type, v interface{}) (interface{}, error) {
   551  	if rt1.Kind() == reflect.String {
   552  		if rt2 == reflect.TypeOf(time.Time{}) {
   553  			return time.Parse(time.RFC3339, v.(string))
   554  		}
   555  
   556  		if rt2.Kind() == reflect.Slice && rt2.Elem().Kind() == reflect.Uint8 {
   557  			return base64.StdEncoding.DecodeString(v.(string))
   558  		}
   559  	}
   560  
   561  	if rt1.Kind() == reflect.Map && rt2.Kind() == reflect.Slice && rt2.Elem().Kind() == reflect.Uint8 {
   562  		return json.Marshal(v)
   563  	}
   564  
   565  	if rt2 == reflect.TypeOf(did.Doc{}) {
   566  		didDoc, err := json.Marshal(v)
   567  		if err != nil {
   568  			return nil, fmt.Errorf("error remarshaling to json: %w", err)
   569  		}
   570  
   571  		return did.ParseDocument(didDoc)
   572  	}
   573  
   574  	return v, nil
   575  }