github.com/hyperledger/aries-framework-go@v0.3.2/pkg/didcomm/transport/ws/pool.go (about) 1 /* 2 Copyright SecureKey Technologies Inc. All Rights Reserved. 3 SPDX-License-Identifier: Apache-2.0 4 */ 5 6 package ws 7 8 import ( 9 "context" 10 "encoding/json" 11 "fmt" 12 "strings" 13 "sync" 14 "time" 15 16 "nhooyr.io/websocket" 17 18 cryptoapi "github.com/hyperledger/aries-framework-go/pkg/crypto" 19 "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/decorator" 20 "github.com/hyperledger/aries-framework-go/pkg/didcomm/protocol/didexchange" 21 "github.com/hyperledger/aries-framework-go/pkg/didcomm/transport" 22 "github.com/hyperledger/aries-framework-go/pkg/didcomm/transport/internal" 23 "github.com/hyperledger/aries-framework-go/pkg/doc/did" 24 "github.com/hyperledger/aries-framework-go/pkg/vdr/fingerprint" 25 "github.com/hyperledger/aries-framework-go/pkg/vdr/peer" 26 ) 27 28 const ( 29 // TODO configure ping request frequency. 30 pingFrequency = 30 * time.Second 31 32 // legacyKeyLen key length. 33 legacyKeyLen = 32 34 ) 35 36 type connPool struct { 37 connMap map[string]*websocket.Conn 38 sync.RWMutex 39 packager transport.Packager 40 msgHandler transport.InboundMessageHandler 41 } 42 43 // nolint: gochecknoglobals 44 var pool = make(map[string]*connPool) 45 46 func getConnPool(prov transport.Provider) *connPool { 47 id := prov.AriesFrameworkID() 48 49 if _, ok := pool[id]; !ok { 50 pool[id] = &connPool{ 51 connMap: make(map[string]*websocket.Conn), 52 packager: prov.Packager(), 53 msgHandler: prov.InboundMessageHandler(), 54 } 55 } 56 57 return pool[id] 58 } 59 60 func (d *connPool) add(verKey string, wsConn *websocket.Conn) { 61 d.Lock() 62 defer d.Unlock() 63 64 d.connMap[verKey] = wsConn 65 } 66 67 func (d *connPool) fetch(verKey string) *websocket.Conn { 68 d.RLock() 69 defer d.RUnlock() 70 71 return d.connMap[verKey] 72 } 73 74 func (d *connPool) remove(verKey string) { 75 d.Lock() 76 defer d.Unlock() 77 78 delete(d.connMap, verKey) 79 } 80 81 func (d *connPool) listener(conn *websocket.Conn, outbound bool) { 82 verKeys := []string{} 83 84 defer d.close(conn, verKeys) 85 86 go keepConnAlive(conn, outbound, pingFrequency) 87 88 for { 89 _, message, err := conn.Read(context.Background()) 90 if err != nil { 91 if websocket.CloseStatus(err) != websocket.StatusNormalClosure { 92 logger.Errorf("Error reading request message: %v", err) 93 } 94 95 break 96 } 97 98 unpackMsg, err := internal.UnpackMessage(message, d.packager, "ws") 99 if err != nil { 100 logger.Errorf("%w", err) 101 102 continue 103 } 104 105 trans := &decorator.Transport{} 106 107 err = json.Unmarshal(unpackMsg.Message, trans) 108 if err != nil { 109 logger.Errorf("unmarshal transport decorator : %v", err) 110 } 111 112 d.addKey(unpackMsg, trans, conn) 113 114 messageHandler := d.msgHandler 115 116 err = messageHandler(unpackMsg) 117 if err != nil { 118 logger.Errorf("incoming msg processing failed: %v", err) 119 } 120 } 121 } 122 123 func (d *connPool) addKey(unpackMsg *transport.Envelope, trans *decorator.Transport, conn *websocket.Conn) { 124 var fromKey string 125 126 if len(unpackMsg.FromKey) == legacyKeyLen { 127 fromKey, _ = fingerprint.CreateDIDKey(unpackMsg.FromKey) 128 } else { 129 fromPubKey := &cryptoapi.PublicKey{} 130 131 err := json.Unmarshal(unpackMsg.FromKey, fromPubKey) 132 if err != nil { 133 logger.Debugf("addKey: unpackMsg.FromKey is not a public key [err: %s]. "+ 134 "It will not be added to the ws connection.", err) 135 } else { 136 fromKey = fromPubKey.KID 137 } 138 } 139 140 if trans.ReturnRoute != nil && trans.ReturnRoute.Value == decorator.TransportReturnRouteAll { 141 if fromKey != "" { 142 d.add(fromKey, conn) 143 } 144 145 keyAgreementIDs := checkKeyAgreementIDs(unpackMsg.Message) 146 147 for _, kaID := range keyAgreementIDs { 148 d.add(kaID, conn) 149 } 150 151 if fromKey == "" && len(keyAgreementIDs) == 0 { 152 logger.Warnf("addKey: no key is linked to ws connection.") 153 } 154 } 155 } 156 157 func (d *connPool) close(conn *websocket.Conn, verKeys []string) { 158 if err := conn.Close(websocket.StatusNormalClosure, 159 "closing the connection"); websocket.CloseStatus(err) != websocket.StatusNormalClosure { 160 logger.Errorf("connection close error") 161 } 162 163 for _, v := range verKeys { 164 d.remove(v) 165 } 166 } 167 168 func checkKeyAgreementIDs(message []byte) []string { 169 var err1, err2 error 170 171 var doc *did.Doc 172 173 doc, err1 = didCommV1PeerDoc(message) 174 175 if err1 != nil { 176 doc, err2 = didCommV2PeerDoc(message) 177 } 178 179 if err1 != nil && err2 != nil { 180 logger.Debugf("failed to find a DIDComm DID doc in websocket message, will not add any keyAgreementIDs."+ 181 " DIDComm V1 parse result=[%s], DIDComm V2 parse result=[%s]", err1.Error(), err2.Error()) 182 183 return nil 184 } 185 186 return docKeyAgreementIDs(doc) 187 } 188 189 func didCommV1PeerDoc(message []byte) (*did.Doc, error) { 190 req := &didexchange.Request{} 191 192 err := json.Unmarshal(message, req) 193 if err != nil { 194 return nil, fmt.Errorf("unmarshal request message failed: %w", err) 195 } 196 197 if req.DocAttach == nil { 198 return nil, fmt.Errorf("fetch message attachment/attachmentData is empty") 199 } 200 201 data, err := req.DocAttach.Data.Fetch() 202 if err != nil { 203 return nil, fmt.Errorf("fetch message attachment data failed: %w", err) 204 } 205 206 doc := &did.Doc{} 207 208 err = json.Unmarshal(data, doc) 209 if err != nil { 210 return nil, fmt.Errorf("unmarshal DID doc from attachment data failed: %w", err) 211 } 212 213 return doc, nil 214 } 215 216 type msgFromField struct { 217 From string `json:"from"` 218 } 219 220 func didCommV2PeerDoc(message []byte) (*did.Doc, error) { 221 msg := &msgFromField{} 222 223 err := json.Unmarshal(message, msg) 224 if err != nil { 225 return nil, fmt.Errorf("unmarshal message as didcomm/v2 failed: %w", err) 226 } 227 228 if msg.From == "" { 229 return nil, fmt.Errorf("message has no didcomm/v2 'from' field") 230 } 231 232 didURL, err := did.ParseDIDURL(msg.From) 233 if err != nil { 234 return nil, fmt.Errorf("'from' field not did url: %w", err) 235 } 236 237 if didURL.Method != "peer" { 238 return nil, fmt.Errorf("'from' DID not peer DID") 239 } 240 241 stateQueries := didURL.Queries["initialState"] 242 if len(stateQueries) == 0 { 243 return nil, fmt.Errorf("peer DID URL has no initialState parameter") 244 } 245 246 doc, err := peer.DocFromGenesisDelta(stateQueries[0]) 247 if err != nil { 248 return nil, fmt.Errorf("failed to parse initialState into DID doc: %w", err) 249 } 250 251 return doc, nil 252 } 253 254 func docKeyAgreementIDs(doc *did.Doc) []string { 255 var keyAgreementIDs []string 256 257 for _, ka := range doc.KeyAgreement { 258 kaID := ka.VerificationMethod.ID 259 if strings.HasPrefix(kaID, "#") { 260 kaID = doc.ID + kaID 261 } 262 263 keyAgreementIDs = append(keyAgreementIDs, kaID) 264 } 265 266 return keyAgreementIDs 267 }