github.com/opcr-io/oras-go/v2@v2.0.0-20231122155130-eb4260d8a0ae/registry/remote/auth/cache_test.go (about)

     1  /*
     2  Copyright The ORAS Authors.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6  
     7  http://www.apache.org/licenses/LICENSE-2.0
     8  
     9  Unless required by applicable law or agreed to in writing, software
    10  distributed under the License is distributed on an "AS IS" BASIS,
    11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  See the License for the specific language governing permissions and
    13  limitations under the License.
    14  */
    15  
    16  package auth
    17  
    18  import (
    19  	"context"
    20  	"errors"
    21  	"strconv"
    22  	"sync"
    23  	"sync/atomic"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/opcr-io/oras-go/v2/errdef"
    28  )
    29  
    30  func Test_concurrentCache_GetScheme(t *testing.T) {
    31  	cache := NewCache()
    32  
    33  	// no entry in the cache
    34  	ctx := context.Background()
    35  	registry := "localhost:5000"
    36  	got, err := cache.GetScheme(ctx, registry)
    37  	if want := errdef.ErrNotFound; err != want {
    38  		t.Fatalf("concurrentCache.GetScheme() error = %v, wantErr %v", err, want)
    39  	}
    40  	if got != SchemeUnknown {
    41  		t.Errorf("concurrentCache.GetScheme() = %v, want %v", got, SchemeUnknown)
    42  	}
    43  
    44  	// set an cache entry
    45  	scheme := SchemeBasic
    46  	_, err = cache.Set(ctx, registry, scheme, "", func(c context.Context) (string, error) {
    47  		return "foo", nil
    48  	})
    49  	if err != nil {
    50  		t.Fatalf("failed to set cache: %v", err)
    51  	}
    52  
    53  	// verify cache
    54  	got, err = cache.GetScheme(ctx, registry)
    55  	if err != nil {
    56  		t.Fatalf("concurrentCache.GetScheme() error = %v", err)
    57  	}
    58  	if got != scheme {
    59  		t.Errorf("concurrentCache.GetScheme() = %v, want %v", got, scheme)
    60  	}
    61  
    62  	// set cache entry again
    63  	scheme = SchemeBearer
    64  	_, err = cache.Set(ctx, registry, scheme, "", func(c context.Context) (string, error) {
    65  		return "bar", nil
    66  	})
    67  	if err != nil {
    68  		t.Fatalf("failed to set cache: %v", err)
    69  	}
    70  
    71  	// verify cache
    72  	got, err = cache.GetScheme(ctx, registry)
    73  	if err != nil {
    74  		t.Fatalf("concurrentCache.GetScheme() error = %v", err)
    75  	}
    76  	if got != scheme {
    77  		t.Errorf("concurrentCache.GetScheme() = %v, want %v", got, scheme)
    78  	}
    79  
    80  	// test other registry
    81  	registry = "localhost:5001"
    82  	got, err = cache.GetScheme(ctx, registry)
    83  	if want := errdef.ErrNotFound; err != want {
    84  		t.Fatalf("concurrentCache.GetScheme() error = %v, wantErr %v", err, want)
    85  	}
    86  	if got != SchemeUnknown {
    87  		t.Errorf("concurrentCache.GetScheme() = %v, want %v", got, SchemeUnknown)
    88  	}
    89  }
    90  
    91  func Test_concurrentCache_GetToken(t *testing.T) {
    92  	cache := NewCache()
    93  
    94  	// no entry in the cache
    95  	ctx := context.Background()
    96  	registry := "localhost:5000"
    97  	scheme := SchemeBearer
    98  	key := "1st key"
    99  	got, err := cache.GetToken(ctx, registry, scheme, key)
   100  	if want := errdef.ErrNotFound; err != want {
   101  		t.Fatalf("concurrentCache.GetToken() error = %v, wantErr %v", err, want)
   102  	}
   103  	if got != "" {
   104  		t.Errorf("concurrentCache.GetToken() = %v, want %v", got, "")
   105  	}
   106  
   107  	// set an cache entry
   108  	_, err = cache.Set(ctx, registry, scheme, key, func(c context.Context) (string, error) {
   109  		return "foo", nil
   110  	})
   111  	if err != nil {
   112  		t.Fatalf("failed to set cache: %v", err)
   113  	}
   114  
   115  	// verify cache
   116  	got, err = cache.GetToken(ctx, registry, scheme, key)
   117  	if err != nil {
   118  		t.Fatalf("concurrentCache.GetToken() error = %v", err)
   119  	}
   120  	if want := "foo"; got != want {
   121  		t.Errorf("concurrentCache.GetToken() = %v, want %v", got, want)
   122  	}
   123  
   124  	// set cache entry again
   125  	_, err = cache.Set(ctx, registry, scheme, key, func(c context.Context) (string, error) {
   126  		return "bar", nil
   127  	})
   128  	if err != nil {
   129  		t.Fatalf("failed to set cache: %v", err)
   130  	}
   131  
   132  	// verify cache
   133  	got, err = cache.GetToken(ctx, registry, scheme, key)
   134  	if err != nil {
   135  		t.Fatalf("concurrentCache.GetToken() error = %v", err)
   136  	}
   137  	if want := "bar"; got != want {
   138  		t.Errorf("concurrentCache.GetToken() = %v, want %v", got, want)
   139  	}
   140  
   141  	// test other key
   142  	key = "2nd key"
   143  	got, err = cache.GetToken(ctx, registry, scheme, key)
   144  	if want := errdef.ErrNotFound; err != want {
   145  		t.Fatalf("concurrentCache.GetToken() error = %v, wantErr %v", err, want)
   146  	}
   147  	if got != "" {
   148  		t.Errorf("concurrentCache.GetToken() = %v, want %v", got, "")
   149  	}
   150  
   151  	// set an cache entry
   152  	_, err = cache.Set(ctx, registry, scheme, key, func(c context.Context) (string, error) {
   153  		return "hello world", nil
   154  	})
   155  	if err != nil {
   156  		t.Fatalf("failed to set cache: %v", err)
   157  	}
   158  
   159  	// verify cache
   160  	got, err = cache.GetToken(ctx, registry, scheme, key)
   161  	if err != nil {
   162  		t.Fatalf("concurrentCache.GetToken() error = %v", err)
   163  	}
   164  	if want := "hello world"; got != want {
   165  		t.Errorf("concurrentCache.GetToken() = %v, want %v", got, want)
   166  	}
   167  
   168  	// verify cache of the previous key as keys should not interference each
   169  	// other
   170  	key = "1st key"
   171  	got, err = cache.GetToken(ctx, registry, scheme, key)
   172  	if err != nil {
   173  		t.Fatalf("concurrentCache.GetToken() error = %v", err)
   174  	}
   175  	if want := "bar"; got != want {
   176  		t.Errorf("concurrentCache.GetToken() = %v, want %v", got, want)
   177  	}
   178  
   179  	// test other registry
   180  	registry = "localhost:5001"
   181  	got, err = cache.GetToken(ctx, registry, scheme, key)
   182  	if want := errdef.ErrNotFound; err != want {
   183  		t.Fatalf("concurrentCache.GetToken() error = %v, wantErr %v", err, want)
   184  	}
   185  	if got != "" {
   186  		t.Errorf("concurrentCache.GetToken() = %v, want %v", got, "")
   187  	}
   188  
   189  	// set an cache entry
   190  	_, err = cache.Set(ctx, registry, scheme, key, func(c context.Context) (string, error) {
   191  		return "foobar", nil
   192  	})
   193  	if err != nil {
   194  		t.Fatalf("failed to set cache: %v", err)
   195  	}
   196  
   197  	// verify cache
   198  	got, err = cache.GetToken(ctx, registry, scheme, key)
   199  	if err != nil {
   200  		t.Fatalf("concurrentCache.GetToken() error = %v", err)
   201  	}
   202  	if want := "foobar"; got != want {
   203  		t.Errorf("concurrentCache.GetToken() = %v, want %v", got, want)
   204  	}
   205  
   206  	// verify cache of the previous registry as registries should not
   207  	// interference each other
   208  	registry = "localhost:5000"
   209  	got, err = cache.GetToken(ctx, registry, scheme, key)
   210  	if err != nil {
   211  		t.Fatalf("concurrentCache.GetToken() error = %v", err)
   212  	}
   213  	if want := "bar"; got != want {
   214  		t.Errorf("concurrentCache.GetToken() = %v, want %v", got, want)
   215  	}
   216  
   217  	// test other scheme
   218  	scheme = SchemeBasic
   219  	got, err = cache.GetToken(ctx, registry, scheme, key)
   220  	if want := errdef.ErrNotFound; err != want {
   221  		t.Fatalf("concurrentCache.GetToken() error = %v, wantErr %v", err, want)
   222  	}
   223  	if got != "" {
   224  		t.Errorf("concurrentCache.GetToken() = %v, want %v", got, "")
   225  	}
   226  
   227  	// set an cache entry
   228  	_, err = cache.Set(ctx, registry, scheme, key, func(c context.Context) (string, error) {
   229  		return "new scheme", nil
   230  	})
   231  	if err != nil {
   232  		t.Fatalf("failed to set cache: %v", err)
   233  	}
   234  
   235  	// verify cache
   236  	got, err = cache.GetToken(ctx, registry, scheme, key)
   237  	if err != nil {
   238  		t.Fatalf("concurrentCache.GetToken() error = %v", err)
   239  	}
   240  	if want := "new scheme"; got != want {
   241  		t.Errorf("concurrentCache.GetToken() = %v, want %v", got, want)
   242  	}
   243  
   244  	// cache of the previous scheme should be invalidated due to scheme change.
   245  	got, err = cache.GetToken(ctx, registry, SchemeBearer, key)
   246  	if want := errdef.ErrNotFound; err != want {
   247  		t.Fatalf("concurrentCache.GetToken() error = %v, wantErr %v", err, want)
   248  	}
   249  	if got != "" {
   250  		t.Errorf("concurrentCache.GetToken() = %v, want %v", got, "")
   251  	}
   252  }
   253  
   254  func Test_concurrentCache_Set(t *testing.T) {
   255  	registries := []string{
   256  		"localhost:5000",
   257  		"localhost:5001",
   258  	}
   259  	scheme := SchemeBearer
   260  	keys := []string{
   261  		"foo",
   262  		"bar",
   263  	}
   264  	count := len(registries) * len(keys)
   265  
   266  	ctx := context.Background()
   267  	cache := NewCache()
   268  
   269  	// first round of fetch
   270  	fetch := func(i int) func(context.Context) (string, error) {
   271  		return func(context.Context) (string, error) {
   272  			return strconv.Itoa(i), nil
   273  		}
   274  	}
   275  	var wg sync.WaitGroup
   276  	for i := 0; i < 10; i++ {
   277  		for j := 0; j < count; j++ {
   278  			wg.Add(1)
   279  			go func(i int) {
   280  				defer wg.Done()
   281  				registry := registries[i&1]
   282  				key := keys[(i>>1)&1]
   283  				got, err := cache.Set(ctx, registry, scheme, key, fetch(i))
   284  				if err != nil {
   285  					t.Errorf("concurrentCache.Set() error = %v", err)
   286  				}
   287  				if want := strconv.Itoa(i); got != want {
   288  					t.Errorf("concurrentCache.Set() = %v, want %v", got, want)
   289  				}
   290  			}(j)
   291  		}
   292  	}
   293  	wg.Wait()
   294  
   295  	for i := 0; i < count; i++ {
   296  		registry := registries[i&1]
   297  		key := keys[(i>>1)&1]
   298  
   299  		gotScheme, err := cache.GetScheme(ctx, registry)
   300  		if err != nil {
   301  			t.Fatalf("concurrentCache.GetScheme() error = %v", err)
   302  		}
   303  		if want := scheme; gotScheme != want {
   304  			t.Errorf("concurrentCache.GetScheme() = %v, want %v", gotScheme, want)
   305  		}
   306  
   307  		gotToken, err := cache.GetToken(ctx, registry, scheme, key)
   308  		if err != nil {
   309  			t.Fatalf("concurrentCache.GetToken() error = %v", err)
   310  		}
   311  		if want := strconv.Itoa(i); gotToken != want {
   312  			t.Errorf("concurrentCache.GetToken() = %v, want %v", gotToken, want)
   313  		}
   314  	}
   315  
   316  	// repeated fetch
   317  	fetch = func(i int) func(context.Context) (string, error) {
   318  		return func(context.Context) (string, error) {
   319  			return strconv.Itoa(i) + " repeated", nil
   320  		}
   321  	}
   322  	for i := 0; i < 10; i++ {
   323  		for j := 0; j < count; j++ {
   324  			wg.Add(1)
   325  			go func(i int) {
   326  				defer wg.Done()
   327  				registry := registries[i&1]
   328  				key := keys[(i>>1)&1]
   329  				got, err := cache.Set(ctx, registry, scheme, key, fetch(i))
   330  				if err != nil {
   331  					t.Errorf("concurrentCache.Set() error = %v", err)
   332  				}
   333  				if want := strconv.Itoa(i) + " repeated"; got != want {
   334  					t.Errorf("concurrentCache.Set() = %v, want %v", got, want)
   335  				}
   336  			}(j)
   337  		}
   338  	}
   339  	wg.Wait()
   340  
   341  	for i := 0; i < count; i++ {
   342  		registry := registries[i&1]
   343  		key := keys[(i>>1)&1]
   344  
   345  		gotScheme, err := cache.GetScheme(ctx, registry)
   346  		if err != nil {
   347  			t.Fatalf("concurrentCache.GetScheme() error = %v", err)
   348  		}
   349  		if want := scheme; gotScheme != want {
   350  			t.Errorf("concurrentCache.GetScheme() = %v, want %v", gotScheme, want)
   351  		}
   352  
   353  		gotToken, err := cache.GetToken(ctx, registry, scheme, key)
   354  		if err != nil {
   355  			t.Fatalf("concurrentCache.GetToken() error = %v", err)
   356  		}
   357  		if want := strconv.Itoa(i) + " repeated"; gotToken != want {
   358  			t.Errorf("concurrentCache.GetToken() = %v, want %v", gotToken, want)
   359  		}
   360  	}
   361  }
   362  
   363  func Test_concurrentCache_Set_Fetch_Once(t *testing.T) {
   364  	registries := []string{
   365  		"localhost:5000",
   366  		"localhost:5001",
   367  	}
   368  	schemes := []Scheme{
   369  		SchemeBasic,
   370  		SchemeBearer,
   371  	}
   372  	keys := []string{
   373  		"foo",
   374  		"bar",
   375  	}
   376  	count := make([]int64, len(registries)*len(schemes)*len(keys))
   377  	fetch := func(i int) func(context.Context) (string, error) {
   378  		return func(context.Context) (string, error) {
   379  			time.Sleep(500 * time.Millisecond)
   380  			atomic.AddInt64(&count[i], 1)
   381  			return strconv.Itoa(i), nil
   382  		}
   383  	}
   384  
   385  	ctx := context.Background()
   386  	cache := NewCache()
   387  
   388  	// first round of fetch
   389  	var wg sync.WaitGroup
   390  	for i := 0; i < 10; i++ {
   391  		for j := 0; j < len(count); j++ {
   392  			wg.Add(1)
   393  			go func(i int) {
   394  				defer wg.Done()
   395  				registry := registries[i&1]
   396  				scheme := schemes[(i>>1)&1]
   397  				key := keys[(i>>2)&1]
   398  				got, err := cache.Set(ctx, registry, scheme, key, fetch(i))
   399  				if err != nil {
   400  					t.Errorf("concurrentCache.Set() error = %v", err)
   401  				}
   402  				if want := strconv.Itoa(i); got != want {
   403  					t.Errorf("concurrentCache.Set() = %v, want %v", got, want)
   404  				}
   405  			}(j)
   406  		}
   407  	}
   408  	wg.Wait()
   409  
   410  	for i := 0; i < len(count); i++ {
   411  		if got := count[i]; got != 1 {
   412  			t.Errorf("fetch is called more than once: %d", got)
   413  		}
   414  	}
   415  
   416  	// repeated fetch
   417  	for i := 0; i < 10; i++ {
   418  		for j := 0; j < len(count); j++ {
   419  			wg.Add(1)
   420  			go func(i int) {
   421  				defer wg.Done()
   422  				registry := registries[i&1]
   423  				scheme := schemes[(i>>1)&1]
   424  				key := keys[(i>>2)&1]
   425  				got, err := cache.Set(ctx, registry, scheme, key, fetch(i))
   426  				if err != nil {
   427  					t.Errorf("concurrentCache.Set() error = %v", err)
   428  				}
   429  				if want := strconv.Itoa(i); got != want {
   430  					t.Errorf("concurrentCache.Set() = %v, want %v", got, want)
   431  				}
   432  			}(j)
   433  		}
   434  	}
   435  	wg.Wait()
   436  
   437  	for i := 0; i < len(count); i++ {
   438  		if got := count[i]; got != 2 {
   439  			t.Errorf("fetch is called more than once: %d", got)
   440  		}
   441  	}
   442  }
   443  
   444  func Test_concurrentCache_Set_Fetch_Failure(t *testing.T) {
   445  	registries := []string{
   446  		"localhost:5000",
   447  		"localhost:5001",
   448  	}
   449  	scheme := SchemeBearer
   450  	keys := []string{
   451  		"foo",
   452  		"bar",
   453  	}
   454  	count := len(registries) * len(keys)
   455  
   456  	ctx := context.Background()
   457  	cache := NewCache()
   458  
   459  	// first round of fetch
   460  	fetch := func(i int) func(context.Context) (string, error) {
   461  		return func(context.Context) (string, error) {
   462  			return "", errors.New(strconv.Itoa(i))
   463  		}
   464  	}
   465  	var wg sync.WaitGroup
   466  	for i := 0; i < 10; i++ {
   467  		for j := 0; j < count; j++ {
   468  			wg.Add(1)
   469  			go func(i int) {
   470  				defer wg.Done()
   471  				registry := registries[i&1]
   472  				key := keys[(i>>1)&1]
   473  				_, err := cache.Set(ctx, registry, scheme, key, fetch(i))
   474  				if want := strconv.Itoa(i); err == nil || err.Error() != want {
   475  					t.Errorf("concurrentCache.Set() error = %v, wantErr %v", err, want)
   476  				}
   477  			}(j)
   478  		}
   479  	}
   480  	wg.Wait()
   481  
   482  	for i := 0; i < count; i++ {
   483  		registry := registries[i&1]
   484  		key := keys[(i>>1)&1]
   485  
   486  		_, err := cache.GetScheme(ctx, registry)
   487  		if want := errdef.ErrNotFound; err != want {
   488  			t.Fatalf("concurrentCache.GetScheme() error = %v, wantErr %v", err, want)
   489  		}
   490  
   491  		_, err = cache.GetToken(ctx, registry, scheme, key)
   492  		if want := errdef.ErrNotFound; err != want {
   493  			t.Errorf("concurrentCache.GetToken() error = %v, wantErr %v", err, want)
   494  		}
   495  	}
   496  
   497  	// repeated fetch
   498  	fetch = func(i int) func(context.Context) (string, error) {
   499  		return func(context.Context) (string, error) {
   500  			return strconv.Itoa(i), nil
   501  		}
   502  	}
   503  	for i := 0; i < 10; i++ {
   504  		for j := 0; j < count; j++ {
   505  			wg.Add(1)
   506  			go func(i int) {
   507  				defer wg.Done()
   508  				registry := registries[i&1]
   509  				key := keys[(i>>1)&1]
   510  				got, err := cache.Set(ctx, registry, scheme, key, fetch(i))
   511  				if err != nil {
   512  					t.Errorf("concurrentCache.Set() error = %v", err)
   513  				}
   514  				if want := strconv.Itoa(i); got != want {
   515  					t.Errorf("concurrentCache.Set() = %v, want %v", got, want)
   516  				}
   517  			}(j)
   518  		}
   519  	}
   520  	wg.Wait()
   521  
   522  	for i := 0; i < count; i++ {
   523  		registry := registries[i&1]
   524  		key := keys[(i>>1)&1]
   525  
   526  		gotScheme, err := cache.GetScheme(ctx, registry)
   527  		if err != nil {
   528  			t.Fatalf("concurrentCache.GetScheme() error = %v", err)
   529  		}
   530  		if want := scheme; gotScheme != want {
   531  			t.Errorf("concurrentCache.GetScheme() = %v, want %v", gotScheme, want)
   532  		}
   533  
   534  		gotToken, err := cache.GetToken(ctx, registry, scheme, key)
   535  		if err != nil {
   536  			t.Fatalf("concurrentCache.GetToken() error = %v", err)
   537  		}
   538  		if want := strconv.Itoa(i); gotToken != want {
   539  			t.Errorf("concurrentCache.GetToken() = %v, want %v", gotToken, want)
   540  		}
   541  	}
   542  }