sigs.k8s.io/prow@v0.0.0-20240503223140-c5e374dc7eb1/cmd/ghproxy/ghproxy_test.go (about)

     1  /*
     2  Copyright 2020 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package main
    18  
    19  import (
    20  	"bytes"
    21  	"crypto/rand"
    22  	"crypto/rsa"
    23  	"encoding/json"
    24  	"fmt"
    25  	"io"
    26  	"net/http"
    27  	"net/http/httptest"
    28  	"net/url"
    29  	"os"
    30  	"path"
    31  	"testing"
    32  	"time"
    33  
    34  	"github.com/sirupsen/logrus"
    35  
    36  	"sigs.k8s.io/prow/pkg/ghcache"
    37  	"sigs.k8s.io/prow/pkg/github"
    38  )
    39  
    40  func TestDiskCachePruning(t *testing.T) {
    41  	t.Parallel()
    42  	cacheDir := t.TempDir()
    43  	o := &options{
    44  		dir:                 cacheDir,
    45  		maxConcurrency:      25,
    46  		pushGatewayInterval: time.Minute,
    47  		upstreamParsed:      &url.URL{},
    48  		timeout:             30,
    49  	}
    50  
    51  	now := time.Now()
    52  	github.TimeNow = func() time.Time { return now }
    53  
    54  	// Five minutes so the test has sufficient time to finish
    55  	// but also sufficient room until the app token which is
    56  	// always valid for 10 minutes expires.
    57  	expiryDuration := 5 * time.Minute
    58  	roundTripper := func(r *http.Request) (*http.Response, error) {
    59  		t.Logf("got a request for path %s", r.URL.Path)
    60  		switch r.URL.Path {
    61  		case "/app":
    62  			return jsonResponse(github.App{Slug: "app-slug"}, 200)
    63  		case "/app/installations":
    64  			return jsonResponse([]github.AppInstallation{{Account: github.User{Login: "org"}}}, 200)
    65  		case "/app/installations/0/access_tokens":
    66  			return jsonResponse(github.AppInstallationToken{Token: "abc", ExpiresAt: now.Add(expiryDuration)}, 201)
    67  		case "/repos/org/repo/git/refs/dev":
    68  			return jsonResponse(github.GetRefResult{}, 200)
    69  		default:
    70  			return nil, fmt.Errorf("got unexpected request for %s", r.URL.Path)
    71  		}
    72  	}
    73  
    74  	rsaKey, err := rsa.GenerateKey(rand.Reader, 512)
    75  	if err != nil {
    76  		t.Fatalf("Failed to generate RSA key: %v", err)
    77  	}
    78  
    79  	server := httptest.NewServer(proxy(o, httpRoundTripper(roundTripper), time.Hour))
    80  	t.Cleanup(server.Close)
    81  	_, _, client, err := github.NewClientFromOptions(logrus.Fields{}, github.ClientOptions{
    82  		MaxRetries:      1,
    83  		Censor:          func(b []byte) []byte { return b },
    84  		AppID:           "123",
    85  		AppPrivateKey:   func() *rsa.PrivateKey { return rsaKey },
    86  		Bases:           []string{server.URL},
    87  		GraphqlEndpoint: server.URL,
    88  	})
    89  	if err != nil {
    90  		t.Fatalf("failed to construct github client: %v", err)
    91  	}
    92  
    93  	if _, err := client.GetRef("org", "repo", "dev"); err != nil {
    94  		t.Fatalf("GetRef failed: %v", err)
    95  	}
    96  
    97  	numberPartitions, err := getNumberOfCachePartitions(cacheDir)
    98  	if err != nil {
    99  		t.Fatalf("failed to get number of cache paritions: %v", err)
   100  	}
   101  	if numberPartitions != 2 {
   102  		t.Fatalf("expected two cache paritions, one for the app and one for the app installation, got %d", numberPartitions)
   103  	}
   104  
   105  	ghcache.Prune(cacheDir, func() time.Time { return now.Add(expiryDuration).Add(time.Second) })
   106  
   107  	numberPartitions, err = getNumberOfCachePartitions(cacheDir)
   108  	if err != nil {
   109  		t.Fatalf("failed to get number of cache paritions: %v", err)
   110  	}
   111  	if numberPartitions != 1 {
   112  		t.Errorf("expected one cache partition for the app as the one for the installation should be cleaned up, got  %d", numberPartitions)
   113  	}
   114  }
   115  
   116  func getNumberOfCachePartitions(cacheDir string) (int, error) {
   117  	var result int
   118  	for _, suffix := range []string{"temp", "data"} {
   119  		entries, err := os.ReadDir(path.Join(cacheDir, suffix))
   120  		if err != nil {
   121  			return result, fmt.Errorf("faield to list: %w", err)
   122  		}
   123  		if result == 0 {
   124  			result = len(entries)
   125  			continue
   126  		}
   127  		if n := len(entries); n != result {
   128  			return result, fmt.Errorf("temp and datadir don't have the same number of partitions: %d vs %d", result, n)
   129  		}
   130  	}
   131  
   132  	return result, nil
   133  }
   134  
   135  type httpRoundTripper func(*http.Request) (*http.Response, error)
   136  
   137  func (rt httpRoundTripper) RoundTrip(r *http.Request) (*http.Response, error) {
   138  	return rt(r)
   139  }
   140  
   141  func jsonResponse(body interface{}, statusCode int) (*http.Response, error) {
   142  	serialized, err := json.Marshal(body)
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  	return &http.Response{StatusCode: statusCode, Body: io.NopCloser(bytes.NewBuffer(serialized)), Header: http.Header{}}, nil
   147  }