github.com/lestrrat-go/jwx/v2@v2.0.21/jwk/set.go (about) 1 package jwk 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "sort" 8 9 "github.com/lestrrat-go/iter/arrayiter" 10 "github.com/lestrrat-go/iter/mapiter" 11 "github.com/lestrrat-go/jwx/v2/internal/json" 12 "github.com/lestrrat-go/jwx/v2/internal/pool" 13 ) 14 15 const keysKey = `keys` // appease linter 16 17 // NewSet creates and empty `jwk.Set` object 18 func NewSet() Set { 19 return &set{ 20 privateParams: make(map[string]interface{}), 21 } 22 } 23 24 func (s *set) Set(n string, v interface{}) error { 25 s.mu.RLock() 26 defer s.mu.RUnlock() 27 28 if n == keysKey { 29 vl, ok := v.([]Key) 30 if !ok { 31 return fmt.Errorf(`value for field "keys" must be []jwk.Key`) 32 } 33 s.keys = vl 34 return nil 35 } 36 37 s.privateParams[n] = v 38 return nil 39 } 40 41 func (s *set) Get(n string) (interface{}, bool) { 42 s.mu.RLock() 43 defer s.mu.RUnlock() 44 45 v, ok := s.privateParams[n] 46 return v, ok 47 } 48 49 func (s *set) Key(idx int) (Key, bool) { 50 s.mu.RLock() 51 defer s.mu.RUnlock() 52 53 if idx >= 0 && idx < len(s.keys) { 54 return s.keys[idx], true 55 } 56 return nil, false 57 } 58 59 func (s *set) Len() int { 60 s.mu.RLock() 61 defer s.mu.RUnlock() 62 63 return len(s.keys) 64 } 65 66 // indexNL is Index(), but without the locking 67 func (s *set) indexNL(key Key) int { 68 for i, k := range s.keys { 69 if k == key { 70 return i 71 } 72 } 73 return -1 74 } 75 76 func (s *set) Index(key Key) int { 77 s.mu.RLock() 78 defer s.mu.RUnlock() 79 80 return s.indexNL(key) 81 } 82 83 func (s *set) AddKey(key Key) error { 84 s.mu.Lock() 85 defer s.mu.Unlock() 86 87 if i := s.indexNL(key); i > -1 { 88 return fmt.Errorf(`(jwk.Set).AddKey: key already exists`) 89 } 90 s.keys = append(s.keys, key) 91 return nil 92 } 93 94 func (s *set) Remove(name string) error { 95 s.mu.Lock() 96 defer s.mu.Unlock() 97 98 delete(s.privateParams, name) 99 return nil 100 } 101 102 func (s *set) RemoveKey(key Key) error { 103 s.mu.Lock() 104 defer s.mu.Unlock() 105 106 for i, k := range s.keys { 107 if k == key { 108 switch i { 109 case 0: 110 s.keys = s.keys[1:] 111 case len(s.keys) - 1: 112 s.keys = s.keys[:i] 113 default: 114 s.keys = append(s.keys[:i], s.keys[i+1:]...) 115 } 116 return nil 117 } 118 } 119 return fmt.Errorf(`(jwk.Set).RemoveKey: specified key does not exist in set`) 120 } 121 122 func (s *set) Clear() error { 123 s.mu.Lock() 124 defer s.mu.Unlock() 125 126 s.keys = nil 127 s.privateParams = make(map[string]interface{}) 128 return nil 129 } 130 131 func (s *set) Keys(ctx context.Context) KeyIterator { 132 ch := make(chan *KeyPair, s.Len()) 133 go iterate(ctx, s.keys, ch) 134 return arrayiter.New(ch) 135 } 136 137 func iterate(ctx context.Context, keys []Key, ch chan *KeyPair) { 138 defer close(ch) 139 140 for i, key := range keys { 141 pair := &KeyPair{Index: i, Value: key} 142 select { 143 case <-ctx.Done(): 144 return 145 case ch <- pair: 146 } 147 } 148 } 149 150 func (s *set) MarshalJSON() ([]byte, error) { 151 s.mu.RLock() 152 defer s.mu.RUnlock() 153 154 buf := pool.GetBytesBuffer() 155 defer pool.ReleaseBytesBuffer(buf) 156 enc := json.NewEncoder(buf) 157 158 fields := []string{keysKey} 159 for k := range s.privateParams { 160 fields = append(fields, k) 161 } 162 sort.Strings(fields) 163 164 buf.WriteByte('{') 165 for i, field := range fields { 166 if i > 0 { 167 buf.WriteByte(',') 168 } 169 fmt.Fprintf(buf, `%q:`, field) 170 if field != keysKey { 171 if err := enc.Encode(s.privateParams[field]); err != nil { 172 return nil, fmt.Errorf(`failed to marshal field %q: %w`, field, err) 173 } 174 } else { 175 buf.WriteByte('[') 176 for j, k := range s.keys { 177 if j > 0 { 178 buf.WriteByte(',') 179 } 180 if err := enc.Encode(k); err != nil { 181 return nil, fmt.Errorf(`failed to marshal key #%d: %w`, i, err) 182 } 183 } 184 buf.WriteByte(']') 185 } 186 } 187 buf.WriteByte('}') 188 189 ret := make([]byte, buf.Len()) 190 copy(ret, buf.Bytes()) 191 return ret, nil 192 } 193 194 func (s *set) UnmarshalJSON(data []byte) error { 195 s.mu.Lock() 196 defer s.mu.Unlock() 197 198 s.privateParams = make(map[string]interface{}) 199 s.keys = nil 200 201 var options []ParseOption 202 var ignoreParseError bool 203 if dc := s.dc; dc != nil { 204 if localReg := dc.Registry(); localReg != nil { 205 options = append(options, withLocalRegistry(localReg)) 206 } 207 ignoreParseError = dc.IgnoreParseError() 208 } 209 210 var sawKeysField bool 211 dec := json.NewDecoder(bytes.NewReader(data)) 212 LOOP: 213 for { 214 tok, err := dec.Token() 215 if err != nil { 216 return fmt.Errorf(`error reading token: %w`, err) 217 } 218 219 switch tok := tok.(type) { 220 case json.Delim: 221 // Assuming we're doing everything correctly, we should ONLY 222 // get either '{' or '}' here. 223 if tok == '}' { // End of object 224 break LOOP 225 } else if tok != '{' { 226 return fmt.Errorf(`expected '{', but got '%c'`, tok) 227 } 228 case string: 229 switch tok { 230 case "keys": 231 sawKeysField = true 232 var list []json.RawMessage 233 if err := dec.Decode(&list); err != nil { 234 return fmt.Errorf(`failed to decode "keys": %w`, err) 235 } 236 237 for i, keysrc := range list { 238 key, err := ParseKey(keysrc, options...) 239 if err != nil { 240 if !ignoreParseError { 241 return fmt.Errorf(`failed to decode key #%d in "keys": %w`, i, err) 242 } 243 continue 244 } 245 s.keys = append(s.keys, key) 246 } 247 default: 248 var v interface{} 249 if err := dec.Decode(&v); err != nil { 250 return fmt.Errorf(`failed to decode value for key %q: %w`, tok, err) 251 } 252 s.privateParams[tok] = v 253 } 254 } 255 } 256 257 // This is really silly, but we can only detect the 258 // lack of the "keys" field after going through the 259 // entire object once 260 // Not checking for len(s.keys) == 0, because it could be 261 // an empty key set 262 if !sawKeysField { 263 key, err := ParseKey(data, options...) 264 if err != nil { 265 return fmt.Errorf(`failed to parse sole key in key set`) 266 } 267 s.keys = append(s.keys, key) 268 } 269 return nil 270 } 271 272 func (s *set) LookupKeyID(kid string) (Key, bool) { 273 s.mu.RLock() 274 defer s.mu.RUnlock() 275 276 n := s.Len() 277 for i := 0; i < n; i++ { 278 key, ok := s.Key(i) 279 if !ok { 280 return nil, false 281 } 282 if key.KeyID() == kid { 283 return key, true 284 } 285 } 286 return nil, false 287 } 288 289 func (s *set) DecodeCtx() DecodeCtx { 290 s.mu.RLock() 291 defer s.mu.RUnlock() 292 return s.dc 293 } 294 295 func (s *set) SetDecodeCtx(dc DecodeCtx) { 296 s.mu.Lock() 297 defer s.mu.Unlock() 298 s.dc = dc 299 } 300 301 func (s *set) Clone() (Set, error) { 302 s2 := &set{} 303 304 s.mu.RLock() 305 defer s.mu.RUnlock() 306 307 s2.keys = make([]Key, len(s.keys)) 308 copy(s2.keys, s.keys) 309 return s2, nil 310 } 311 312 func (s *set) makePairs() []*HeaderPair { 313 pairs := make([]*HeaderPair, 0, len(s.privateParams)) 314 for k, v := range s.privateParams { 315 pairs = append(pairs, &HeaderPair{Key: k, Value: v}) 316 } 317 sort.Slice(pairs, func(i, j int) bool { 318 //nolint:forcetypeassert 319 return pairs[i].Key.(string) < pairs[j].Key.(string) 320 }) 321 return pairs 322 } 323 324 func (s *set) Iterate(ctx context.Context) HeaderIterator { 325 pairs := s.makePairs() 326 ch := make(chan *HeaderPair, len(pairs)) 327 go func(ctx context.Context, ch chan *HeaderPair, pairs []*HeaderPair) { 328 defer close(ch) 329 for _, pair := range pairs { 330 select { 331 case <-ctx.Done(): 332 return 333 case ch <- pair: 334 } 335 } 336 }(ctx, ch, pairs) 337 return mapiter.New(ch) 338 }