github.com/lestrrat-go/jwx/v2@v2.0.21/jwe/message.go (about) 1 package jwe 2 3 import ( 4 "context" 5 "fmt" 6 "sort" 7 "strings" 8 9 "github.com/lestrrat-go/jwx/v2/internal/base64" 10 "github.com/lestrrat-go/jwx/v2/internal/json" 11 "github.com/lestrrat-go/jwx/v2/internal/pool" 12 ) 13 14 // NewRecipient creates a Recipient object 15 func NewRecipient() Recipient { 16 return &stdRecipient{ 17 headers: NewHeaders(), 18 } 19 } 20 21 func (r *stdRecipient) SetHeaders(h Headers) error { 22 r.headers = h 23 return nil 24 } 25 26 func (r *stdRecipient) SetEncryptedKey(v []byte) error { 27 r.encryptedKey = v 28 return nil 29 } 30 31 func (r *stdRecipient) Headers() Headers { 32 return r.headers 33 } 34 35 func (r *stdRecipient) EncryptedKey() []byte { 36 return r.encryptedKey 37 } 38 39 type recipientMarshalProxy struct { 40 Headers Headers `json:"header"` 41 EncryptedKey string `json:"encrypted_key"` 42 } 43 44 func (r *stdRecipient) UnmarshalJSON(buf []byte) error { 45 var proxy recipientMarshalProxy 46 proxy.Headers = NewHeaders() 47 if err := json.Unmarshal(buf, &proxy); err != nil { 48 return fmt.Errorf(`failed to unmarshal json into recipient: %w`, err) 49 } 50 51 r.headers = proxy.Headers 52 decoded, err := base64.DecodeString(proxy.EncryptedKey) 53 if err != nil { 54 return fmt.Errorf(`failed to decode "encrypted_key": %w`, err) 55 } 56 r.encryptedKey = decoded 57 return nil 58 } 59 60 func (r *stdRecipient) MarshalJSON() ([]byte, error) { 61 buf := pool.GetBytesBuffer() 62 defer pool.ReleaseBytesBuffer(buf) 63 64 buf.WriteString(`{"header":`) 65 hdrbuf, err := r.headers.MarshalJSON() 66 if err != nil { 67 return nil, fmt.Errorf(`failed to marshal recipient header: %w`, err) 68 } 69 buf.Write(hdrbuf) 70 buf.WriteString(`,"encrypted_key":"`) 71 buf.WriteString(base64.EncodeToString(r.encryptedKey)) 72 buf.WriteString(`"}`) 73 74 ret := make([]byte, buf.Len()) 75 copy(ret, buf.Bytes()) 76 return ret, nil 77 } 78 79 // NewMessage creates a new message 80 func NewMessage() *Message { 81 return &Message{} 82 } 83 84 func (m *Message) AuthenticatedData() []byte { 85 return m.authenticatedData 86 } 87 88 func (m *Message) CipherText() []byte { 89 return m.cipherText 90 } 91 92 func (m *Message) InitializationVector() []byte { 93 return m.initializationVector 94 } 95 96 func (m *Message) Tag() []byte { 97 return m.tag 98 } 99 100 func (m *Message) ProtectedHeaders() Headers { 101 return m.protectedHeaders 102 } 103 104 func (m *Message) Recipients() []Recipient { 105 return m.recipients 106 } 107 108 func (m *Message) UnprotectedHeaders() Headers { 109 return m.unprotectedHeaders 110 } 111 112 const ( 113 AuthenticatedDataKey = "aad" 114 CipherTextKey = "ciphertext" 115 CountKey = "p2c" 116 InitializationVectorKey = "iv" 117 ProtectedHeadersKey = "protected" 118 RecipientsKey = "recipients" 119 SaltKey = "p2s" 120 TagKey = "tag" 121 UnprotectedHeadersKey = "unprotected" 122 HeadersKey = "header" 123 EncryptedKeyKey = "encrypted_key" 124 ) 125 126 func (m *Message) Set(k string, v interface{}) error { 127 switch k { 128 case AuthenticatedDataKey: 129 buf, ok := v.([]byte) 130 if !ok { 131 return fmt.Errorf(`invalid value %T for %s key`, v, AuthenticatedDataKey) 132 } 133 m.authenticatedData = buf 134 case CipherTextKey: 135 buf, ok := v.([]byte) 136 if !ok { 137 return fmt.Errorf(`invalid value %T for %s key`, v, CipherTextKey) 138 } 139 m.cipherText = buf 140 case InitializationVectorKey: 141 buf, ok := v.([]byte) 142 if !ok { 143 return fmt.Errorf(`invalid value %T for %s key`, v, InitializationVectorKey) 144 } 145 m.initializationVector = buf 146 case ProtectedHeadersKey: 147 cv, ok := v.(Headers) 148 if !ok { 149 return fmt.Errorf(`invalid value %T for %s key`, v, ProtectedHeadersKey) 150 } 151 m.protectedHeaders = cv 152 case RecipientsKey: 153 cv, ok := v.([]Recipient) 154 if !ok { 155 return fmt.Errorf(`invalid value %T for %s key`, v, RecipientsKey) 156 } 157 m.recipients = cv 158 case TagKey: 159 buf, ok := v.([]byte) 160 if !ok { 161 return fmt.Errorf(`invalid value %T for %s key`, v, TagKey) 162 } 163 m.tag = buf 164 case UnprotectedHeadersKey: 165 cv, ok := v.(Headers) 166 if !ok { 167 return fmt.Errorf(`invalid value %T for %s key`, v, UnprotectedHeadersKey) 168 } 169 m.unprotectedHeaders = cv 170 default: 171 if m.unprotectedHeaders == nil { 172 m.unprotectedHeaders = NewHeaders() 173 } 174 return m.unprotectedHeaders.Set(k, v) 175 } 176 return nil 177 } 178 179 type messageMarshalProxy struct { 180 AuthenticatedData string `json:"aad,omitempty"` 181 CipherText string `json:"ciphertext"` 182 InitializationVector string `json:"iv,omitempty"` 183 ProtectedHeaders json.RawMessage `json:"protected"` 184 Recipients []json.RawMessage `json:"recipients,omitempty"` 185 Tag string `json:"tag,omitempty"` 186 UnprotectedHeaders Headers `json:"unprotected,omitempty"` 187 188 // For flattened structure. Headers is NOT a Headers type, 189 // so that we can detect its presence by checking proxy.Headers != nil 190 Headers json.RawMessage `json:"header,omitempty"` 191 EncryptedKey string `json:"encrypted_key,omitempty"` 192 } 193 194 type jsonKV struct { 195 Key string 196 Value string 197 } 198 199 func (m *Message) MarshalJSON() ([]byte, error) { 200 // This is slightly convoluted, but we need to encode the 201 // protected headers, so we do it by hand 202 buf := pool.GetBytesBuffer() 203 defer pool.ReleaseBytesBuffer(buf) 204 enc := json.NewEncoder(buf) 205 206 var fields []jsonKV 207 208 if cipherText := m.CipherText(); len(cipherText) > 0 { 209 buf.Reset() 210 if err := enc.Encode(base64.EncodeToString(cipherText)); err != nil { 211 return nil, fmt.Errorf(`failed to encode %s field: %w`, CipherTextKey, err) 212 } 213 fields = append(fields, jsonKV{ 214 Key: CipherTextKey, 215 Value: strings.TrimSpace(buf.String()), 216 }) 217 } 218 219 if iv := m.InitializationVector(); len(iv) > 0 { 220 buf.Reset() 221 if err := enc.Encode(base64.EncodeToString(iv)); err != nil { 222 return nil, fmt.Errorf(`failed to encode %s field: %w`, InitializationVectorKey, err) 223 } 224 fields = append(fields, jsonKV{ 225 Key: InitializationVectorKey, 226 Value: strings.TrimSpace(buf.String()), 227 }) 228 } 229 230 var encodedProtectedHeaders []byte 231 if h := m.ProtectedHeaders(); h != nil { 232 v, err := h.Encode() 233 if err != nil { 234 return nil, fmt.Errorf(`failed to encode protected headers: %w`, err) 235 } 236 237 encodedProtectedHeaders = v 238 if len(encodedProtectedHeaders) <= 2 { // '{}' 239 encodedProtectedHeaders = nil 240 } else { 241 fields = append(fields, jsonKV{ 242 Key: ProtectedHeadersKey, 243 Value: fmt.Sprintf("%q", encodedProtectedHeaders), 244 }) 245 } 246 } 247 248 if aad := m.AuthenticatedData(); len(aad) > 0 { 249 aad = base64.Encode(aad) 250 if encodedProtectedHeaders != nil { 251 tmp := append(encodedProtectedHeaders, '.') 252 aad = append(tmp, aad...) 253 } 254 255 buf.Reset() 256 if err := enc.Encode(aad); err != nil { 257 return nil, fmt.Errorf(`failed to encode %s field: %w`, AuthenticatedDataKey, err) 258 } 259 fields = append(fields, jsonKV{ 260 Key: AuthenticatedDataKey, 261 Value: strings.TrimSpace(buf.String()), 262 }) 263 } 264 265 if recipients := m.Recipients(); len(recipients) > 0 { 266 if len(recipients) == 1 { // Use flattened format 267 if hdrs := recipients[0].Headers(); hdrs != nil { 268 buf.Reset() 269 if err := enc.Encode(hdrs); err != nil { 270 return nil, fmt.Errorf(`failed to encode %s field: %w`, HeadersKey, err) 271 } 272 fields = append(fields, jsonKV{ 273 Key: HeadersKey, 274 Value: strings.TrimSpace(buf.String()), 275 }) 276 } 277 278 if ek := recipients[0].EncryptedKey(); len(ek) > 0 { 279 buf.Reset() 280 if err := enc.Encode(base64.EncodeToString(ek)); err != nil { 281 return nil, fmt.Errorf(`failed to encode %s field: %w`, EncryptedKeyKey, err) 282 } 283 fields = append(fields, jsonKV{ 284 Key: EncryptedKeyKey, 285 Value: strings.TrimSpace(buf.String()), 286 }) 287 } 288 } else { 289 buf.Reset() 290 if err := enc.Encode(recipients); err != nil { 291 return nil, fmt.Errorf(`failed to encode %s field: %w`, RecipientsKey, err) 292 } 293 fields = append(fields, jsonKV{ 294 Key: RecipientsKey, 295 Value: strings.TrimSpace(buf.String()), 296 }) 297 } 298 } 299 300 if tag := m.Tag(); len(tag) > 0 { 301 buf.Reset() 302 if err := enc.Encode(base64.EncodeToString(tag)); err != nil { 303 return nil, fmt.Errorf(`failed to encode %s field: %w`, TagKey, err) 304 } 305 fields = append(fields, jsonKV{ 306 Key: TagKey, 307 Value: strings.TrimSpace(buf.String()), 308 }) 309 } 310 311 if h := m.UnprotectedHeaders(); h != nil { 312 unprotected, err := json.Marshal(h) 313 if err != nil { 314 return nil, fmt.Errorf(`failed to encode unprotected headers: %w`, err) 315 } 316 317 if len(unprotected) > 2 { 318 fields = append(fields, jsonKV{ 319 Key: UnprotectedHeadersKey, 320 Value: fmt.Sprintf("%q", unprotected), 321 }) 322 } 323 } 324 325 sort.Slice(fields, func(i, j int) bool { 326 return fields[i].Key < fields[j].Key 327 }) 328 buf.Reset() 329 fmt.Fprintf(buf, `{`) 330 for i, kv := range fields { 331 if i > 0 { 332 fmt.Fprintf(buf, `,`) 333 } 334 fmt.Fprintf(buf, `%q:%s`, kv.Key, kv.Value) 335 } 336 fmt.Fprintf(buf, `}`) 337 338 ret := make([]byte, buf.Len()) 339 copy(ret, buf.Bytes()) 340 return ret, nil 341 } 342 343 func (m *Message) UnmarshalJSON(buf []byte) error { 344 var proxy messageMarshalProxy 345 proxy.UnprotectedHeaders = NewHeaders() 346 347 if err := json.Unmarshal(buf, &proxy); err != nil { 348 return fmt.Errorf(`failed to unmashal JSON into message: %w`, err) 349 } 350 351 // Get the string value 352 var protectedHeadersStr string 353 if err := json.Unmarshal(proxy.ProtectedHeaders, &protectedHeadersStr); err != nil { 354 return fmt.Errorf(`failed to decode protected headers (1): %w`, err) 355 } 356 357 // It's now in _quoted_ base64 string. Decode it 358 protectedHeadersRaw, err := base64.DecodeString(protectedHeadersStr) 359 if err != nil { 360 return fmt.Errorf(`failed to base64 decoded protected headers buffer: %w`, err) 361 } 362 363 h := NewHeaders() 364 if err := json.Unmarshal(protectedHeadersRaw, h); err != nil { 365 return fmt.Errorf(`failed to decode protected headers (2): %w`, err) 366 } 367 368 // if this were a flattened message, we would see a "header" and "ciphertext" 369 // field. TODO: do both of these conditions need to meet, or just one? 370 if proxy.Headers != nil || len(proxy.EncryptedKey) > 0 { 371 recipient := NewRecipient() 372 hdrs := NewHeaders() 373 if err := json.Unmarshal(proxy.Headers, hdrs); err != nil { 374 return fmt.Errorf(`failed to decode headers field: %w`, err) 375 } 376 377 if err := recipient.SetHeaders(hdrs); err != nil { 378 return fmt.Errorf(`failed to set new headers: %w`, err) 379 } 380 381 if v := proxy.EncryptedKey; len(v) > 0 { 382 buf, err := base64.DecodeString(v) 383 if err != nil { 384 return fmt.Errorf(`failed to decode encrypted key: %w`, err) 385 } 386 if err := recipient.SetEncryptedKey(buf); err != nil { 387 return fmt.Errorf(`failed to set encrypted key: %w`, err) 388 } 389 } 390 391 m.recipients = append(m.recipients, recipient) 392 } else { 393 for i, recipientbuf := range proxy.Recipients { 394 recipient := NewRecipient() 395 if err := json.Unmarshal(recipientbuf, recipient); err != nil { 396 return fmt.Errorf(`failed to decode recipient at index %d: %w`, i, err) 397 } 398 399 m.recipients = append(m.recipients, recipient) 400 } 401 } 402 403 if src := proxy.AuthenticatedData; len(src) > 0 { 404 v, err := base64.DecodeString(src) 405 if err != nil { 406 return fmt.Errorf(`failed to decode "aad": %w`, err) 407 } 408 m.authenticatedData = v 409 } 410 411 if src := proxy.CipherText; len(src) > 0 { 412 v, err := base64.DecodeString(src) 413 if err != nil { 414 return fmt.Errorf(`failed to decode "ciphertext": %w`, err) 415 } 416 m.cipherText = v 417 } 418 419 if src := proxy.InitializationVector; len(src) > 0 { 420 v, err := base64.DecodeString(src) 421 if err != nil { 422 return fmt.Errorf(`failed to decode "iv": %w`, err) 423 } 424 m.initializationVector = v 425 } 426 427 if src := proxy.Tag; len(src) > 0 { 428 v, err := base64.DecodeString(src) 429 if err != nil { 430 return fmt.Errorf(`failed to decode "tag": %w`, err) 431 } 432 m.tag = v 433 } 434 435 m.protectedHeaders = h 436 if m.storeProtectedHeaders { 437 // this is later used for decryption 438 m.rawProtectedHeaders = base64.Encode(protectedHeadersRaw) 439 } 440 441 if iz, ok := proxy.UnprotectedHeaders.(isZeroer); ok { 442 if !iz.isZero() { 443 m.unprotectedHeaders = proxy.UnprotectedHeaders 444 } 445 } 446 447 if len(m.recipients) == 0 { 448 if err := m.makeDummyRecipient(proxy.EncryptedKey, m.protectedHeaders); err != nil { 449 return fmt.Errorf(`failed to setup recipient: %w`, err) 450 } 451 } 452 453 return nil 454 } 455 456 func (m *Message) makeDummyRecipient(enckeybuf string, protected Headers) error { 457 // Recipients in this case should not contain the content encryption key, 458 // so move that out 459 hdrs, err := protected.Clone(context.TODO()) 460 if err != nil { 461 return fmt.Errorf(`failed to clone headers: %w`, err) 462 } 463 464 if err := hdrs.Remove(ContentEncryptionKey); err != nil { 465 return fmt.Errorf(`failed to remove %#v from public header: %w`, ContentEncryptionKey, err) 466 } 467 468 enckey, err := base64.DecodeString(enckeybuf) 469 if err != nil { 470 return fmt.Errorf(`failed to decode encrypted key: %w`, err) 471 } 472 473 if err := m.Set(RecipientsKey, []Recipient{ 474 &stdRecipient{ 475 headers: hdrs, 476 encryptedKey: enckey, 477 }, 478 }); err != nil { 479 return fmt.Errorf(`failed to set %s: %w`, RecipientsKey, err) 480 } 481 return nil 482 } 483 484 // Compact generates a JWE message in compact serialization format from a 485 // `*jwe.Message` object. The object contain exactly one recipient, or 486 // an error is returned. 487 // 488 // This function currently does not take any options, but the function 489 // signature contains `options` for possible future expansion of the API 490 func Compact(m *Message, _ ...CompactOption) ([]byte, error) { 491 if len(m.recipients) != 1 { 492 return nil, fmt.Errorf(`wrong number of recipients for compact serialization`) 493 } 494 495 recipient := m.recipients[0] 496 497 // The protected header must be a merge between the message-wide 498 // protected header AND the recipient header 499 500 // There's something wrong if m.protectedHeaders is nil, but 501 // it could happen 502 if m.protectedHeaders == nil { 503 return nil, fmt.Errorf(`invalid protected header`) 504 } 505 506 ctx := context.TODO() 507 hcopy, err := m.protectedHeaders.Clone(ctx) 508 if err != nil { 509 return nil, fmt.Errorf(`failed to copy protected header: %w`, err) 510 } 511 hcopy, err = hcopy.Merge(ctx, m.unprotectedHeaders) 512 if err != nil { 513 return nil, fmt.Errorf(`failed to merge unprotected header: %w`, err) 514 } 515 hcopy, err = hcopy.Merge(ctx, recipient.Headers()) 516 if err != nil { 517 return nil, fmt.Errorf(`failed to merge recipient header: %w`, err) 518 } 519 520 protected, err := hcopy.Encode() 521 if err != nil { 522 return nil, fmt.Errorf(`failed to encode header: %w`, err) 523 } 524 525 encryptedKey := base64.Encode(recipient.EncryptedKey()) 526 iv := base64.Encode(m.initializationVector) 527 cipher := base64.Encode(m.cipherText) 528 tag := base64.Encode(m.tag) 529 530 buf := pool.GetBytesBuffer() 531 defer pool.ReleaseBytesBuffer(buf) 532 533 buf.Grow(len(protected) + len(encryptedKey) + len(iv) + len(cipher) + len(tag) + 4) 534 buf.Write(protected) 535 buf.WriteByte('.') 536 buf.Write(encryptedKey) 537 buf.WriteByte('.') 538 buf.Write(iv) 539 buf.WriteByte('.') 540 buf.Write(cipher) 541 buf.WriteByte('.') 542 buf.Write(tag) 543 544 result := make([]byte, buf.Len()) 545 copy(result, buf.Bytes()) 546 return result, nil 547 }