sigs.k8s.io/prow@v0.0.0-20240503223140-c5e374dc7eb1/pkg/interrupts/interrupts_test.go (about)

     1  /*
     2  Copyright 2019 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package interrupts
    18  
    19  import (
    20  	"context"
    21  	"crypto/rand"
    22  	"crypto/rsa"
    23  	"crypto/tls"
    24  	"crypto/x509"
    25  	"crypto/x509/pkix"
    26  	"encoding/pem"
    27  	"fmt"
    28  	"math/big"
    29  	"net"
    30  	"net/http"
    31  	"os"
    32  	"sync"
    33  	"syscall"
    34  	"testing"
    35  	"time"
    36  )
    37  
    38  // interrupt allows for tests to trigger an interrupt as needed
    39  var interrupt = make(chan os.Signal, 1)
    40  
    41  // this init will be executed before that in the code package,
    42  // so we can inject our implementation of the interrupt channel
    43  func init() {
    44  	signalsLock.Lock()
    45  	gracePeriod = time.Second
    46  	signals = func() <-chan os.Signal {
    47  		return interrupt
    48  	}
    49  	signalsLock.Unlock()
    50  }
    51  
    52  // instead of building a mechanism to reset/re-initialize the interrupt
    53  // manager which would only be used in testing, we write an integration
    54  // test that only fires the mock interrupt once
    55  func TestInterrupts(t *testing.T) {
    56  	// we need to lock around values used to test otherwise the test
    57  	// goroutine will race with the workers
    58  	lock := sync.Mutex{}
    59  
    60  	ctx := Context()
    61  	var ctxDone bool
    62  	go func() {
    63  		<-ctx.Done()
    64  
    65  		lock.Lock()
    66  		ctxDone = true
    67  		lock.Unlock()
    68  	}()
    69  
    70  	var workDone bool
    71  	var workCancelled bool
    72  	work := func(ctx context.Context) {
    73  		lock.Lock()
    74  		workDone = true
    75  		lock.Unlock()
    76  
    77  		<-ctx.Done()
    78  
    79  		lock.Lock()
    80  		workCancelled = true
    81  		lock.Unlock()
    82  	}
    83  	Run(work)
    84  
    85  	// we cannot use httptest mocks for the tests here as they expect
    86  	// to be started by the httptest package itself, not by a downstream
    87  	// caller like the interrupts library
    88  	var serverCalled bool
    89  	var serverCancelled bool
    90  	listener, err := net.Listen("tcp", "127.0.0.1:")
    91  	if err != nil {
    92  		t.Fatalf("could not listen on random port: %v", err)
    93  	}
    94  	if err := listener.Close(); err != nil {
    95  		t.Fatalf("could close listener: %v", err)
    96  	}
    97  	server := &http.Server{Addr: listener.Addr().String(), Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
    98  		lock.Lock()
    99  		serverCalled = true
   100  		lock.Unlock()
   101  	})}
   102  	server.RegisterOnShutdown(func() {
   103  		lock.Lock()
   104  		serverCancelled = true
   105  		lock.Unlock()
   106  	})
   107  	ListenAndServe(server, time.Second)
   108  	// wait for the server to start
   109  	time.Sleep(100 * time.Millisecond)
   110  	if _, err := http.Get("http://" + listener.Addr().String()); err != nil {
   111  		t.Errorf("could not reach server registered with ListenAndServe(): %v", err)
   112  	}
   113  
   114  	var tlsServerCalled bool
   115  	var tlsServerCancelled bool
   116  	tlsListener, err := net.Listen("tcp", "127.0.0.1:")
   117  	if err != nil {
   118  		t.Fatalf("could not listen on random port: %v", err)
   119  	}
   120  	if err := tlsListener.Close(); err != nil {
   121  		t.Fatalf("could close listener: %v", err)
   122  	}
   123  	tlsServer := &http.Server{Addr: tlsListener.Addr().String(), Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) {
   124  		lock.Lock()
   125  		tlsServerCalled = true
   126  		lock.Unlock()
   127  	})}
   128  	tlsServer.RegisterOnShutdown(func() {
   129  		lock.Lock()
   130  		tlsServerCancelled = true
   131  		lock.Unlock()
   132  	})
   133  	cert, key, err := generateCerts("127.0.0.1")
   134  	if err != nil {
   135  		t.Fatalf("could not generate cert and key for TLS server: %v", err)
   136  	}
   137  	ListenAndServeTLS(tlsServer, cert, key, time.Second)
   138  	client := &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}}
   139  	// wait for the server to start
   140  	time.Sleep(100 * time.Millisecond)
   141  	if _, err := client.Get("https://" + tlsListener.Addr().String()); err != nil {
   142  		t.Errorf("could not reach server registered with ListenAndServeTLS(): %v", err)
   143  	}
   144  
   145  	var intervalCalls int
   146  	interval := func() time.Duration {
   147  		lock.Lock()
   148  		intervalCalls++
   149  		lock.Unlock()
   150  		if intervalCalls > 2 {
   151  			return 10 * time.Hour
   152  		}
   153  		return 1 * time.Nanosecond
   154  	}
   155  	var tickCalls int
   156  	tick := func() {
   157  		lock.Lock()
   158  		tickCalls++
   159  		lock.Unlock()
   160  	}
   161  	Tick(tick, interval)
   162  	// writing a test that functions correctly here without being susceptible
   163  	// to timing flakes is challenging. Using time.Sleep like this does have
   164  	// that downside, but the sleep time is many orders of magnitude higher
   165  	// than the tick intervals and the amount of time taken to execute the
   166  	// test as well, so it is going to be exceedingly rare that scheduling of
   167  	// the test process will cause a flake here from timing. The test cannot
   168  	// use synchronized approaches to waiting here as we do not know how long
   169  	// we must wait. The test must have enough time to ask for the interval
   170  	// as many times as we expect it to, but if we only wait for that we fail
   171  	// to catch the cases where the interval is requested too many times.
   172  	time.Sleep(100 * time.Millisecond)
   173  
   174  	var onInterruptCalled bool
   175  	OnInterrupt(func() {
   176  		lock.Lock()
   177  		onInterruptCalled = true
   178  		lock.Unlock()
   179  	})
   180  
   181  	done := sync.WaitGroup{}
   182  	done.Add(1)
   183  	go func() {
   184  		WaitForGracefulShutdown()
   185  		time.Sleep(1 * time.Millisecond) // Ensure graceful shutdown channel closes
   186  		done.Done()
   187  	}()
   188  
   189  	if onInterruptCalled {
   190  		t.Error("work registered with OnInterrupt() was executed before interrupt")
   191  	}
   192  
   193  	// trigger the interrupt
   194  	interrupt <- syscall.Signal(1)
   195  	// wait for graceful shutdown to occur
   196  	done.Wait()
   197  
   198  	lock.Lock()
   199  	if !ctxDone {
   200  		t.Error("context from Context() was not cancelled on interrupt")
   201  	}
   202  	if !workDone {
   203  		t.Error("work registered with Run() was not executed")
   204  	}
   205  	if !workCancelled {
   206  		t.Error("work registered with Run() was not cancelled on interrupt")
   207  	}
   208  	if !serverCalled {
   209  		t.Error("server registered with ListenAndServe() was not serving")
   210  	}
   211  	if !serverCancelled {
   212  		t.Error("server registered with ListenAndServe() was not cancelled on interrupt")
   213  	}
   214  	if !tlsServerCalled {
   215  		t.Error("server registered with ListenAndServeTLS() was not serving")
   216  	}
   217  	if !tlsServerCancelled {
   218  		t.Error("server registered with ListenAndServeTLS() was not cancelled on interrupt")
   219  	}
   220  	if tickCalls != 2 {
   221  		t.Errorf("work registered with Tick() was called %d times, not %d; interval was requested %d times", tickCalls, 2, intervalCalls)
   222  	}
   223  	if !onInterruptCalled {
   224  		t.Error("work registered with OnInterrupt() was not executed on interrupt")
   225  	}
   226  	lock.Unlock()
   227  }
   228  
   229  func generateCerts(url string) (string, string, error) {
   230  	priv, err := rsa.GenerateKey(rand.Reader, 2048)
   231  	if err != nil {
   232  		return "", "", fmt.Errorf("failed to generate private key: %w", err)
   233  	}
   234  
   235  	serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
   236  	serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
   237  	if err != nil {
   238  		return "", "", fmt.Errorf("failed to generate serial number: %s", err)
   239  	}
   240  
   241  	template := x509.Certificate{
   242  		SerialNumber: serialNumber,
   243  		Subject: pkix.Name{
   244  			Organization: []string{"Acme Co"},
   245  		},
   246  		NotBefore: time.Now(),
   247  		NotAfter:  time.Now().Add(1 * time.Hour),
   248  
   249  		KeyUsage:              x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
   250  		ExtKeyUsage:           []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
   251  		BasicConstraintsValid: true,
   252  
   253  		IPAddresses: []net.IP{net.ParseIP(url)},
   254  	}
   255  
   256  	derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
   257  	if err != nil {
   258  		return "", "", fmt.Errorf("failed to create certificate: %s", err)
   259  	}
   260  
   261  	certOut, err := os.CreateTemp("", "cert.pem")
   262  	if err != nil {
   263  		return "", "", fmt.Errorf("failed to open cert.pem for writing: %s", err)
   264  	}
   265  	if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
   266  		return "", "", fmt.Errorf("failed to write data to cert.pem: %s", err)
   267  	}
   268  	if err := certOut.Close(); err != nil {
   269  		return "", "", fmt.Errorf("error closing cert.pem: %s", err)
   270  	}
   271  
   272  	keyOut, err := os.CreateTemp("", "key.pem")
   273  	if err != nil {
   274  		return "", "", fmt.Errorf("failed to open key.pem for writing: %w", err)
   275  	}
   276  	privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
   277  	if err != nil {
   278  		return "", "", fmt.Errorf("unable to marshal private key: %w", err)
   279  	}
   280  	if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
   281  		return "", "", fmt.Errorf("failed to write data to key.pem: %s", err)
   282  	}
   283  	if err := keyOut.Close(); err != nil {
   284  		return "", "", fmt.Errorf("error closing key.pem: %s", err)
   285  	}
   286  	if err := os.Chmod(keyOut.Name(), 0600); err != nil {
   287  		return "", "", fmt.Errorf("could not change permissions on key.pem: %w", err)
   288  	}
   289  	return certOut.Name(), keyOut.Name(), nil
   290  }