github.com/Schaudge/grailbase@v0.0.0-20240223061707-44c758a471c0/sync/loadingcache/value_test.go (about)

     1  package loadingcache
     2  
     3  import (
     4  	"context"
     5  	"sync"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/Schaudge/grailbase/errors"
    10  	"github.com/stretchr/testify/assert"
    11  	"github.com/stretchr/testify/require"
    12  )
    13  
    14  const recentUnixTimestamp = 1600000000 // 2020-09-13 12:26:40 +0000 UTC
    15  
    16  func TestValueExpiration(t *testing.T) {
    17  	var (
    18  		ctx   = context.Background()
    19  		v     Value
    20  		clock fakeClock
    21  	)
    22  	v.setClock(clock.Now)
    23  
    24  	clock.Set(time.Unix(recentUnixTimestamp, 0))
    25  	var v1 int
    26  	require.NoError(t, v.GetOrLoad(ctx, &v1, func(_ context.Context, opts *LoadOpts) error {
    27  		clock.Add(2 * time.Hour)
    28  		v1 = 1
    29  		opts.CacheFor(time.Hour)
    30  		return nil
    31  	}))
    32  	assert.Equal(t, 1, v1)
    33  
    34  	clock.Add(5 * time.Minute)
    35  	var v2 int
    36  	require.NoError(t, v.GetOrLoad(ctx, &v2, loadFail))
    37  	assert.Equal(t, 1, v1)
    38  	assert.Equal(t, 1, v2)
    39  
    40  	clock.Add(time.Hour)
    41  	var v3 int
    42  	require.NoError(t, v.GetOrLoad(ctx, &v3, func(_ context.Context, opts *LoadOpts) error {
    43  		v3 = 3
    44  		opts.CacheForever()
    45  		return nil
    46  	}))
    47  	assert.Equal(t, 1, v1)
    48  	assert.Equal(t, 1, v2)
    49  	assert.Equal(t, 3, v3)
    50  
    51  	clock.Add(10000 * time.Hour)
    52  	var v4 int
    53  	assert.NoError(t, v.GetOrLoad(ctx, &v4, loadFail))
    54  	assert.Equal(t, 1, v1)
    55  	assert.Equal(t, 1, v2)
    56  	assert.Equal(t, 3, v3)
    57  	assert.Equal(t, 3, v4)
    58  }
    59  
    60  func TestValueExpiration0(t *testing.T) {
    61  	var (
    62  		ctx   = context.Background()
    63  		v     Value
    64  		clock fakeClock
    65  	)
    66  	v.setClock(clock.Now)
    67  
    68  	clock.Set(time.Unix(recentUnixTimestamp, 0))
    69  	var v1 int
    70  	require.NoError(t, v.GetOrLoad(ctx, &v1, func(_ context.Context, opts *LoadOpts) error {
    71  		v1 = 1
    72  		return nil
    73  	}))
    74  	assert.Equal(t, 1, v1)
    75  
    76  	// Run v2 at the same time as v1. It should not get a cached result because v1's cache time was 0.
    77  	var v2 int
    78  	require.NoError(t, v.GetOrLoad(ctx, &v2, func(_ context.Context, opts *LoadOpts) error {
    79  		v2 = 2
    80  		opts.CacheFor(time.Hour)
    81  		return nil
    82  	}))
    83  	assert.Equal(t, 1, v1)
    84  	assert.Equal(t, 2, v2)
    85  }
    86  
    87  func TestValueNil(t *testing.T) {
    88  	var (
    89  		ctx   = context.Background()
    90  		v     *Value
    91  		clock fakeClock
    92  	)
    93  	v.setClock(clock.Now)
    94  
    95  	clock.Set(time.Unix(recentUnixTimestamp, 0))
    96  	var v1 int
    97  	require.NoError(t, v.GetOrLoad(ctx, &v1, func(_ context.Context, opts *LoadOpts) error {
    98  		clock.Add(2 * time.Hour)
    99  		v1 = 1
   100  		opts.CacheForever()
   101  		return nil
   102  	}))
   103  	assert.Equal(t, 1, v1)
   104  
   105  	var v2 int
   106  	assert.Error(t, v.GetOrLoad(ctx, &v2, loadFail))
   107  	assert.Equal(t, 1, v1)
   108  
   109  	clock.Add(time.Hour)
   110  	var v3 int
   111  	require.NoError(t, v.GetOrLoad(ctx, &v3, func(_ context.Context, opts *LoadOpts) error {
   112  		v3 = 3
   113  		opts.CacheForever()
   114  		return nil
   115  	}))
   116  	assert.Equal(t, 1, v1)
   117  	assert.Equal(t, 3, v3)
   118  }
   119  
   120  func TestValueCancellation(t *testing.T) {
   121  	var (
   122  		v     Value
   123  		clock fakeClock
   124  	)
   125  	v.setClock(clock.Now)
   126  	clock.Set(time.Unix(recentUnixTimestamp, 0))
   127  	const cacheDuration = time.Minute
   128  
   129  	type participant struct {
   130  		cancel context.CancelFunc
   131  		// participant waits for these before proceeding.
   132  		waitGet, waitLoad chan<- struct{}
   133  		// participant returns these signals of its progress.
   134  		loadStarted <-chan struct{}
   135  		result      <-chan error
   136  	}
   137  	makeParticipant := func(dst *int, loaded int) participant {
   138  		ctx, cancel := context.WithCancel(context.Background())
   139  		var (
   140  			waitGet     = make(chan struct{})
   141  			waitLoad    = make(chan struct{})
   142  			loadStarted = make(chan struct{})
   143  			result      = make(chan error)
   144  		)
   145  		go func() {
   146  			<-waitGet
   147  			result <- v.GetOrLoad(ctx, dst, func(ctx context.Context, opts *LoadOpts) error {
   148  				close(loadStarted)
   149  				select {
   150  				case <-ctx.Done():
   151  					return ctx.Err()
   152  				case <-waitLoad:
   153  					*dst = loaded
   154  					opts.CacheFor(cacheDuration)
   155  					return nil
   156  				}
   157  			})
   158  		}()
   159  		return participant{cancel, waitGet, waitLoad, loadStarted, result}
   160  	}
   161  
   162  	// Start participant 1 and wait for its cache load to start.
   163  	var v1 int
   164  	p1 := makeParticipant(&v1, 1)
   165  	close(p1.waitGet)
   166  	<-p1.loadStarted
   167  
   168  	// Start participant 2, then cancel its context and wait for its error.
   169  	var v2 int
   170  	p2 := makeParticipant(&v2, 2)
   171  	p2.waitGet <- struct{}{}
   172  	p2.cancel()
   173  	err2 := <-p2.result
   174  	assert.True(t, errors.Is(errors.Canceled, err2), "got: %v", err2)
   175  
   176  	// Start participant 3, then cancel participant 1 and wait for 3 to start loading.
   177  	var v3 int
   178  	p3 := makeParticipant(&v3, 3)
   179  	p3.waitGet <- struct{}{}
   180  	p1.cancel()
   181  	<-p3.loadStarted
   182  	err1 := <-p1.result
   183  	assert.True(t, errors.Is(errors.Canceled, err1), "got: %v", err1)
   184  
   185  	// Start participant 4 later (according to clock).
   186  	var v4 int
   187  	p4 := makeParticipant(&v4, 4)
   188  	clock.Add(time.Second)
   189  	p4.waitGet <- struct{}{}
   190  
   191  	// Let participant 3 finish loading and wait for results.
   192  	close(p3.waitLoad)
   193  	require.NoError(t, <-p3.result)
   194  	require.NoError(t, <-p4.result)
   195  	assert.Equal(t, 3, v3)
   196  	assert.Equal(t, 3, v4) // Got cached result.
   197  
   198  	// Start participant 5 past cache time so it recomputes.
   199  	var v5 int
   200  	p5 := makeParticipant(&v5, 5)
   201  	clock.Add(cacheDuration * 2)
   202  	p5.waitGet <- struct{}{}
   203  	close(p5.waitLoad)
   204  	require.NoError(t, <-p5.result)
   205  	assert.Equal(t, 3, v3)
   206  	assert.Equal(t, 3, v4)
   207  	assert.Equal(t, 5, v5)
   208  }
   209  
   210  type fakeClock struct {
   211  	mu  sync.Mutex
   212  	now time.Time
   213  }
   214  
   215  func (c *fakeClock) Now() time.Time {
   216  	c.mu.Lock()
   217  	defer c.mu.Unlock()
   218  	return c.now
   219  }
   220  
   221  func (c *fakeClock) Set(now time.Time) {
   222  	c.mu.Lock()
   223  	defer c.mu.Unlock()
   224  	c.now = now
   225  }
   226  
   227  func (c *fakeClock) Add(d time.Duration) {
   228  	c.mu.Lock()
   229  	defer c.mu.Unlock()
   230  	c.now = c.now.Add(d)
   231  }
   232  
   233  func loadFail(context.Context, *LoadOpts) error {
   234  	panic("unexpected load")
   235  }