github.com/freiheit-com/kuberpult@v1.24.2-0.20240328135542-315d5630abe6/pkg/setup/setup_test.go (about)

     1  /*This file is part of kuberpult.
     2  
     3  Kuberpult is free software: you can redistribute it and/or modify
     4  it under the terms of the Expat(MIT) License as published by
     5  the Free Software Foundation.
     6  
     7  Kuberpult is distributed in the hope that it will be useful,
     8  but WITHOUT ANY WARRANTY; without even the implied warranty of
     9  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    10  MIT License for more details.
    11  
    12  You should have received a copy of the MIT License
    13  along with kuberpult. If not, see <https://directory.fsf.org/wiki/License:Expat>.
    14  
    15  Copyright 2023 freiheit.com*/
    16  
    17  package setup
    18  
    19  import (
    20  	"context"
    21  	"fmt"
    22  	"io"
    23  	"net/http"
    24  	"net/http/httptest"
    25  	"os"
    26  	"syscall"
    27  	"testing"
    28  	"time"
    29  
    30  	"github.com/freiheit-com/kuberpult/pkg/metrics"
    31  	"github.com/google/go-cmp/cmp"
    32  )
    33  
    34  func TestBasicAuthHandler(t *testing.T) {
    35  	tcs := []struct {
    36  		desc            string
    37  		basicAuthServer *BasicAuth
    38  		requestUser     string
    39  		requestPassword string
    40  
    41  		expectedResponseCode     int
    42  		expectedChainHandlerCall bool
    43  	}{
    44  		{
    45  			desc:                     "returns 401 on wrong auth, wrong username",
    46  			basicAuthServer:          &BasicAuth{"test", "666"},
    47  			requestUser:              "foo",
    48  			requestPassword:          "666",
    49  			expectedResponseCode:     401,
    50  			expectedChainHandlerCall: false,
    51  		},
    52  		{
    53  			desc:                     "returns 401 on wrong auth, wrong password",
    54  			basicAuthServer:          &BasicAuth{"test", "666"},
    55  			requestUser:              "test",
    56  			requestPassword:          "888",
    57  			expectedResponseCode:     401,
    58  			expectedChainHandlerCall: false,
    59  		},
    60  		{
    61  			desc:                     "passes request true, if auth ok",
    62  			basicAuthServer:          &BasicAuth{"test", "666"},
    63  			requestUser:              "test",
    64  			requestPassword:          "666",
    65  			expectedResponseCode:     200,
    66  			expectedChainHandlerCall: true,
    67  		},
    68  	}
    69  
    70  	for _, tc := range tcs {
    71  		t.Run(tc.desc, func(t *testing.T) {
    72  			testChainHandler := &testChainHandler{}
    73  
    74  			testRequest := httptest.NewRequest("GET", "http://example.com/", nil)
    75  			testRequest.SetBasicAuth(tc.requestUser, tc.requestPassword)
    76  
    77  			testResponse := httptest.NewRecorder()
    78  
    79  			handler := NewBasicAuthHandler(tc.basicAuthServer, testChainHandler)
    80  			handler.ServeHTTP(testResponse, testRequest)
    81  
    82  			if tc.expectedChainHandlerCall != testChainHandler.called {
    83  				t.Errorf("expectedChainHandlerCall %t, got %t", tc.expectedChainHandlerCall, testChainHandler.called)
    84  			}
    85  			if tc.expectedResponseCode != testResponse.Code {
    86  				t.Errorf("expectedResponseCode %d, got %d", tc.expectedResponseCode, testResponse.Code)
    87  			}
    88  		})
    89  	}
    90  }
    91  
    92  func TestGracefulShutdown(t *testing.T) {
    93  	tcs := []struct {
    94  		desc   string
    95  		port   string
    96  		termFn func()
    97  		cancel bool
    98  	}{
    99  		{
   100  			desc: "Cleans up on shutdown triggered by the OS",
   101  			port: "8383",
   102  			termFn: func() {
   103  				osSignalChannel <- syscall.SIGTERM
   104  			},
   105  		},
   106  		{
   107  			desc:   "Cleans up on cancelled context",
   108  			port:   "8282",
   109  			cancel: true,
   110  		},
   111  	}
   112  
   113  	for _, tc := range tcs {
   114  		t.Run(tc.desc, func(t *testing.T) {
   115  
   116  			fakeServer := make(chan interface{}, 1)
   117  			backServeHTTP := serveHTTP
   118  			defer func() {
   119  				serveHTTP = backServeHTTP
   120  			}()
   121  			serveHTTP = func(ctx context.Context, httpS *http.Server, port string, cancel context.CancelFunc) {
   122  				for range fakeServer {
   123  				}
   124  			}
   125  
   126  			backShutdownHTTP := shutdownHTTP
   127  			defer func() {
   128  				shutdownHTTP = backShutdownHTTP
   129  			}()
   130  			shutdownHTTP = func(ctx context.Context, httpS *http.Server) error {
   131  				close(fakeServer)
   132  				return nil
   133  			}
   134  
   135  			backOsSignalChannel := osSignalChannel
   136  			osSignalChannel = make(chan os.Signal, 1)
   137  			defer func() {
   138  				osSignalChannel = backOsSignalChannel
   139  			}()
   140  
   141  			cleanShutdownCh := make(chan bool, 1)
   142  
   143  			cfg := ServerConfig{
   144  				HTTP: []HTTPConfig{
   145  					{
   146  						Port:      tc.port,
   147  						Register:  func(*http.ServeMux) {},
   148  						BasicAuth: nil,
   149  						Shutdown: func(ctx context.Context) error {
   150  							<-time.After(200 * time.Millisecond) // Releasing resources (time consuming task)
   151  							cleanShutdownCh <- true
   152  							return nil
   153  						},
   154  					},
   155  				},
   156  			}
   157  
   158  			mainExited := make(chan bool, 1)
   159  			ctx, cancel := context.WithCancel(context.Background())
   160  			go func() {
   161  				Run(ctx, cfg)
   162  				mainExited <- true
   163  			}()
   164  			if tc.cancel {
   165  				cancel()
   166  			}
   167  			if tc.termFn != nil {
   168  				tc.termFn()
   169  			}
   170  			select {
   171  			case <-mainExited:
   172  				t.Errorf("Main goroutine finished before resource cleanup")
   173  			case <-time.After(10 * time.Second):
   174  				t.Errorf("Program didn't finish on shutdown signal")
   175  			case <-cleanShutdownCh: // That's what we expect
   176  			}
   177  		})
   178  	}
   179  
   180  }
   181  
   182  func TestMetrics(t *testing.T) {
   183  	tcs := []struct {
   184  		desc string
   185  		port string
   186  	}{
   187  		{
   188  			desc: "registers metrics server automatically",
   189  			port: "8384",
   190  		},
   191  	}
   192  
   193  	for _, tc := range tcs {
   194  		t.Run(tc.desc, func(t *testing.T) {
   195  			metricAdded := make(chan struct{})
   196  			cfg := ServerConfig{
   197  				HTTP: []HTTPConfig{
   198  					{
   199  						Port:     tc.port,
   200  						Register: func(*http.ServeMux) {},
   201  					},
   202  				},
   203  				Background: []BackgroundTaskConfig{
   204  					{
   205  						Name: "something",
   206  						Run: func(ctx context.Context, hr *HealthReporter) error {
   207  							pv := metrics.FromContext(ctx)
   208  							counter, _ := pv.Meter("something").Int64Counter("something")
   209  							counter.Add(ctx, 1)
   210  							metricAdded <- struct{}{}
   211  							<-ctx.Done()
   212  							return nil
   213  						},
   214  					},
   215  				},
   216  			}
   217  
   218  			mainExited := make(chan bool, 1)
   219  			ctx, cancel := context.WithCancel(context.Background())
   220  			go func() {
   221  				Run(ctx, cfg)
   222  				mainExited <- true
   223  			}()
   224  			<-metricAdded
   225  			var response *http.Response
   226  			for i := 0; i < 10; i = i + 1 {
   227  				res, err := http.Get(fmt.Sprintf("http://localhost:%s/metrics", tc.port))
   228  				if err != nil {
   229  					if i == 9 {
   230  						t.Errorf("error getting metrics: %s", err)
   231  					}
   232  					continue
   233  				}
   234  				response = res
   235  				time.After(time.Second)
   236  			}
   237  			body, _ := io.ReadAll(response.Body)
   238  			expectedBody := `# HELP background_job_ready 
   239  # TYPE background_job_ready gauge
   240  background_job_ready{name="something"} 0
   241  # HELP something_total 
   242  # TYPE something_total counter
   243  something_total 1
   244  `
   245  			if string(body) != expectedBody {
   246  				t.Errorf("got wrong metric body, diff %s", cmp.Diff(string(body), expectedBody))
   247  			}
   248  			cancel()
   249  			<-mainExited
   250  		})
   251  	}
   252  }
   253  
   254  //helper
   255  
   256  type testChainHandler struct {
   257  	called bool
   258  }
   259  
   260  func (h *testChainHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
   261  	h.called = true
   262  }