github.com/lestrrat-go/jwx/v2@v2.0.21/jws/message.go (about) 1 package jws 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 8 "github.com/lestrrat-go/jwx/v2/internal/base64" 9 "github.com/lestrrat-go/jwx/v2/internal/json" 10 "github.com/lestrrat-go/jwx/v2/internal/pool" 11 "github.com/lestrrat-go/jwx/v2/jwk" 12 ) 13 14 func NewSignature() *Signature { 15 return &Signature{} 16 } 17 18 func (s *Signature) DecodeCtx() DecodeCtx { 19 return s.dc 20 } 21 22 func (s *Signature) SetDecodeCtx(dc DecodeCtx) { 23 s.dc = dc 24 } 25 26 func (s Signature) PublicHeaders() Headers { 27 return s.headers 28 } 29 30 func (s *Signature) SetPublicHeaders(v Headers) *Signature { 31 s.headers = v 32 return s 33 } 34 35 func (s Signature) ProtectedHeaders() Headers { 36 return s.protected 37 } 38 39 func (s *Signature) SetProtectedHeaders(v Headers) *Signature { 40 s.protected = v 41 return s 42 } 43 44 func (s Signature) Signature() []byte { 45 return s.signature 46 } 47 48 func (s *Signature) SetSignature(v []byte) *Signature { 49 s.signature = v 50 return s 51 } 52 53 type signatureUnmarshalProbe struct { 54 Header Headers `json:"header,omitempty"` 55 Protected *string `json:"protected,omitempty"` 56 Signature *string `json:"signature,omitempty"` 57 } 58 59 func (s *Signature) UnmarshalJSON(data []byte) error { 60 var sup signatureUnmarshalProbe 61 sup.Header = NewHeaders() 62 if err := json.Unmarshal(data, &sup); err != nil { 63 return fmt.Errorf(`failed to unmarshal signature into temporary struct: %w`, err) 64 } 65 66 s.headers = sup.Header 67 if buf := sup.Protected; buf != nil { 68 src := []byte(*buf) 69 if !bytes.HasPrefix(src, []byte{'{'}) { 70 decoded, err := base64.Decode(src) 71 if err != nil { 72 return fmt.Errorf(`failed to base64 decode protected headers: %w`, err) 73 } 74 src = decoded 75 } 76 77 prt := NewHeaders() 78 //nolint:forcetypeassert 79 prt.(*stdHeaders).SetDecodeCtx(s.DecodeCtx()) 80 if err := json.Unmarshal(src, prt); err != nil { 81 return fmt.Errorf(`failed to unmarshal protected headers: %w`, err) 82 } 83 //nolint:forcetypeassert 84 prt.(*stdHeaders).SetDecodeCtx(nil) 85 s.protected = prt 86 } 87 88 if sup.Signature != nil { 89 decoded, err := base64.DecodeString(*sup.Signature) 90 if err != nil { 91 return fmt.Errorf(`failed to base decode signature: %w`, err) 92 } 93 s.signature = decoded 94 } 95 return nil 96 } 97 98 // Sign populates the signature field, with a signature generated by 99 // given the signer object and payload. 100 // 101 // The first return value is the raw signature in binary format. 102 // The second return value s the full three-segment signature 103 // (e.g. "eyXXXX.XXXXX.XXXX") 104 func (s *Signature) Sign(payload []byte, signer Signer, key interface{}) ([]byte, []byte, error) { 105 ctx, cancel := context.WithCancel(context.Background()) 106 defer cancel() 107 108 hdrs, err := mergeHeaders(ctx, s.headers, s.protected) 109 if err != nil { 110 return nil, nil, fmt.Errorf(`failed to merge headers: %w`, err) 111 } 112 113 if err := hdrs.Set(AlgorithmKey, signer.Algorithm()); err != nil { 114 return nil, nil, fmt.Errorf(`failed to set "alg": %w`, err) 115 } 116 117 // If the key is a jwk.Key instance, obtain the raw key 118 if jwkKey, ok := key.(jwk.Key); ok { 119 // If we have a key ID specified by this jwk.Key, use that in the header 120 if kid := jwkKey.KeyID(); kid != "" { 121 if err := hdrs.Set(jwk.KeyIDKey, kid); err != nil { 122 return nil, nil, fmt.Errorf(`set key ID from jwk.Key: %w`, err) 123 } 124 } 125 } 126 hdrbuf, err := json.Marshal(hdrs) 127 if err != nil { 128 return nil, nil, fmt.Errorf(`failed to marshal headers: %w`, err) 129 } 130 131 buf := pool.GetBytesBuffer() 132 defer pool.ReleaseBytesBuffer(buf) 133 134 buf.WriteString(base64.EncodeToString(hdrbuf)) 135 buf.WriteByte('.') 136 137 var plen int 138 b64 := getB64Value(hdrs) 139 if b64 { 140 encoded := base64.EncodeToString(payload) 141 plen = len(encoded) 142 buf.WriteString(encoded) 143 } else { 144 if !s.detached { 145 if bytes.Contains(payload, []byte{'.'}) { 146 return nil, nil, fmt.Errorf(`payload must not contain a "."`) 147 } 148 } 149 plen = len(payload) 150 buf.Write(payload) 151 } 152 153 signature, err := signer.Sign(buf.Bytes(), key) 154 if err != nil { 155 return nil, nil, fmt.Errorf(`failed to sign payload: %w`, err) 156 } 157 s.signature = signature 158 159 // Detached payload, this should be removed from the end result 160 if s.detached { 161 buf.Truncate(buf.Len() - plen) 162 } 163 164 buf.WriteByte('.') 165 buf.WriteString(base64.EncodeToString(signature)) 166 ret := make([]byte, buf.Len()) 167 copy(ret, buf.Bytes()) 168 169 return signature, ret, nil 170 } 171 172 func NewMessage() *Message { 173 return &Message{} 174 } 175 176 // Clears the internal raw buffer that was accumulated during 177 // the verify phase 178 func (m *Message) clearRaw() { 179 for _, sig := range m.signatures { 180 if protected := sig.protected; protected != nil { 181 if cr, ok := protected.(*stdHeaders); ok { 182 cr.raw = nil 183 } 184 } 185 } 186 } 187 188 func (m *Message) SetDecodeCtx(dc DecodeCtx) { 189 m.dc = dc 190 } 191 192 func (m *Message) DecodeCtx() DecodeCtx { 193 return m.dc 194 } 195 196 // Payload returns the decoded payload 197 func (m Message) Payload() []byte { 198 return m.payload 199 } 200 201 func (m *Message) SetPayload(v []byte) *Message { 202 m.payload = v 203 return m 204 } 205 206 func (m Message) Signatures() []*Signature { 207 return m.signatures 208 } 209 210 func (m *Message) AppendSignature(v *Signature) *Message { 211 m.signatures = append(m.signatures, v) 212 return m 213 } 214 215 func (m *Message) ClearSignatures() *Message { 216 m.signatures = nil 217 return m 218 } 219 220 // LookupSignature looks up a particular signature entry using 221 // the `kid` value 222 func (m Message) LookupSignature(kid string) []*Signature { 223 var sigs []*Signature 224 for _, sig := range m.signatures { 225 if hdr := sig.PublicHeaders(); hdr != nil { 226 hdrKeyID := hdr.KeyID() 227 if hdrKeyID == kid { 228 sigs = append(sigs, sig) 229 continue 230 } 231 } 232 233 if hdr := sig.ProtectedHeaders(); hdr != nil { 234 hdrKeyID := hdr.KeyID() 235 if hdrKeyID == kid { 236 sigs = append(sigs, sig) 237 continue 238 } 239 } 240 } 241 return sigs 242 } 243 244 // This struct is used to first probe for the structure of the 245 // incoming JSON object. We then decide how to parse it 246 // from the fields that are populated. 247 type messageUnmarshalProbe struct { 248 Payload *string `json:"payload"` 249 Signatures []json.RawMessage `json:"signatures,omitempty"` 250 Header Headers `json:"header,omitempty"` 251 Protected *string `json:"protected,omitempty"` 252 Signature *string `json:"signature,omitempty"` 253 } 254 255 func (m *Message) UnmarshalJSON(buf []byte) error { 256 m.payload = nil 257 m.signatures = nil 258 m.b64 = true 259 260 var mup messageUnmarshalProbe 261 mup.Header = NewHeaders() 262 if err := json.Unmarshal(buf, &mup); err != nil { 263 return fmt.Errorf(`failed to unmarshal into temporary structure: %w`, err) 264 } 265 266 b64 := true 267 if mup.Signature == nil { // flattened signature is NOT present 268 if len(mup.Signatures) == 0 { 269 return fmt.Errorf(`required field "signatures" not present`) 270 } 271 272 m.signatures = make([]*Signature, 0, len(mup.Signatures)) 273 for i, rawsig := range mup.Signatures { 274 var sig Signature 275 sig.SetDecodeCtx(m.DecodeCtx()) 276 if err := json.Unmarshal(rawsig, &sig); err != nil { 277 return fmt.Errorf(`failed to unmarshal signature #%d: %w`, i+1, err) 278 } 279 sig.SetDecodeCtx(nil) 280 281 if sig.protected == nil { 282 // Instead of barfing on a nil protected header, use an empty header 283 sig.protected = NewHeaders() 284 } 285 286 if i == 0 { 287 if !getB64Value(sig.protected) { 288 b64 = false 289 } 290 } else { 291 if b64 != getB64Value(sig.protected) { 292 return fmt.Errorf(`b64 value must be the same for all signatures`) 293 } 294 } 295 296 m.signatures = append(m.signatures, &sig) 297 } 298 } else { // .signature is present, it's a flattened structure 299 if len(mup.Signatures) != 0 { 300 return fmt.Errorf(`invalid format ("signatures" and "signature" keys cannot both be present)`) 301 } 302 303 var sig Signature 304 sig.headers = mup.Header 305 if src := mup.Protected; src != nil { 306 decoded, err := base64.DecodeString(*src) 307 if err != nil { 308 return fmt.Errorf(`failed to base64 decode flattened protected headers: %w`, err) 309 } 310 prt := NewHeaders() 311 //nolint:forcetypeassert 312 prt.(*stdHeaders).SetDecodeCtx(m.DecodeCtx()) 313 if err := json.Unmarshal(decoded, prt); err != nil { 314 return fmt.Errorf(`failed to unmarshal flattened protected headers: %w`, err) 315 } 316 //nolint:forcetypeassert 317 prt.(*stdHeaders).SetDecodeCtx(nil) 318 sig.protected = prt 319 } 320 321 if sig.protected == nil { 322 // Instead of barfing on a nil protected header, use an empty header 323 sig.protected = NewHeaders() 324 } 325 326 decoded, err := base64.DecodeString(*mup.Signature) 327 if err != nil { 328 return fmt.Errorf(`failed to base64 decode flattened signature: %w`, err) 329 } 330 sig.signature = decoded 331 332 m.signatures = []*Signature{&sig} 333 b64 = getB64Value(sig.protected) 334 } 335 336 if mup.Payload != nil { 337 if !b64 { // NOT base64 encoded 338 m.payload = []byte(*mup.Payload) 339 } else { 340 decoded, err := base64.DecodeString(*mup.Payload) 341 if err != nil { 342 return fmt.Errorf(`failed to base64 decode payload: %w`, err) 343 } 344 m.payload = decoded 345 } 346 } 347 m.b64 = b64 348 return nil 349 } 350 351 func (m Message) MarshalJSON() ([]byte, error) { 352 if len(m.signatures) == 1 { 353 return m.marshalFlattened() 354 } 355 return m.marshalFull() 356 } 357 358 func (m Message) marshalFlattened() ([]byte, error) { 359 buf := pool.GetBytesBuffer() 360 defer pool.ReleaseBytesBuffer(buf) 361 362 sig := m.signatures[0] 363 364 buf.WriteRune('{') 365 var wrote bool 366 367 if hdr := sig.headers; hdr != nil { 368 hdrjs, err := hdr.MarshalJSON() 369 if err != nil { 370 return nil, fmt.Errorf(`failed to marshal "header" (flattened format): %w`, err) 371 } 372 buf.WriteString(`"header":`) 373 buf.Write(hdrjs) 374 wrote = true 375 } 376 377 if wrote { 378 buf.WriteRune(',') 379 } 380 buf.WriteString(`"payload":"`) 381 buf.WriteString(base64.EncodeToString(m.payload)) 382 buf.WriteRune('"') 383 384 if protected := sig.protected; protected != nil { 385 protectedbuf, err := protected.MarshalJSON() 386 if err != nil { 387 return nil, fmt.Errorf(`failed to marshal "protected" (flattened format): %w`, err) 388 } 389 buf.WriteString(`,"protected":"`) 390 buf.WriteString(base64.EncodeToString(protectedbuf)) 391 buf.WriteRune('"') 392 } 393 394 buf.WriteString(`,"signature":"`) 395 buf.WriteString(base64.EncodeToString(sig.signature)) 396 buf.WriteRune('"') 397 buf.WriteRune('}') 398 399 ret := make([]byte, buf.Len()) 400 copy(ret, buf.Bytes()) 401 return ret, nil 402 } 403 404 func (m Message) marshalFull() ([]byte, error) { 405 buf := pool.GetBytesBuffer() 406 defer pool.ReleaseBytesBuffer(buf) 407 408 buf.WriteString(`{"payload":"`) 409 buf.WriteString(base64.EncodeToString(m.payload)) 410 buf.WriteString(`","signatures":[`) 411 for i, sig := range m.signatures { 412 if i > 0 { 413 buf.WriteRune(',') 414 } 415 416 buf.WriteRune('{') 417 var wrote bool 418 if hdr := sig.headers; hdr != nil { 419 hdrbuf, err := hdr.MarshalJSON() 420 if err != nil { 421 return nil, fmt.Errorf(`failed to marshal "header" for signature #%d: %w`, i+1, err) 422 } 423 buf.WriteString(`"header":`) 424 buf.Write(hdrbuf) 425 wrote = true 426 } 427 428 if protected := sig.protected; protected != nil { 429 protectedbuf, err := protected.MarshalJSON() 430 if err != nil { 431 return nil, fmt.Errorf(`failed to marshal "protected" for signature #%d: %w`, i+1, err) 432 } 433 if wrote { 434 buf.WriteRune(',') 435 } 436 buf.WriteString(`"protected":"`) 437 buf.WriteString(base64.EncodeToString(protectedbuf)) 438 buf.WriteRune('"') 439 wrote = true 440 } 441 442 if len(sig.signature) > 0 { 443 // If InsecureNoSignature is enabled, signature may not exist 444 if wrote { 445 buf.WriteRune(',') 446 } 447 buf.WriteString(`"signature":"`) 448 buf.WriteString(base64.EncodeToString(sig.signature)) 449 buf.WriteString(`"`) 450 } 451 buf.WriteString(`}`) 452 } 453 buf.WriteString(`]}`) 454 455 ret := make([]byte, buf.Len()) 456 copy(ret, buf.Bytes()) 457 return ret, nil 458 } 459 460 // Compact generates a JWS message in compact serialization format from 461 // `*jws.Message` object. The object contain exactly one signature, or 462 // an error is returned. 463 // 464 // If using a detached payload, the payload must already be stored in 465 // the `*jws.Message` object, and the `jws.WithDetached()` option 466 // must be passed to the function. 467 func Compact(msg *Message, options ...CompactOption) ([]byte, error) { 468 if l := len(msg.signatures); l != 1 { 469 return nil, fmt.Errorf(`jws.Compact: cannot serialize message with %d signatures (must be one)`, l) 470 } 471 472 var detached bool 473 for _, option := range options { 474 //nolint:forcetypeassert 475 switch option.Ident() { 476 case identDetached{}: 477 detached = option.Value().(bool) 478 } 479 } 480 481 s := msg.signatures[0] 482 // XXX check if this is correct 483 hdrs := s.ProtectedHeaders() 484 485 hdrbuf, err := json.Marshal(hdrs) 486 if err != nil { 487 return nil, fmt.Errorf(`jws.Compress: failed to marshal headers: %w`, err) 488 } 489 490 buf := pool.GetBytesBuffer() 491 defer pool.ReleaseBytesBuffer(buf) 492 493 buf.WriteString(base64.EncodeToString(hdrbuf)) 494 buf.WriteByte('.') 495 496 if !detached { 497 if getB64Value(hdrs) { 498 encoded := base64.EncodeToString(msg.payload) 499 buf.WriteString(encoded) 500 } else { 501 if bytes.Contains(msg.payload, []byte{'.'}) { 502 return nil, fmt.Errorf(`jws.Compress: payload must not contain a "."`) 503 } 504 buf.Write(msg.payload) 505 } 506 } 507 508 buf.WriteByte('.') 509 buf.WriteString(base64.EncodeToString(s.signature)) 510 ret := make([]byte, buf.Len()) 511 copy(ret, buf.Bytes()) 512 return ret, nil 513 }