sigs.k8s.io/prow@v0.0.0-20240503223140-c5e374dc7eb1/pkg/interrupts/interrupts_test.go (about) 1 /* 2 Copyright 2019 The Kubernetes Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package interrupts 18 19 import ( 20 "context" 21 "crypto/rand" 22 "crypto/rsa" 23 "crypto/tls" 24 "crypto/x509" 25 "crypto/x509/pkix" 26 "encoding/pem" 27 "fmt" 28 "math/big" 29 "net" 30 "net/http" 31 "os" 32 "sync" 33 "syscall" 34 "testing" 35 "time" 36 ) 37 38 // interrupt allows for tests to trigger an interrupt as needed 39 var interrupt = make(chan os.Signal, 1) 40 41 // this init will be executed before that in the code package, 42 // so we can inject our implementation of the interrupt channel 43 func init() { 44 signalsLock.Lock() 45 gracePeriod = time.Second 46 signals = func() <-chan os.Signal { 47 return interrupt 48 } 49 signalsLock.Unlock() 50 } 51 52 // instead of building a mechanism to reset/re-initialize the interrupt 53 // manager which would only be used in testing, we write an integration 54 // test that only fires the mock interrupt once 55 func TestInterrupts(t *testing.T) { 56 // we need to lock around values used to test otherwise the test 57 // goroutine will race with the workers 58 lock := sync.Mutex{} 59 60 ctx := Context() 61 var ctxDone bool 62 go func() { 63 <-ctx.Done() 64 65 lock.Lock() 66 ctxDone = true 67 lock.Unlock() 68 }() 69 70 var workDone bool 71 var workCancelled bool 72 work := func(ctx context.Context) { 73 lock.Lock() 74 workDone = true 75 lock.Unlock() 76 77 <-ctx.Done() 78 79 lock.Lock() 80 workCancelled = true 81 lock.Unlock() 82 } 83 Run(work) 84 85 // we cannot use httptest mocks for the tests here as they expect 86 // to be started by the httptest package itself, not by a downstream 87 // caller like the interrupts library 88 var serverCalled bool 89 var serverCancelled bool 90 listener, err := net.Listen("tcp", "127.0.0.1:") 91 if err != nil { 92 t.Fatalf("could not listen on random port: %v", err) 93 } 94 if err := listener.Close(); err != nil { 95 t.Fatalf("could close listener: %v", err) 96 } 97 server := &http.Server{Addr: listener.Addr().String(), Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) { 98 lock.Lock() 99 serverCalled = true 100 lock.Unlock() 101 })} 102 server.RegisterOnShutdown(func() { 103 lock.Lock() 104 serverCancelled = true 105 lock.Unlock() 106 }) 107 ListenAndServe(server, time.Second) 108 // wait for the server to start 109 time.Sleep(100 * time.Millisecond) 110 if _, err := http.Get("http://" + listener.Addr().String()); err != nil { 111 t.Errorf("could not reach server registered with ListenAndServe(): %v", err) 112 } 113 114 var tlsServerCalled bool 115 var tlsServerCancelled bool 116 tlsListener, err := net.Listen("tcp", "127.0.0.1:") 117 if err != nil { 118 t.Fatalf("could not listen on random port: %v", err) 119 } 120 if err := tlsListener.Close(); err != nil { 121 t.Fatalf("could close listener: %v", err) 122 } 123 tlsServer := &http.Server{Addr: tlsListener.Addr().String(), Handler: http.HandlerFunc(func(http.ResponseWriter, *http.Request) { 124 lock.Lock() 125 tlsServerCalled = true 126 lock.Unlock() 127 })} 128 tlsServer.RegisterOnShutdown(func() { 129 lock.Lock() 130 tlsServerCancelled = true 131 lock.Unlock() 132 }) 133 cert, key, err := generateCerts("127.0.0.1") 134 if err != nil { 135 t.Fatalf("could not generate cert and key for TLS server: %v", err) 136 } 137 ListenAndServeTLS(tlsServer, cert, key, time.Second) 138 client := &http.Client{Transport: &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}} 139 // wait for the server to start 140 time.Sleep(100 * time.Millisecond) 141 if _, err := client.Get("https://" + tlsListener.Addr().String()); err != nil { 142 t.Errorf("could not reach server registered with ListenAndServeTLS(): %v", err) 143 } 144 145 var intervalCalls int 146 interval := func() time.Duration { 147 lock.Lock() 148 intervalCalls++ 149 lock.Unlock() 150 if intervalCalls > 2 { 151 return 10 * time.Hour 152 } 153 return 1 * time.Nanosecond 154 } 155 var tickCalls int 156 tick := func() { 157 lock.Lock() 158 tickCalls++ 159 lock.Unlock() 160 } 161 Tick(tick, interval) 162 // writing a test that functions correctly here without being susceptible 163 // to timing flakes is challenging. Using time.Sleep like this does have 164 // that downside, but the sleep time is many orders of magnitude higher 165 // than the tick intervals and the amount of time taken to execute the 166 // test as well, so it is going to be exceedingly rare that scheduling of 167 // the test process will cause a flake here from timing. The test cannot 168 // use synchronized approaches to waiting here as we do not know how long 169 // we must wait. The test must have enough time to ask for the interval 170 // as many times as we expect it to, but if we only wait for that we fail 171 // to catch the cases where the interval is requested too many times. 172 time.Sleep(100 * time.Millisecond) 173 174 var onInterruptCalled bool 175 OnInterrupt(func() { 176 lock.Lock() 177 onInterruptCalled = true 178 lock.Unlock() 179 }) 180 181 done := sync.WaitGroup{} 182 done.Add(1) 183 go func() { 184 WaitForGracefulShutdown() 185 time.Sleep(1 * time.Millisecond) // Ensure graceful shutdown channel closes 186 done.Done() 187 }() 188 189 if onInterruptCalled { 190 t.Error("work registered with OnInterrupt() was executed before interrupt") 191 } 192 193 // trigger the interrupt 194 interrupt <- syscall.Signal(1) 195 // wait for graceful shutdown to occur 196 done.Wait() 197 198 lock.Lock() 199 if !ctxDone { 200 t.Error("context from Context() was not cancelled on interrupt") 201 } 202 if !workDone { 203 t.Error("work registered with Run() was not executed") 204 } 205 if !workCancelled { 206 t.Error("work registered with Run() was not cancelled on interrupt") 207 } 208 if !serverCalled { 209 t.Error("server registered with ListenAndServe() was not serving") 210 } 211 if !serverCancelled { 212 t.Error("server registered with ListenAndServe() was not cancelled on interrupt") 213 } 214 if !tlsServerCalled { 215 t.Error("server registered with ListenAndServeTLS() was not serving") 216 } 217 if !tlsServerCancelled { 218 t.Error("server registered with ListenAndServeTLS() was not cancelled on interrupt") 219 } 220 if tickCalls != 2 { 221 t.Errorf("work registered with Tick() was called %d times, not %d; interval was requested %d times", tickCalls, 2, intervalCalls) 222 } 223 if !onInterruptCalled { 224 t.Error("work registered with OnInterrupt() was not executed on interrupt") 225 } 226 lock.Unlock() 227 } 228 229 func generateCerts(url string) (string, string, error) { 230 priv, err := rsa.GenerateKey(rand.Reader, 2048) 231 if err != nil { 232 return "", "", fmt.Errorf("failed to generate private key: %w", err) 233 } 234 235 serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) 236 serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) 237 if err != nil { 238 return "", "", fmt.Errorf("failed to generate serial number: %s", err) 239 } 240 241 template := x509.Certificate{ 242 SerialNumber: serialNumber, 243 Subject: pkix.Name{ 244 Organization: []string{"Acme Co"}, 245 }, 246 NotBefore: time.Now(), 247 NotAfter: time.Now().Add(1 * time.Hour), 248 249 KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, 250 ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, 251 BasicConstraintsValid: true, 252 253 IPAddresses: []net.IP{net.ParseIP(url)}, 254 } 255 256 derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) 257 if err != nil { 258 return "", "", fmt.Errorf("failed to create certificate: %s", err) 259 } 260 261 certOut, err := os.CreateTemp("", "cert.pem") 262 if err != nil { 263 return "", "", fmt.Errorf("failed to open cert.pem for writing: %s", err) 264 } 265 if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { 266 return "", "", fmt.Errorf("failed to write data to cert.pem: %s", err) 267 } 268 if err := certOut.Close(); err != nil { 269 return "", "", fmt.Errorf("error closing cert.pem: %s", err) 270 } 271 272 keyOut, err := os.CreateTemp("", "key.pem") 273 if err != nil { 274 return "", "", fmt.Errorf("failed to open key.pem for writing: %w", err) 275 } 276 privBytes, err := x509.MarshalPKCS8PrivateKey(priv) 277 if err != nil { 278 return "", "", fmt.Errorf("unable to marshal private key: %w", err) 279 } 280 if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { 281 return "", "", fmt.Errorf("failed to write data to key.pem: %s", err) 282 } 283 if err := keyOut.Close(); err != nil { 284 return "", "", fmt.Errorf("error closing key.pem: %s", err) 285 } 286 if err := os.Chmod(keyOut.Name(), 0600); err != nil { 287 return "", "", fmt.Errorf("could not change permissions on key.pem: %w", err) 288 } 289 return certOut.Name(), keyOut.Name(), nil 290 }