github.com/Tyktechnologies/tyk@v2.9.5+incompatible/gateway/host_checker_test.go (about)

     1  package gateway
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"net"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"net/url"
    10  	"sync"
    11  	"testing"
    12  	"text/template"
    13  	"time"
    14  
    15  	"github.com/TykTechnologies/tyk/apidef"
    16  	"github.com/TykTechnologies/tyk/config"
    17  	"github.com/TykTechnologies/tyk/storage"
    18  	proxyproto "github.com/pires/go-proxyproto"
    19  )
    20  
    21  const sampleUptimeTestAPI = `{
    22  	"api_id": "test",
    23  	"use_keyless": true,
    24  	"version_data": {
    25  		"not_versioned": true,
    26  		"versions": {
    27  			"v1": {"name": "v1"}
    28  		}
    29  	},
    30  	"uptime_tests": {
    31  		"check_list": [
    32  			{
    33  				"url": "{{.Host1}}/get",
    34  				"method": "GET"
    35  			},
    36  			{
    37  				"url": "{{.Host2}}/get",
    38  				"method": "GET"
    39  			}
    40  		]
    41  	},
    42  	"proxy": {
    43  		"listen_path": "/",
    44  		"enable_load_balancing": true,
    45  		"check_host_against_uptime_tests": true,
    46  		"target_list": [
    47  			"{{.Host1}}",
    48  			"{{.Host2}}"
    49  		]
    50  	}
    51  }`
    52  
    53  type testEventHandler struct {
    54  	cb func(config.EventMessage)
    55  }
    56  
    57  func (w *testEventHandler) Init(handlerConf interface{}) error {
    58  	return nil
    59  }
    60  
    61  func (w *testEventHandler) HandleEvent(em config.EventMessage) {
    62  	w.cb(em)
    63  }
    64  
    65  func TestHostChecker(t *testing.T) {
    66  	specTmpl := template.Must(template.New("spec").Parse(sampleUptimeTestAPI))
    67  
    68  	tmplData := struct {
    69  		Host1, Host2 string
    70  	}{
    71  		TestHttpAny,
    72  		testHttpFailureAny,
    73  	}
    74  
    75  	specBuf := &bytes.Buffer{}
    76  	specTmpl.ExecuteTemplate(specBuf, specTmpl.Name(), &tmplData)
    77  
    78  	spec := CreateDefinitionFromString(specBuf.String())
    79  
    80  	// From api_loader.go#processSpec
    81  	sl := apidef.NewHostListFromList(spec.Proxy.Targets)
    82  	spec.Proxy.StructuredTargetList = sl
    83  
    84  	var eventWG sync.WaitGroup
    85  	// Should receive one HostDown event
    86  	eventWG.Add(1)
    87  	cb := func(em config.EventMessage) {
    88  		eventWG.Done()
    89  	}
    90  
    91  	spec.EventPaths = map[apidef.TykEvent][]config.TykEventHandler{
    92  		"HostDown": {&testEventHandler{cb}},
    93  	}
    94  
    95  	apisMu.Lock()
    96  	apisByID = map[string]*APISpec{spec.APIID: spec}
    97  	apisMu.Unlock()
    98  	GlobalHostChecker.checkerMu.Lock()
    99  	GlobalHostChecker.checker.sampleTriggerLimit = 1
   100  	GlobalHostChecker.checkerMu.Unlock()
   101  	defer func() {
   102  		apisMu.Lock()
   103  		apisByID = make(map[string]*APISpec)
   104  		apisMu.Unlock()
   105  		GlobalHostChecker.checkerMu.Lock()
   106  		GlobalHostChecker.checker.sampleTriggerLimit = defaultSampletTriggerLimit
   107  		GlobalHostChecker.checkerMu.Unlock()
   108  	}()
   109  
   110  	SetCheckerHostList()
   111  	GlobalHostChecker.checkerMu.Lock()
   112  	if len(GlobalHostChecker.currentHostList) != 2 {
   113  		t.Error("Should update hosts manager check list", GlobalHostChecker.currentHostList)
   114  	}
   115  
   116  	if len(GlobalHostChecker.checker.newList) != 2 {
   117  		t.Error("Should update host checker check list")
   118  	}
   119  	GlobalHostChecker.checkerMu.Unlock()
   120  
   121  	hostCheckTicker <- struct{}{}
   122  	eventWG.Wait()
   123  
   124  	if GlobalHostChecker.HostDown(TestHttpAny) {
   125  		t.Error("Should not mark as down")
   126  	}
   127  
   128  	if !GlobalHostChecker.HostDown(testHttpFailureAny) {
   129  		t.Error("Should mark as down")
   130  	}
   131  
   132  	// Test it many times concurrently, to simulate concurrent and
   133  	// parallel requests to the API. This will catch bugs in those
   134  	// scenarios, like data races.
   135  	var targetWG sync.WaitGroup
   136  	for i := 0; i < 10; i++ {
   137  		targetWG.Add(1)
   138  		go func() {
   139  			host, err := nextTarget(spec.Proxy.StructuredTargetList, spec)
   140  			if err != nil {
   141  				t.Error("Should return nil error, got", err)
   142  			}
   143  			if host != TestHttpAny {
   144  				t.Error("Should return only active host, got", host)
   145  			}
   146  			targetWG.Done()
   147  		}()
   148  	}
   149  	targetWG.Wait()
   150  
   151  	GlobalHostChecker.checkerMu.Lock()
   152  	if GlobalHostChecker.checker.checkTimeout != defaultTimeout {
   153  		t.Error("Should set defaults", GlobalHostChecker.checker.checkTimeout)
   154  	}
   155  
   156  	redisStore := GlobalHostChecker.store.(*storage.RedisCluster)
   157  	if ttl, _ := redisStore.GetKeyTTL(PoolerHostSentinelKeyPrefix + testHttpFailure); int(ttl) != GlobalHostChecker.checker.checkTimeout*GlobalHostChecker.checker.sampleTriggerLimit {
   158  		t.Error("HostDown expiration key should be checkTimeout + 1", ttl)
   159  	}
   160  	GlobalHostChecker.checkerMu.Unlock()
   161  }
   162  
   163  func TestReverseProxyAllDown(t *testing.T) {
   164  	specTmpl := template.Must(template.New("spec").Parse(sampleUptimeTestAPI))
   165  
   166  	tmplData := struct {
   167  		Host1, Host2 string
   168  	}{
   169  		testHttpFailureAny,
   170  		testHttpFailureAny,
   171  	}
   172  
   173  	specBuf := &bytes.Buffer{}
   174  	specTmpl.ExecuteTemplate(specBuf, specTmpl.Name(), &tmplData)
   175  
   176  	spec := CreateDefinitionFromString(specBuf.String())
   177  
   178  	// From api_loader.go#processSpec
   179  	sl := apidef.NewHostListFromList(spec.Proxy.Targets)
   180  	spec.Proxy.StructuredTargetList = sl
   181  
   182  	var eventWG sync.WaitGroup
   183  	// Should receive one HostDown event
   184  	eventWG.Add(1)
   185  	cb := func(em config.EventMessage) {
   186  		eventWG.Done()
   187  	}
   188  	spec.EventPaths = map[apidef.TykEvent][]config.TykEventHandler{
   189  		"HostDown": {&testEventHandler{cb}},
   190  	}
   191  
   192  	apisMu.Lock()
   193  	apisByID = map[string]*APISpec{spec.APIID: spec}
   194  	apisMu.Unlock()
   195  	GlobalHostChecker.checkerMu.Lock()
   196  	GlobalHostChecker.checker.sampleTriggerLimit = 1
   197  	GlobalHostChecker.checkerMu.Unlock()
   198  	defer func() {
   199  		apisMu.Lock()
   200  		apisByID = make(map[string]*APISpec)
   201  		apisMu.Unlock()
   202  		GlobalHostChecker.checkerMu.Lock()
   203  		GlobalHostChecker.checker.sampleTriggerLimit = defaultSampletTriggerLimit
   204  		GlobalHostChecker.checkerMu.Unlock()
   205  	}()
   206  
   207  	SetCheckerHostList()
   208  
   209  	hostCheckTicker <- struct{}{}
   210  	eventWG.Wait()
   211  
   212  	remote, _ := url.Parse(TestHttpAny)
   213  	proxy := TykNewSingleHostReverseProxy(remote, spec, nil)
   214  
   215  	req := TestReq(t, "GET", "/", nil)
   216  	rec := httptest.NewRecorder()
   217  	proxy.ServeHTTP(rec, req)
   218  	if rec.Code != 503 {
   219  		t.Fatalf("wanted code to be 503, was %d", rec.Code)
   220  	}
   221  }
   222  
   223  type answers struct {
   224  	mu             sync.RWMutex
   225  	ping, fail, up bool
   226  	cancel         func()
   227  }
   228  
   229  func (a *answers) onFail(_ HostHealthReport) {
   230  	defer a.cancel()
   231  	a.mu.Lock()
   232  	a.fail = true
   233  	a.mu.Unlock()
   234  }
   235  
   236  func (a *answers) onPing(_ HostHealthReport) {
   237  	defer a.cancel()
   238  	a.mu.Lock()
   239  	a.ping = true
   240  	a.mu.Unlock()
   241  }
   242  func (a *answers) onUp(_ HostHealthReport) {
   243  	defer a.cancel()
   244  	a.mu.Lock()
   245  	a.up = true
   246  	a.mu.Unlock()
   247  }
   248  
   249  func TestTestCheckerTCPHosts_correct_answers(t *testing.T) {
   250  	l, err := net.Listen("tcp", "127.0.0.1:0")
   251  	if err != nil {
   252  		t.Fatal(err)
   253  	}
   254  	defer l.Close()
   255  	data := HostData{
   256  		CheckURL: l.Addr().String(),
   257  		Protocol: "tcp",
   258  		Commands: []apidef.CheckCommand{
   259  			{
   260  				Name: "send", Message: "ping",
   261  			}, {
   262  				Name: "expect", Message: "pong",
   263  			},
   264  		},
   265  	}
   266  	go func(ls net.Listener) {
   267  		for {
   268  			s, err := ls.Accept()
   269  			if err != nil {
   270  				return
   271  			}
   272  			buf := make([]byte, 4)
   273  			_, err = s.Read(buf)
   274  			if err != nil {
   275  				return
   276  			}
   277  			if string(buf) == "ping" {
   278  				s.Write([]byte("pong"))
   279  			} else {
   280  				s.Write([]byte("unknown"))
   281  			}
   282  		}
   283  	}(l)
   284  	ctx, cancel := context.WithCancel(context.Background())
   285  	hs := &HostUptimeChecker{}
   286  	ans := &answers{cancel: cancel}
   287  	setTestMode(false)
   288  
   289  	hs.Init(1, 1, 0, map[string]HostData{
   290  		l.Addr().String(): data,
   291  	},
   292  		ans.onFail,
   293  		ans.onUp,
   294  		ans.onPing,
   295  	)
   296  	hs.sampleTriggerLimit = 1
   297  	go hs.Start()
   298  	<-ctx.Done()
   299  	hs.Stop()
   300  	setTestMode(true)
   301  	if !(ans.ping && !ans.fail && !ans.up) {
   302  		t.Errorf("expected the host to be up : field:%v up:%v pinged:%v", ans.fail, ans.up, ans.ping)
   303  	}
   304  }
   305  func TestTestCheckerTCPHosts_correct_answers_proxy_protocol(t *testing.T) {
   306  	l, err := net.Listen("tcp", "127.0.0.1:0")
   307  	if err != nil {
   308  		t.Fatal(err)
   309  	}
   310  	defer l.Close()
   311  	data := HostData{
   312  		CheckURL:            l.Addr().String(),
   313  		Protocol:            "tcp",
   314  		EnableProxyProtocol: true,
   315  		Commands: []apidef.CheckCommand{
   316  			{
   317  				Name: "send", Message: "ping",
   318  			}, {
   319  				Name: "expect", Message: "pong",
   320  			},
   321  		},
   322  	}
   323  	go func(ls net.Listener) {
   324  		ls = &proxyproto.Listener{Listener: ls}
   325  		for {
   326  			s, err := ls.Accept()
   327  			if err != nil {
   328  				return
   329  			}
   330  			buf := make([]byte, 4)
   331  			_, err = s.Read(buf)
   332  			if err != nil {
   333  				return
   334  			}
   335  			if string(buf) == "ping" {
   336  				s.Write([]byte("pong"))
   337  			} else {
   338  				s.Write([]byte("unknown"))
   339  			}
   340  		}
   341  	}(l)
   342  	ctx, cancel := context.WithCancel(context.Background())
   343  	hs := &HostUptimeChecker{}
   344  	ans := &answers{cancel: cancel}
   345  	setTestMode(false)
   346  
   347  	hs.Init(1, 1, 0, map[string]HostData{
   348  		l.Addr().String(): data,
   349  	},
   350  		ans.onFail,
   351  		ans.onUp,
   352  		ans.onPing,
   353  	)
   354  	hs.sampleTriggerLimit = 1
   355  	go hs.Start()
   356  	<-ctx.Done()
   357  	hs.Stop()
   358  	setTestMode(true)
   359  	if !(ans.ping && !ans.fail && !ans.up) {
   360  		t.Errorf("expected the host to be up : field:%v up:%v pinged:%v", ans.fail, ans.up, ans.ping)
   361  	}
   362  }
   363  
   364  func TestTestCheckerTCPHosts_correct_wrong_answers(t *testing.T) {
   365  	l, err := net.Listen("tcp", "127.0.0.1:0")
   366  	if err != nil {
   367  		t.Fatal(err)
   368  	}
   369  	defer l.Close()
   370  	data := HostData{
   371  		CheckURL: l.Addr().String(),
   372  		Protocol: "tcp",
   373  		Commands: []apidef.CheckCommand{
   374  			{
   375  				Name: "send", Message: "ping",
   376  			}, {
   377  				Name: "expect", Message: "pong",
   378  			},
   379  		},
   380  	}
   381  	go func(ls net.Listener) {
   382  		for {
   383  			s, err := ls.Accept()
   384  			if err != nil {
   385  				return
   386  			}
   387  			buf := make([]byte, 4)
   388  			_, err = s.Read(buf)
   389  			if err != nil {
   390  				return
   391  			}
   392  			s.Write([]byte("unknown"))
   393  		}
   394  	}(l)
   395  	ctx, cancel := context.WithCancel(context.Background())
   396  	hs := &HostUptimeChecker{}
   397  	failed := false
   398  	setTestMode(false)
   399  	hs.Init(1, 1, 0, map[string]HostData{
   400  		l.Addr().String(): data,
   401  	},
   402  		func(HostHealthReport) {
   403  			failed = true
   404  			cancel()
   405  		},
   406  		func(HostHealthReport) {},
   407  		func(HostHealthReport) {},
   408  	)
   409  	hs.sampleTriggerLimit = 1
   410  	go hs.Start()
   411  	<-ctx.Done()
   412  	hs.Stop()
   413  	setTestMode(true)
   414  	if !failed {
   415  		t.Error("expected the host check to fai")
   416  	}
   417  }
   418  
   419  func TestProxyWhenHostIsDown(t *testing.T) {
   420  
   421  	g := config.Global()
   422  	g.UptimeTests.Config.FailureTriggerSampleSize = 1
   423  	g.UptimeTests.Config.TimeWait = 5
   424  	g.UptimeTests.Config.EnableUptimeAnalytics = true
   425  	config.SetGlobal(g)
   426  
   427  	ts := StartTest()
   428  	defer ts.Close()
   429  
   430  	defer ResetTestConfig()
   431  	l := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   432  	}))
   433  	defer l.Close()
   434  	BuildAndLoadAPI(func(spec *APISpec) {
   435  		spec.Proxy.ListenPath = "/"
   436  		spec.Proxy.EnableLoadBalancing = true
   437  		spec.Proxy.Targets = []string{l.URL}
   438  		spec.Proxy.CheckHostAgainstUptimeTests = true
   439  		spec.UptimeTests.CheckList = []apidef.HostCheckObject{
   440  			{CheckURL: l.URL},
   441  		}
   442  	})
   443  	GlobalHostChecker.checkerMu.Lock()
   444  	GlobalHostChecker.checker.sampleTriggerLimit = 1
   445  	GlobalHostChecker.checkerMu.Unlock()
   446  
   447  	tick := time.NewTicker(time.Millisecond)
   448  	defer tick.Stop()
   449  	x := 0
   450  	get := func() {
   451  		x++
   452  		res, err := http.Get(ts.URL + "/")
   453  		if err == nil {
   454  			res.Body.Close()
   455  		}
   456  		code := http.StatusOK
   457  		if x > 2 {
   458  			code = http.StatusServiceUnavailable
   459  		}
   460  		if res.StatusCode != code {
   461  			t.Errorf("%d: expected %d got %d", x, code, res.StatusCode)
   462  		}
   463  	}
   464  	n := 0
   465  	sentSignal := false
   466  	for {
   467  		select {
   468  		case <-tick.C:
   469  			if sentSignal {
   470  				sentSignal = !sentSignal
   471  				continue
   472  			}
   473  			if n == 2 {
   474  				l.Close()
   475  				hostCheckTicker <- struct{}{}
   476  				n++
   477  				sentSignal = true
   478  				continue
   479  			}
   480  			n++
   481  			if n == 10 {
   482  				return
   483  			}
   484  			get()
   485  		}
   486  	}
   487  }