github.com/google/osv-scalibr@v0.4.1/clients/datasource/cache_test.go (about)

     1  // Copyright 2025 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package datasource_test
    16  
    17  import (
    18  	"maps"
    19  	"sync"
    20  	"sync/atomic"
    21  	"testing"
    22  
    23  	"github.com/google/osv-scalibr/clients/datasource"
    24  )
    25  
    26  func TestRequestCache(t *testing.T) {
    27  	// Test that RequestCache calls each function exactly once per key.
    28  	requestCache := datasource.NewRequestCache[int, int]()
    29  
    30  	const numKeys = 20
    31  	const requestsPerKey = 50
    32  
    33  	var wg sync.WaitGroup
    34  	var fnCalls [numKeys]int32
    35  
    36  	for i := range numKeys {
    37  		for range requestsPerKey {
    38  			wg.Add(1)
    39  			go func() {
    40  				t.Helper()
    41  				_, _ = requestCache.Get(i, func() (int, error) {
    42  					// Count how many times this function gets called for this key,
    43  					// then return the key as the value.
    44  					atomic.AddInt32(&fnCalls[i], 1)
    45  					return i, nil
    46  				})
    47  				wg.Done()
    48  			}()
    49  		}
    50  	}
    51  
    52  	wg.Wait() // Make sure all the goroutines are finished
    53  
    54  	for i, c := range fnCalls {
    55  		if c != 1 {
    56  			t.Errorf("RequestCache Get(%d) function called %d times", i, c)
    57  		}
    58  	}
    59  
    60  	cacheMap := requestCache.GetMap()
    61  	if len(cacheMap) != numKeys {
    62  		t.Errorf("RequestCache GetMap length was %d, expected %d", len(cacheMap), numKeys)
    63  	}
    64  
    65  	for k, v := range cacheMap {
    66  		if k != v {
    67  			t.Errorf("RequestCache GetMap key %d has unexpected value %d", k, v)
    68  		}
    69  	}
    70  }
    71  
    72  func TestRequestCacheSetMap(t *testing.T) {
    73  	requestCache := datasource.NewRequestCache[string, string]()
    74  	requestCache.SetMap(map[string]string{"foo": "foo1", "bar": "bar2"})
    75  	fn := func() (string, error) { return "CACHE MISS", nil }
    76  
    77  	want := map[string]string{
    78  		"foo": "foo1",
    79  		"bar": "bar2",
    80  		"baz": "CACHE MISS",
    81  		"FOO": "CACHE MISS",
    82  	}
    83  
    84  	for k, v := range want {
    85  		got, err := requestCache.Get(k, fn)
    86  		if err != nil {
    87  			t.Errorf("Get(%v) returned an error: %v", v, err)
    88  		} else if got != v {
    89  			t.Errorf("Get(%v) got: %v, want %v", k, got, v)
    90  		}
    91  	}
    92  
    93  	gotMap := requestCache.GetMap()
    94  	if !maps.Equal(want, gotMap) {
    95  		t.Errorf("GetMap() got %v, want %v", gotMap, want)
    96  	}
    97  }