github.com/TrueCloudLab/frostfs-api-go/v2@v2.0.0-20230228134343-196241c4e79a/signature/sign.go (about) 1 package signature 2 3 import ( 4 "crypto/ecdsa" 5 "errors" 6 "fmt" 7 8 "github.com/TrueCloudLab/frostfs-api-go/v2/accounting" 9 "github.com/TrueCloudLab/frostfs-api-go/v2/container" 10 "github.com/TrueCloudLab/frostfs-api-go/v2/netmap" 11 "github.com/TrueCloudLab/frostfs-api-go/v2/object" 12 "github.com/TrueCloudLab/frostfs-api-go/v2/refs" 13 "github.com/TrueCloudLab/frostfs-api-go/v2/reputation" 14 "github.com/TrueCloudLab/frostfs-api-go/v2/session" 15 "github.com/TrueCloudLab/frostfs-api-go/v2/util/signature" 16 ) 17 18 type serviceRequest interface { 19 GetMetaHeader() *session.RequestMetaHeader 20 GetVerificationHeader() *session.RequestVerificationHeader 21 SetVerificationHeader(*session.RequestVerificationHeader) 22 } 23 24 type serviceResponse interface { 25 GetMetaHeader() *session.ResponseMetaHeader 26 GetVerificationHeader() *session.ResponseVerificationHeader 27 SetVerificationHeader(*session.ResponseVerificationHeader) 28 } 29 30 type stableMarshaler interface { 31 StableMarshal([]byte) []byte 32 StableSize() int 33 } 34 35 type StableMarshalerWrapper struct { 36 SM stableMarshaler 37 } 38 39 type metaHeader interface { 40 stableMarshaler 41 getOrigin() metaHeader 42 } 43 44 type verificationHeader interface { 45 stableMarshaler 46 47 GetBodySignature() *refs.Signature 48 SetBodySignature(*refs.Signature) 49 GetMetaSignature() *refs.Signature 50 SetMetaSignature(*refs.Signature) 51 GetOriginSignature() *refs.Signature 52 SetOriginSignature(*refs.Signature) 53 54 setOrigin(stableMarshaler) 55 getOrigin() verificationHeader 56 } 57 58 type requestMetaHeader struct { 59 *session.RequestMetaHeader 60 } 61 62 type responseMetaHeader struct { 63 *session.ResponseMetaHeader 64 } 65 66 type requestVerificationHeader struct { 67 *session.RequestVerificationHeader 68 } 69 70 type responseVerificationHeader struct { 71 *session.ResponseVerificationHeader 72 } 73 74 func (h *requestMetaHeader) getOrigin() metaHeader { 75 return &requestMetaHeader{ 76 RequestMetaHeader: h.GetOrigin(), 77 } 78 } 79 80 func (h *responseMetaHeader) getOrigin() metaHeader { 81 return &responseMetaHeader{ 82 ResponseMetaHeader: h.GetOrigin(), 83 } 84 } 85 86 func (h *requestVerificationHeader) getOrigin() verificationHeader { 87 if origin := h.GetOrigin(); origin != nil { 88 return &requestVerificationHeader{ 89 RequestVerificationHeader: origin, 90 } 91 } 92 93 return nil 94 } 95 96 func (h *requestVerificationHeader) setOrigin(m stableMarshaler) { 97 if m != nil { 98 h.SetOrigin(m.(*session.RequestVerificationHeader)) 99 } 100 } 101 102 func (r *responseVerificationHeader) getOrigin() verificationHeader { 103 if origin := r.GetOrigin(); origin != nil { 104 return &responseVerificationHeader{ 105 ResponseVerificationHeader: origin, 106 } 107 } 108 109 return nil 110 } 111 112 func (r *responseVerificationHeader) setOrigin(m stableMarshaler) { 113 if m != nil { 114 r.SetOrigin(m.(*session.ResponseVerificationHeader)) 115 } 116 } 117 118 func (s StableMarshalerWrapper) ReadSignedData(buf []byte) ([]byte, error) { 119 if s.SM != nil { 120 return s.SM.StableMarshal(buf), nil 121 } 122 123 return nil, nil 124 } 125 126 func (s StableMarshalerWrapper) SignedDataSize() int { 127 if s.SM != nil { 128 return s.SM.StableSize() 129 } 130 131 return 0 132 } 133 134 func SignServiceMessage(key *ecdsa.PrivateKey, msg interface{}) error { 135 var ( 136 body, meta, verifyOrigin stableMarshaler 137 verifyHdr verificationHeader 138 verifyHdrSetter func(verificationHeader) 139 ) 140 141 switch v := msg.(type) { 142 case nil: 143 return nil 144 case serviceRequest: 145 body = serviceMessageBody(v) 146 meta = v.GetMetaHeader() 147 verifyHdr = &requestVerificationHeader{new(session.RequestVerificationHeader)} 148 verifyHdrSetter = func(h verificationHeader) { 149 v.SetVerificationHeader(h.(*requestVerificationHeader).RequestVerificationHeader) 150 } 151 152 if h := v.GetVerificationHeader(); h != nil { 153 verifyOrigin = h 154 } 155 case serviceResponse: 156 body = serviceMessageBody(v) 157 meta = v.GetMetaHeader() 158 verifyHdr = &responseVerificationHeader{new(session.ResponseVerificationHeader)} 159 verifyHdrSetter = func(h verificationHeader) { 160 v.SetVerificationHeader(h.(*responseVerificationHeader).ResponseVerificationHeader) 161 } 162 163 if h := v.GetVerificationHeader(); h != nil { 164 verifyOrigin = h 165 } 166 default: 167 panic(fmt.Sprintf("unsupported session message %T", v)) 168 } 169 170 if verifyOrigin == nil { 171 // sign session message body 172 if err := signServiceMessagePart(key, body, verifyHdr.SetBodySignature); err != nil { 173 return fmt.Errorf("could not sign body: %w", err) 174 } 175 } 176 177 // sign meta header 178 if err := signServiceMessagePart(key, meta, verifyHdr.SetMetaSignature); err != nil { 179 return fmt.Errorf("could not sign meta header: %w", err) 180 } 181 182 // sign verification header origin 183 if err := signServiceMessagePart(key, verifyOrigin, verifyHdr.SetOriginSignature); err != nil { 184 return fmt.Errorf("could not sign origin of verification header: %w", err) 185 } 186 187 // wrap origin verification header 188 verifyHdr.setOrigin(verifyOrigin) 189 190 // update matryoshka verification header 191 verifyHdrSetter(verifyHdr) 192 193 return nil 194 } 195 196 func signServiceMessagePart(key *ecdsa.PrivateKey, part stableMarshaler, sigWrite func(*refs.Signature)) error { 197 var sig *refs.Signature 198 199 // sign part 200 if err := signature.SignDataWithHandler( 201 key, 202 &StableMarshalerWrapper{part}, 203 func(s *refs.Signature) { 204 sig = s 205 }, 206 ); err != nil { 207 return err 208 } 209 210 // write part signature 211 sigWrite(sig) 212 213 return nil 214 } 215 216 func VerifyServiceMessage(msg interface{}) error { 217 var ( 218 meta metaHeader 219 verify verificationHeader 220 ) 221 222 switch v := msg.(type) { 223 case nil: 224 return nil 225 case serviceRequest: 226 meta = &requestMetaHeader{ 227 RequestMetaHeader: v.GetMetaHeader(), 228 } 229 230 verify = &requestVerificationHeader{ 231 RequestVerificationHeader: v.GetVerificationHeader(), 232 } 233 case serviceResponse: 234 meta = &responseMetaHeader{ 235 ResponseMetaHeader: v.GetMetaHeader(), 236 } 237 238 verify = &responseVerificationHeader{ 239 ResponseVerificationHeader: v.GetVerificationHeader(), 240 } 241 default: 242 panic(fmt.Sprintf("unsupported session message %T", v)) 243 } 244 245 body := serviceMessageBody(msg) 246 size := body.StableSize() 247 if sz := meta.StableSize(); sz > size { 248 size = sz 249 } 250 if sz := verify.StableSize(); sz > size { 251 size = sz 252 } 253 254 buf := make([]byte, 0, size) 255 return verifyMatryoshkaLevel(body, meta, verify, buf) 256 } 257 258 func verifyMatryoshkaLevel(body stableMarshaler, meta metaHeader, verify verificationHeader, buf []byte) error { 259 if err := verifyServiceMessagePart(meta, verify.GetMetaSignature, buf); err != nil { 260 return fmt.Errorf("could not verify meta header: %w", err) 261 } 262 263 origin := verify.getOrigin() 264 265 if err := verifyServiceMessagePart(origin, verify.GetOriginSignature, buf); err != nil { 266 return fmt.Errorf("could not verify origin of verification header: %w", err) 267 } 268 269 if origin == nil { 270 if err := verifyServiceMessagePart(body, verify.GetBodySignature, buf); err != nil { 271 return fmt.Errorf("could not verify body: %w", err) 272 } 273 274 return nil 275 } 276 277 if verify.GetBodySignature() != nil { 278 return errors.New("body signature at the matryoshka upper level") 279 } 280 281 return verifyMatryoshkaLevel(body, meta.getOrigin(), origin, buf) 282 } 283 284 func verifyServiceMessagePart(part stableMarshaler, sigRdr func() *refs.Signature, buf []byte) error { 285 return signature.VerifyDataWithSource( 286 &StableMarshalerWrapper{part}, 287 sigRdr, 288 signature.WithBuffer(buf), 289 ) 290 } 291 292 func serviceMessageBody(req interface{}) stableMarshaler { 293 switch v := req.(type) { 294 default: 295 panic(fmt.Sprintf("unsupported session message %T", req)) 296 297 /* Accounting */ 298 case *accounting.BalanceRequest: 299 return v.GetBody() 300 case *accounting.BalanceResponse: 301 return v.GetBody() 302 303 /* Session */ 304 case *session.CreateRequest: 305 return v.GetBody() 306 case *session.CreateResponse: 307 return v.GetBody() 308 309 /* Container */ 310 case *container.PutRequest: 311 return v.GetBody() 312 case *container.PutResponse: 313 return v.GetBody() 314 case *container.DeleteRequest: 315 return v.GetBody() 316 case *container.DeleteResponse: 317 return v.GetBody() 318 case *container.GetRequest: 319 return v.GetBody() 320 case *container.GetResponse: 321 return v.GetBody() 322 case *container.ListRequest: 323 return v.GetBody() 324 case *container.ListResponse: 325 return v.GetBody() 326 case *container.SetExtendedACLRequest: 327 return v.GetBody() 328 case *container.SetExtendedACLResponse: 329 return v.GetBody() 330 case *container.GetExtendedACLRequest: 331 return v.GetBody() 332 case *container.GetExtendedACLResponse: 333 return v.GetBody() 334 case *container.AnnounceUsedSpaceRequest: 335 return v.GetBody() 336 case *container.AnnounceUsedSpaceResponse: 337 return v.GetBody() 338 339 /* Object */ 340 case *object.PutRequest: 341 return v.GetBody() 342 case *object.PutResponse: 343 return v.GetBody() 344 case *object.GetRequest: 345 return v.GetBody() 346 case *object.GetResponse: 347 return v.GetBody() 348 case *object.HeadRequest: 349 return v.GetBody() 350 case *object.HeadResponse: 351 return v.GetBody() 352 case *object.SearchRequest: 353 return v.GetBody() 354 case *object.SearchResponse: 355 return v.GetBody() 356 case *object.DeleteRequest: 357 return v.GetBody() 358 case *object.DeleteResponse: 359 return v.GetBody() 360 case *object.GetRangeRequest: 361 return v.GetBody() 362 case *object.GetRangeResponse: 363 return v.GetBody() 364 case *object.GetRangeHashRequest: 365 return v.GetBody() 366 case *object.GetRangeHashResponse: 367 return v.GetBody() 368 369 /* Netmap */ 370 case *netmap.LocalNodeInfoRequest: 371 return v.GetBody() 372 case *netmap.LocalNodeInfoResponse: 373 return v.GetBody() 374 case *netmap.NetworkInfoRequest: 375 return v.GetBody() 376 case *netmap.NetworkInfoResponse: 377 return v.GetBody() 378 case *netmap.SnapshotRequest: 379 return v.GetBody() 380 case *netmap.SnapshotResponse: 381 return v.GetBody() 382 383 /* Reputation */ 384 case *reputation.AnnounceLocalTrustRequest: 385 return v.GetBody() 386 case *reputation.AnnounceLocalTrustResponse: 387 return v.GetBody() 388 case *reputation.AnnounceIntermediateResultRequest: 389 return v.GetBody() 390 case *reputation.AnnounceIntermediateResultResponse: 391 return v.GetBody() 392 } 393 }