github.com/myhau/pulumi/pkg/v3@v3.70.2-0.20221116134521-f2775972e587/backend/httpstate/token_source_test.go (about)

     1  // Copyright 2016-2022, Pulumi Corporation.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package httpstate
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"runtime"
    21  	"sync"
    22  	"testing"
    23  
    24  	"time"
    25  
    26  	"github.com/stretchr/testify/assert"
    27  )
    28  
    29  func TestTokenSource(t *testing.T) {
    30  	if runtime.GOOS == "windows" {
    31  		t.Skip("Flaky on Windows CI workers due to the use of timer+Sleep")
    32  	}
    33  	t.Parallel()
    34  
    35  	ctx := context.TODO()
    36  	dur := 20 * time.Millisecond
    37  	backend := &testTokenBackend{tokens: map[string]time.Time{}}
    38  
    39  	tok0, tok0Expires := backend.NewToken(dur)
    40  	ts, err := newTokenSource(ctx, tok0, tok0Expires, dur, backend.Refresh)
    41  	assert.NoError(t, err)
    42  	defer ts.Close()
    43  
    44  	for i := 0; i < 32; i++ {
    45  		tok, err := ts.GetToken()
    46  		assert.NoError(t, err)
    47  		assert.NoError(t, backend.VerifyToken(tok))
    48  		t.Logf("STEP: %d, TOKEN: %s", i, tok)
    49  
    50  		// tok0 initially
    51  		if i == 0 {
    52  			assert.Equal(t, tok0, tok)
    53  		}
    54  
    55  		// definitely a fresh token by step 16
    56  		// allow some leeway due to time.Sleep concurrency
    57  		if i > 16 {
    58  			assert.NotEqual(t, tok0, tok)
    59  		}
    60  
    61  		time.Sleep(dur / 16)
    62  	}
    63  }
    64  
    65  func TestTokenSourceWithQuicklyExpiringInitialToken(t *testing.T) {
    66  	if runtime.GOOS == "windows" {
    67  		t.Skip("Flaky on Windows CI workers due to the use of timer+Sleep")
    68  	}
    69  	t.Parallel()
    70  
    71  	ctx := context.TODO()
    72  	dur := 20 * time.Millisecond
    73  	backend := &testTokenBackend{tokens: map[string]time.Time{}}
    74  
    75  	tok0, tok0Expires := backend.NewToken(dur / 10)
    76  	ts, err := newTokenSource(ctx, tok0, tok0Expires, dur, backend.Refresh)
    77  	assert.NoError(t, err)
    78  	defer ts.Close()
    79  
    80  	for i := 0; i < 8; i++ {
    81  		tok, err := ts.GetToken()
    82  		assert.NoError(t, err)
    83  		assert.NoError(t, backend.VerifyToken(tok))
    84  		t.Logf("STEP: %d, TOKEN: %s", i, tok)
    85  		time.Sleep(dur / 16)
    86  	}
    87  }
    88  
    89  type testTokenBackend struct {
    90  	mu      sync.Mutex
    91  	counter int
    92  	tokens  map[string]time.Time
    93  }
    94  
    95  func (ts *testTokenBackend) NewToken(duration time.Duration) (string, time.Time) {
    96  	ts.mu.Lock()
    97  	defer ts.mu.Unlock()
    98  	return ts.newTokenInner(duration)
    99  }
   100  
   101  func (ts *testTokenBackend) Refresh(
   102  	ctx context.Context,
   103  	duration time.Duration,
   104  	currentToken string) (string, time.Time, error) {
   105  	ts.mu.Lock()
   106  	defer ts.mu.Unlock()
   107  	if err := ts.verifyTokenInner(currentToken); err != nil {
   108  		return "", time.Time{}, err
   109  	}
   110  	tok, expires := ts.newTokenInner(duration)
   111  	return tok, expires, nil
   112  }
   113  
   114  func (ts *testTokenBackend) TokenName(refreshCount int) string {
   115  	ts.mu.Lock()
   116  	defer ts.mu.Unlock()
   117  	return ts.tokenNameInner(refreshCount)
   118  }
   119  
   120  func (ts *testTokenBackend) VerifyToken(token string) error {
   121  	ts.mu.Lock()
   122  	defer ts.mu.Unlock()
   123  	return ts.verifyTokenInner(token)
   124  }
   125  
   126  func (ts *testTokenBackend) newTokenInner(duration time.Duration) (string, time.Time) {
   127  	now := time.Now()
   128  	ts.counter++
   129  	tok := ts.tokenNameInner(ts.counter)
   130  	expires := now.Add(duration)
   131  	ts.tokens[tok] = now.Add(duration)
   132  	return tok, expires
   133  }
   134  
   135  func (ts *testTokenBackend) tokenNameInner(refreshCount int) string {
   136  	return fmt.Sprintf("token-%d", ts.counter)
   137  }
   138  
   139  func (ts *testTokenBackend) verifyTokenInner(token string) error {
   140  	now := time.Now()
   141  	expires, gotCurrentToken := ts.tokens[token]
   142  	if !gotCurrentToken {
   143  		return fmt.Errorf("Unknown token: %v", token)
   144  	}
   145  
   146  	if now.After(expires) {
   147  		return fmt.Errorf("Expired token %v (%v past expiration)",
   148  			token, now.Sub(expires))
   149  	}
   150  	return nil
   151  }