github.com/lestrrat-go/jwx/v2@v2.0.21/jwk/set.go (about)

     1  package jwk
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"sort"
     8  
     9  	"github.com/lestrrat-go/iter/arrayiter"
    10  	"github.com/lestrrat-go/iter/mapiter"
    11  	"github.com/lestrrat-go/jwx/v2/internal/json"
    12  	"github.com/lestrrat-go/jwx/v2/internal/pool"
    13  )
    14  
    15  const keysKey = `keys` // appease linter
    16  
    17  // NewSet creates and empty `jwk.Set` object
    18  func NewSet() Set {
    19  	return &set{
    20  		privateParams: make(map[string]interface{}),
    21  	}
    22  }
    23  
    24  func (s *set) Set(n string, v interface{}) error {
    25  	s.mu.RLock()
    26  	defer s.mu.RUnlock()
    27  
    28  	if n == keysKey {
    29  		vl, ok := v.([]Key)
    30  		if !ok {
    31  			return fmt.Errorf(`value for field "keys" must be []jwk.Key`)
    32  		}
    33  		s.keys = vl
    34  		return nil
    35  	}
    36  
    37  	s.privateParams[n] = v
    38  	return nil
    39  }
    40  
    41  func (s *set) Get(n string) (interface{}, bool) {
    42  	s.mu.RLock()
    43  	defer s.mu.RUnlock()
    44  
    45  	v, ok := s.privateParams[n]
    46  	return v, ok
    47  }
    48  
    49  func (s *set) Key(idx int) (Key, bool) {
    50  	s.mu.RLock()
    51  	defer s.mu.RUnlock()
    52  
    53  	if idx >= 0 && idx < len(s.keys) {
    54  		return s.keys[idx], true
    55  	}
    56  	return nil, false
    57  }
    58  
    59  func (s *set) Len() int {
    60  	s.mu.RLock()
    61  	defer s.mu.RUnlock()
    62  
    63  	return len(s.keys)
    64  }
    65  
    66  // indexNL is Index(), but without the locking
    67  func (s *set) indexNL(key Key) int {
    68  	for i, k := range s.keys {
    69  		if k == key {
    70  			return i
    71  		}
    72  	}
    73  	return -1
    74  }
    75  
    76  func (s *set) Index(key Key) int {
    77  	s.mu.RLock()
    78  	defer s.mu.RUnlock()
    79  
    80  	return s.indexNL(key)
    81  }
    82  
    83  func (s *set) AddKey(key Key) error {
    84  	s.mu.Lock()
    85  	defer s.mu.Unlock()
    86  
    87  	if i := s.indexNL(key); i > -1 {
    88  		return fmt.Errorf(`(jwk.Set).AddKey: key already exists`)
    89  	}
    90  	s.keys = append(s.keys, key)
    91  	return nil
    92  }
    93  
    94  func (s *set) Remove(name string) error {
    95  	s.mu.Lock()
    96  	defer s.mu.Unlock()
    97  
    98  	delete(s.privateParams, name)
    99  	return nil
   100  }
   101  
   102  func (s *set) RemoveKey(key Key) error {
   103  	s.mu.Lock()
   104  	defer s.mu.Unlock()
   105  
   106  	for i, k := range s.keys {
   107  		if k == key {
   108  			switch i {
   109  			case 0:
   110  				s.keys = s.keys[1:]
   111  			case len(s.keys) - 1:
   112  				s.keys = s.keys[:i]
   113  			default:
   114  				s.keys = append(s.keys[:i], s.keys[i+1:]...)
   115  			}
   116  			return nil
   117  		}
   118  	}
   119  	return fmt.Errorf(`(jwk.Set).RemoveKey: specified key does not exist in set`)
   120  }
   121  
   122  func (s *set) Clear() error {
   123  	s.mu.Lock()
   124  	defer s.mu.Unlock()
   125  
   126  	s.keys = nil
   127  	s.privateParams = make(map[string]interface{})
   128  	return nil
   129  }
   130  
   131  func (s *set) Keys(ctx context.Context) KeyIterator {
   132  	ch := make(chan *KeyPair, s.Len())
   133  	go iterate(ctx, s.keys, ch)
   134  	return arrayiter.New(ch)
   135  }
   136  
   137  func iterate(ctx context.Context, keys []Key, ch chan *KeyPair) {
   138  	defer close(ch)
   139  
   140  	for i, key := range keys {
   141  		pair := &KeyPair{Index: i, Value: key}
   142  		select {
   143  		case <-ctx.Done():
   144  			return
   145  		case ch <- pair:
   146  		}
   147  	}
   148  }
   149  
   150  func (s *set) MarshalJSON() ([]byte, error) {
   151  	s.mu.RLock()
   152  	defer s.mu.RUnlock()
   153  
   154  	buf := pool.GetBytesBuffer()
   155  	defer pool.ReleaseBytesBuffer(buf)
   156  	enc := json.NewEncoder(buf)
   157  
   158  	fields := []string{keysKey}
   159  	for k := range s.privateParams {
   160  		fields = append(fields, k)
   161  	}
   162  	sort.Strings(fields)
   163  
   164  	buf.WriteByte('{')
   165  	for i, field := range fields {
   166  		if i > 0 {
   167  			buf.WriteByte(',')
   168  		}
   169  		fmt.Fprintf(buf, `%q:`, field)
   170  		if field != keysKey {
   171  			if err := enc.Encode(s.privateParams[field]); err != nil {
   172  				return nil, fmt.Errorf(`failed to marshal field %q: %w`, field, err)
   173  			}
   174  		} else {
   175  			buf.WriteByte('[')
   176  			for j, k := range s.keys {
   177  				if j > 0 {
   178  					buf.WriteByte(',')
   179  				}
   180  				if err := enc.Encode(k); err != nil {
   181  					return nil, fmt.Errorf(`failed to marshal key #%d: %w`, i, err)
   182  				}
   183  			}
   184  			buf.WriteByte(']')
   185  		}
   186  	}
   187  	buf.WriteByte('}')
   188  
   189  	ret := make([]byte, buf.Len())
   190  	copy(ret, buf.Bytes())
   191  	return ret, nil
   192  }
   193  
   194  func (s *set) UnmarshalJSON(data []byte) error {
   195  	s.mu.Lock()
   196  	defer s.mu.Unlock()
   197  
   198  	s.privateParams = make(map[string]interface{})
   199  	s.keys = nil
   200  
   201  	var options []ParseOption
   202  	var ignoreParseError bool
   203  	if dc := s.dc; dc != nil {
   204  		if localReg := dc.Registry(); localReg != nil {
   205  			options = append(options, withLocalRegistry(localReg))
   206  		}
   207  		ignoreParseError = dc.IgnoreParseError()
   208  	}
   209  
   210  	var sawKeysField bool
   211  	dec := json.NewDecoder(bytes.NewReader(data))
   212  LOOP:
   213  	for {
   214  		tok, err := dec.Token()
   215  		if err != nil {
   216  			return fmt.Errorf(`error reading token: %w`, err)
   217  		}
   218  
   219  		switch tok := tok.(type) {
   220  		case json.Delim:
   221  			// Assuming we're doing everything correctly, we should ONLY
   222  			// get either '{' or '}' here.
   223  			if tok == '}' { // End of object
   224  				break LOOP
   225  			} else if tok != '{' {
   226  				return fmt.Errorf(`expected '{', but got '%c'`, tok)
   227  			}
   228  		case string:
   229  			switch tok {
   230  			case "keys":
   231  				sawKeysField = true
   232  				var list []json.RawMessage
   233  				if err := dec.Decode(&list); err != nil {
   234  					return fmt.Errorf(`failed to decode "keys": %w`, err)
   235  				}
   236  
   237  				for i, keysrc := range list {
   238  					key, err := ParseKey(keysrc, options...)
   239  					if err != nil {
   240  						if !ignoreParseError {
   241  							return fmt.Errorf(`failed to decode key #%d in "keys": %w`, i, err)
   242  						}
   243  						continue
   244  					}
   245  					s.keys = append(s.keys, key)
   246  				}
   247  			default:
   248  				var v interface{}
   249  				if err := dec.Decode(&v); err != nil {
   250  					return fmt.Errorf(`failed to decode value for key %q: %w`, tok, err)
   251  				}
   252  				s.privateParams[tok] = v
   253  			}
   254  		}
   255  	}
   256  
   257  	// This is really silly, but we can only detect the
   258  	// lack of the "keys" field after going through the
   259  	// entire object once
   260  	// Not checking for len(s.keys) == 0, because it could be
   261  	// an empty key set
   262  	if !sawKeysField {
   263  		key, err := ParseKey(data, options...)
   264  		if err != nil {
   265  			return fmt.Errorf(`failed to parse sole key in key set`)
   266  		}
   267  		s.keys = append(s.keys, key)
   268  	}
   269  	return nil
   270  }
   271  
   272  func (s *set) LookupKeyID(kid string) (Key, bool) {
   273  	s.mu.RLock()
   274  	defer s.mu.RUnlock()
   275  
   276  	n := s.Len()
   277  	for i := 0; i < n; i++ {
   278  		key, ok := s.Key(i)
   279  		if !ok {
   280  			return nil, false
   281  		}
   282  		if key.KeyID() == kid {
   283  			return key, true
   284  		}
   285  	}
   286  	return nil, false
   287  }
   288  
   289  func (s *set) DecodeCtx() DecodeCtx {
   290  	s.mu.RLock()
   291  	defer s.mu.RUnlock()
   292  	return s.dc
   293  }
   294  
   295  func (s *set) SetDecodeCtx(dc DecodeCtx) {
   296  	s.mu.Lock()
   297  	defer s.mu.Unlock()
   298  	s.dc = dc
   299  }
   300  
   301  func (s *set) Clone() (Set, error) {
   302  	s2 := &set{}
   303  
   304  	s.mu.RLock()
   305  	defer s.mu.RUnlock()
   306  
   307  	s2.keys = make([]Key, len(s.keys))
   308  	copy(s2.keys, s.keys)
   309  	return s2, nil
   310  }
   311  
   312  func (s *set) makePairs() []*HeaderPair {
   313  	pairs := make([]*HeaderPair, 0, len(s.privateParams))
   314  	for k, v := range s.privateParams {
   315  		pairs = append(pairs, &HeaderPair{Key: k, Value: v})
   316  	}
   317  	sort.Slice(pairs, func(i, j int) bool {
   318  		//nolint:forcetypeassert
   319  		return pairs[i].Key.(string) < pairs[j].Key.(string)
   320  	})
   321  	return pairs
   322  }
   323  
   324  func (s *set) Iterate(ctx context.Context) HeaderIterator {
   325  	pairs := s.makePairs()
   326  	ch := make(chan *HeaderPair, len(pairs))
   327  	go func(ctx context.Context, ch chan *HeaderPair, pairs []*HeaderPair) {
   328  		defer close(ch)
   329  		for _, pair := range pairs {
   330  			select {
   331  			case <-ctx.Done():
   332  				return
   333  			case ch <- pair:
   334  			}
   335  		}
   336  	}(ctx, ch, pairs)
   337  	return mapiter.New(ch)
   338  }