github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/transport/ws/pool.go (about)

     1  /*
     2  Copyright SecureKey Technologies Inc. All Rights Reserved.
     3  SPDX-License-Identifier: Apache-2.0
     4  */
     5  
     6  package ws
     7  
     8  import (
     9  	"context"
    10  	"encoding/json"
    11  	"fmt"
    12  	"strings"
    13  	"sync"
    14  	"time"
    15  
    16  	"nhooyr.io/websocket"
    17  
    18  	cryptoapi "github.com/hyperledger/aries-framework-go/pkg/crypto"
    19  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator"
    20  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/didexchange"
    21  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/transport"
    22  	"github.com/hyperledger/aries-framework-go/pkg/didcomm/transport/internal"
    23  	"github.com/hyperledger/aries-framework-go/pkg/doc/did"
    24  	"github.com/hyperledger/aries-framework-go/pkg/vdr/fingerprint"
    25  	"github.com/hyperledger/aries-framework-go/pkg/vdr/peer"
    26  )
    27  
    28  const (
    29  	// TODO configure ping request frequency.
    30  	pingFrequency = 30 * time.Second
    31  
    32  	// legacyKeyLen key length.
    33  	legacyKeyLen = 32
    34  )
    35  
    36  type connPool struct {
    37  	connMap map[string]*websocket.Conn
    38  	sync.RWMutex
    39  	packager   transport.Packager
    40  	msgHandler transport.InboundMessageHandler
    41  }
    42  
    43  // nolint: gochecknoglobals
    44  var pool = make(map[string]*connPool)
    45  
    46  func getConnPool(prov transport.Provider) *connPool {
    47  	id := prov.AriesFrameworkID()
    48  
    49  	if _, ok := pool[id]; !ok {
    50  		pool[id] = &connPool{
    51  			connMap:    make(map[string]*websocket.Conn),
    52  			packager:   prov.Packager(),
    53  			msgHandler: prov.InboundMessageHandler(),
    54  		}
    55  	}
    56  
    57  	return pool[id]
    58  }
    59  
    60  func (d *connPool) add(verKey string, wsConn *websocket.Conn) {
    61  	d.Lock()
    62  	defer d.Unlock()
    63  
    64  	d.connMap[verKey] = wsConn
    65  }
    66  
    67  func (d *connPool) fetch(verKey string) *websocket.Conn {
    68  	d.RLock()
    69  	defer d.RUnlock()
    70  
    71  	return d.connMap[verKey]
    72  }
    73  
    74  func (d *connPool) remove(verKey string) {
    75  	d.Lock()
    76  	defer d.Unlock()
    77  
    78  	delete(d.connMap, verKey)
    79  }
    80  
    81  func (d *connPool) listener(conn *websocket.Conn, outbound bool) {
    82  	verKeys := []string{}
    83  
    84  	defer d.close(conn, verKeys)
    85  
    86  	go keepConnAlive(conn, outbound, pingFrequency)
    87  
    88  	for {
    89  		_, message, err := conn.Read(context.Background())
    90  		if err != nil {
    91  			if websocket.CloseStatus(err) != websocket.StatusNormalClosure {
    92  				logger.Errorf("Error reading request message: %v", err)
    93  			}
    94  
    95  			break
    96  		}
    97  
    98  		unpackMsg, err := internal.UnpackMessage(message, d.packager, "ws")
    99  		if err != nil {
   100  			logger.Errorf("%w", err)
   101  
   102  			continue
   103  		}
   104  
   105  		trans := &decorator.Transport{}
   106  
   107  		err = json.Unmarshal(unpackMsg.Message, trans)
   108  		if err != nil {
   109  			logger.Errorf("unmarshal transport decorator : %v", err)
   110  		}
   111  
   112  		d.addKey(unpackMsg, trans, conn)
   113  
   114  		messageHandler := d.msgHandler
   115  
   116  		err = messageHandler(unpackMsg)
   117  		if err != nil {
   118  			logger.Errorf("incoming msg processing failed: %v", err)
   119  		}
   120  	}
   121  }
   122  
   123  func (d *connPool) addKey(unpackMsg *transport.Envelope, trans *decorator.Transport, conn *websocket.Conn) {
   124  	var fromKey string
   125  
   126  	if len(unpackMsg.FromKey) == legacyKeyLen {
   127  		fromKey, _ = fingerprint.CreateDIDKey(unpackMsg.FromKey)
   128  	} else {
   129  		fromPubKey := &cryptoapi.PublicKey{}
   130  
   131  		err := json.Unmarshal(unpackMsg.FromKey, fromPubKey)
   132  		if err != nil {
   133  			logger.Debugf("addKey: unpackMsg.FromKey is not a public key [err: %s]. "+
   134  				"It will not be added to the ws connection.", err)
   135  		} else {
   136  			fromKey = fromPubKey.KID
   137  		}
   138  	}
   139  
   140  	if trans.ReturnRoute != nil && trans.ReturnRoute.Value == decorator.TransportReturnRouteAll {
   141  		if fromKey != "" {
   142  			d.add(fromKey, conn)
   143  		}
   144  
   145  		keyAgreementIDs := checkKeyAgreementIDs(unpackMsg.Message)
   146  
   147  		for _, kaID := range keyAgreementIDs {
   148  			d.add(kaID, conn)
   149  		}
   150  
   151  		if fromKey == "" && len(keyAgreementIDs) == 0 {
   152  			logger.Warnf("addKey: no key is linked to ws connection.")
   153  		}
   154  	}
   155  }
   156  
   157  func (d *connPool) close(conn *websocket.Conn, verKeys []string) {
   158  	if err := conn.Close(websocket.StatusNormalClosure,
   159  		"closing the connection"); websocket.CloseStatus(err) != websocket.StatusNormalClosure {
   160  		logger.Errorf("connection close error")
   161  	}
   162  
   163  	for _, v := range verKeys {
   164  		d.remove(v)
   165  	}
   166  }
   167  
   168  func checkKeyAgreementIDs(message []byte) []string {
   169  	var err1, err2 error
   170  
   171  	var doc *did.Doc
   172  
   173  	doc, err1 = didCommV1PeerDoc(message)
   174  
   175  	if err1 != nil {
   176  		doc, err2 = didCommV2PeerDoc(message)
   177  	}
   178  
   179  	if err1 != nil && err2 != nil {
   180  		logger.Debugf("failed to find a DIDComm DID doc in websocket message, will not add any keyAgreementIDs."+
   181  			" DIDComm V1 parse result=[%s], DIDComm V2 parse result=[%s]", err1.Error(), err2.Error())
   182  
   183  		return nil
   184  	}
   185  
   186  	return docKeyAgreementIDs(doc)
   187  }
   188  
   189  func didCommV1PeerDoc(message []byte) (*did.Doc, error) {
   190  	req := &didexchange.Request{}
   191  
   192  	err := json.Unmarshal(message, req)
   193  	if err != nil {
   194  		return nil, fmt.Errorf("unmarshal request message failed: %w", err)
   195  	}
   196  
   197  	if req.DocAttach == nil {
   198  		return nil, fmt.Errorf("fetch message attachment/attachmentData is empty")
   199  	}
   200  
   201  	data, err := req.DocAttach.Data.Fetch()
   202  	if err != nil {
   203  		return nil, fmt.Errorf("fetch message attachment data failed: %w", err)
   204  	}
   205  
   206  	doc := &did.Doc{}
   207  
   208  	err = json.Unmarshal(data, doc)
   209  	if err != nil {
   210  		return nil, fmt.Errorf("unmarshal DID doc from attachment data failed: %w", err)
   211  	}
   212  
   213  	return doc, nil
   214  }
   215  
   216  type msgFromField struct {
   217  	From string `json:"from"`
   218  }
   219  
   220  func didCommV2PeerDoc(message []byte) (*did.Doc, error) {
   221  	msg := &msgFromField{}
   222  
   223  	err := json.Unmarshal(message, msg)
   224  	if err != nil {
   225  		return nil, fmt.Errorf("unmarshal message as didcomm/v2 failed: %w", err)
   226  	}
   227  
   228  	if msg.From == "" {
   229  		return nil, fmt.Errorf("message has no didcomm/v2 'from' field")
   230  	}
   231  
   232  	didURL, err := did.ParseDIDURL(msg.From)
   233  	if err != nil {
   234  		return nil, fmt.Errorf("'from' field not did url: %w", err)
   235  	}
   236  
   237  	if didURL.Method != "peer" {
   238  		return nil, fmt.Errorf("'from' DID not peer DID")
   239  	}
   240  
   241  	stateQueries := didURL.Queries["initialState"]
   242  	if len(stateQueries) == 0 {
   243  		return nil, fmt.Errorf("peer DID URL has no initialState parameter")
   244  	}
   245  
   246  	doc, err := peer.DocFromGenesisDelta(stateQueries[0])
   247  	if err != nil {
   248  		return nil, fmt.Errorf("failed to parse initialState into DID doc: %w", err)
   249  	}
   250  
   251  	return doc, nil
   252  }
   253  
   254  func docKeyAgreementIDs(doc *did.Doc) []string {
   255  	var keyAgreementIDs []string
   256  
   257  	for _, ka := range doc.KeyAgreement {
   258  		kaID := ka.VerificationMethod.ID
   259  		if strings.HasPrefix(kaID, "#") {
   260  			kaID = doc.ID + kaID
   261  		}
   262  
   263  		keyAgreementIDs = append(keyAgreementIDs, kaID)
   264  	}
   265  
   266  	return keyAgreementIDs
   267  }