github.com/abayer/test-infra@v0.0.5/ghproxy/ghcache/coalesce_test.go (about)

     1  /*
     2  Copyright 2018 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package ghcache
    18  
    19  import (
    20  	"bytes"
    21  	"errors"
    22  	"io/ioutil"
    23  	"net/http"
    24  	"net/url"
    25  	"reflect"
    26  	"sync"
    27  	"testing"
    28  	"time"
    29  )
    30  
    31  // testDelegate is a fake upstream transport delegate that logs hits by URI and
    32  // will wait to respond to requests until signaled unless the request has
    33  // a header specifying it should be responded to immediately.
    34  type testDelegate struct {
    35  	beginResponding *sync.Cond
    36  
    37  	hitsLock sync.Mutex
    38  	hits     map[string]int
    39  }
    40  
    41  func (t *testDelegate) RoundTrip(req *http.Request) (*http.Response, error) {
    42  	t.hitsLock.Lock()
    43  	t.hits[req.URL.Path] += 1
    44  	t.hitsLock.Unlock()
    45  
    46  	if req.Header.Get("test-immediate-response") == "" {
    47  		t.beginResponding.L.Lock()
    48  		t.beginResponding.Wait()
    49  		t.beginResponding.L.Unlock()
    50  	}
    51  	return &http.Response{
    52  			Body: ioutil.NopCloser(bytes.NewBufferString("Response")),
    53  		},
    54  		nil
    55  }
    56  
    57  func TestRoundTrip(t *testing.T) {
    58  	// Check that only 1 request goes to upstream if there are concurrent requests.
    59  	delegate := &testDelegate{
    60  		hits:            make(map[string]int),
    61  		beginResponding: sync.NewCond(&sync.Mutex{}),
    62  	}
    63  	coalesce := &requestCoalescer{
    64  		keys:     make(map[string]*responseWaiter),
    65  		delegate: delegate,
    66  	}
    67  	wg := sync.WaitGroup{}
    68  	wg.Add(100)
    69  	for i := 0; i < 100; i++ {
    70  		go func() {
    71  			runRequest(t, coalesce, "/resource1", false)
    72  			wg.Done()
    73  		}()
    74  	}
    75  	// There is a race here. We need to wait for all requests to be made to the
    76  	// coalescer before letting upstream respond, but we don't have a way of
    77  	// knowing when all requests have actually started waiting on the
    78  	// responseWaiter...
    79  	time.Sleep(time.Second * 5)
    80  
    81  	// Check that requests for different resources are not blocked.
    82  	runRequest(t, coalesce, "/resource2", true) // Doesn't return until timeout or success.
    83  	delegate.beginResponding.Broadcast()
    84  
    85  	// Check that non concurrent requests all hit upstream.
    86  	runRequest(t, coalesce, "/resource2", true)
    87  
    88  	wg.Wait()
    89  	expectedHits := map[string]int{"/resource1": 1, "/resource2": 2}
    90  	if !reflect.DeepEqual(delegate.hits, expectedHits) {
    91  		t.Errorf("Unexpected hit count(s). Expected %v, but got %v.", expectedHits, delegate.hits)
    92  	}
    93  }
    94  
    95  func runRequest(t *testing.T, rt http.RoundTripper, uri string, immediate bool) {
    96  	res := make(chan error)
    97  	run := func() {
    98  		u, err := url.Parse("http://foo.com" + uri)
    99  		if err != nil {
   100  			res <- err
   101  		}
   102  		req, err := http.NewRequest(http.MethodGet, u.String(), nil)
   103  		if err != nil {
   104  			res <- err
   105  		}
   106  		if immediate {
   107  			req.Header.Set("test-immediate-response", "true")
   108  		}
   109  		resp, err := rt.RoundTrip(req)
   110  		if err != nil {
   111  			res <- err
   112  		} else if b, err := ioutil.ReadAll(resp.Body); err != nil {
   113  			res <- err
   114  		} else if string(b) != "Response" {
   115  			res <- errors.New("unexpected response value")
   116  		}
   117  		res <- nil
   118  	}
   119  	go run()
   120  	select {
   121  	case <-time.After(time.Second * 10):
   122  		t.Errorf("Request for %q timed out.", uri)
   123  	case err := <-res:
   124  		if err != nil {
   125  			t.Errorf("Request error: %v.", err)
   126  		}
   127  	}
   128  }