github.com/lestrrat-go/jwx/v2@v2.0.21/jwe/headers.go (about)

     1  package jwe
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  
     7  	"github.com/lestrrat-go/jwx/v2/internal/base64"
     8  	"github.com/lestrrat-go/jwx/v2/internal/json"
     9  
    10  	"github.com/lestrrat-go/iter/mapiter"
    11  	"github.com/lestrrat-go/jwx/v2/internal/iter"
    12  )
    13  
    14  type isZeroer interface {
    15  	isZero() bool
    16  }
    17  
    18  func (h *stdHeaders) isZero() bool {
    19  	return h.agreementPartyUInfo == nil &&
    20  		h.agreementPartyVInfo == nil &&
    21  		h.algorithm == nil &&
    22  		h.compression == nil &&
    23  		h.contentEncryption == nil &&
    24  		h.contentType == nil &&
    25  		h.critical == nil &&
    26  		h.ephemeralPublicKey == nil &&
    27  		h.jwk == nil &&
    28  		h.jwkSetURL == nil &&
    29  		h.keyID == nil &&
    30  		h.typ == nil &&
    31  		h.x509CertChain == nil &&
    32  		h.x509CertThumbprint == nil &&
    33  		h.x509CertThumbprintS256 == nil &&
    34  		h.x509URL == nil &&
    35  		len(h.privateParams) == 0
    36  }
    37  
    38  // Iterate returns a channel that successively returns all the
    39  // header name and values.
    40  func (h *stdHeaders) Iterate(ctx context.Context) Iterator {
    41  	pairs := h.makePairs()
    42  	ch := make(chan *HeaderPair, len(pairs))
    43  	go func(ctx context.Context, ch chan *HeaderPair, pairs []*HeaderPair) {
    44  		defer close(ch)
    45  		for _, pair := range pairs {
    46  			select {
    47  			case <-ctx.Done():
    48  				return
    49  			case ch <- pair:
    50  			}
    51  		}
    52  	}(ctx, ch, pairs)
    53  	return mapiter.New(ch)
    54  }
    55  
    56  func (h *stdHeaders) Walk(ctx context.Context, visitor Visitor) error {
    57  	return iter.WalkMap(ctx, h, visitor)
    58  }
    59  
    60  func (h *stdHeaders) AsMap(ctx context.Context) (map[string]interface{}, error) {
    61  	return iter.AsMap(ctx, h)
    62  }
    63  
    64  func (h *stdHeaders) Clone(ctx context.Context) (Headers, error) {
    65  	dst := NewHeaders()
    66  	if err := h.Copy(ctx, dst); err != nil {
    67  		return nil, fmt.Errorf(`failed to copy header contents to new object: %w`, err)
    68  	}
    69  	return dst, nil
    70  }
    71  
    72  func (h *stdHeaders) Copy(_ context.Context, dst Headers) error {
    73  	for _, pair := range h.makePairs() {
    74  		//nolint:forcetypeassert
    75  		key := pair.Key.(string)
    76  		if err := dst.Set(key, pair.Value); err != nil {
    77  			return fmt.Errorf(`failed to set header %q: %w`, key, err)
    78  		}
    79  	}
    80  	return nil
    81  }
    82  
    83  func (h *stdHeaders) Merge(ctx context.Context, h2 Headers) (Headers, error) {
    84  	h3 := NewHeaders()
    85  
    86  	if h != nil {
    87  		if err := h.Copy(ctx, h3); err != nil {
    88  			return nil, fmt.Errorf(`failed to copy headers from receiver: %w`, err)
    89  		}
    90  	}
    91  
    92  	if h2 != nil {
    93  		if err := h2.Copy(ctx, h3); err != nil {
    94  			return nil, fmt.Errorf(`failed to copy headers from argument: %w`, err)
    95  		}
    96  	}
    97  
    98  	return h3, nil
    99  }
   100  
   101  func (h *stdHeaders) Encode() ([]byte, error) {
   102  	buf, err := json.Marshal(h)
   103  	if err != nil {
   104  		return nil, fmt.Errorf(`failed to marshal headers to JSON prior to encoding: %w`, err)
   105  	}
   106  
   107  	return base64.Encode(buf), nil
   108  }
   109  
   110  func (h *stdHeaders) Decode(buf []byte) error {
   111  	// base64 json string -> json object representation of header
   112  	decoded, err := base64.Decode(buf)
   113  	if err != nil {
   114  		return fmt.Errorf(`failed to unmarshal base64 encoded buffer: %w`, err)
   115  	}
   116  
   117  	if err := json.Unmarshal(decoded, h); err != nil {
   118  		return fmt.Errorf(`failed to unmarshal buffer: %w`, err)
   119  	}
   120  
   121  	return nil
   122  }