github.com/Tyktechnologies/tyk@v2.9.5+incompatible/gateway/proxy_muxer_test.go (about)

     1  package gateway
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"io/ioutil"
     7  	"net"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"reflect"
    11  	"strconv"
    12  	"sync/atomic"
    13  	"testing"
    14  
    15  	"github.com/TykTechnologies/tyk/config"
    16  )
    17  
    18  func TestTCPDial_with_service_discovery(t *testing.T) {
    19  	service1, err := net.Listen("tcp", "127.0.0.1:0")
    20  	if err != nil {
    21  		t.Fatal(err)
    22  	}
    23  	defer service1.Close()
    24  	msg := "whois"
    25  	go func() {
    26  		for {
    27  			ls, err := service1.Accept()
    28  			if err != nil {
    29  				break
    30  			}
    31  			buf := make([]byte, len(msg))
    32  			_, err = ls.Read(buf)
    33  			if err != nil {
    34  				break
    35  			}
    36  			ls.Write([]byte("service1"))
    37  		}
    38  	}()
    39  	service2, err := net.Listen("tcp", "127.0.0.1:0")
    40  	if err != nil {
    41  		t.Fatal(err)
    42  	}
    43  	defer service1.Close()
    44  	go func() {
    45  		for {
    46  			ls, err := service2.Accept()
    47  			if err != nil {
    48  				break
    49  			}
    50  			buf := make([]byte, len(msg))
    51  			_, err = ls.Read(buf)
    52  			if err != nil {
    53  				break
    54  			}
    55  			ls.Write([]byte("service2"))
    56  		}
    57  	}()
    58  	var active atomic.Value
    59  	active.Store(0)
    60  	sds := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    61  		list := []string{
    62  			"tcp://" + service1.Addr().String(),
    63  			"tcp://" + service2.Addr().String(),
    64  		}
    65  		idx := active.Load().(int)
    66  		if idx == 0 {
    67  			idx = 1
    68  		} else {
    69  			idx = 0
    70  		}
    71  		active.Store(idx)
    72  		json.NewEncoder(w).Encode([]interface{}{
    73  			map[string]string{
    74  				"hostname": list[idx],
    75  			},
    76  		})
    77  	}))
    78  	defer sds.Close()
    79  	ts := StartTest()
    80  	defer ts.Close()
    81  	rp, err := net.Listen("tcp", "127.0.0.1:0")
    82  	if err != nil {
    83  		t.Fatal(err)
    84  	}
    85  	_, port, err := net.SplitHostPort(rp.Addr().String())
    86  	if err != nil {
    87  		t.Fatal(err)
    88  	}
    89  	p, err := strconv.Atoi(port)
    90  	if err != nil {
    91  		t.Fatal(err)
    92  	}
    93  	EnablePort(p, "tcp")
    94  	defer ResetTestConfig()
    95  	address := rp.Addr().String()
    96  	rp.Close()
    97  	BuildAndLoadAPI(func(spec *APISpec) {
    98  		spec.Proxy.ListenPath = "/"
    99  		spec.Protocol = "tcp"
   100  		spec.Proxy.ServiceDiscovery.UseDiscoveryService = true
   101  		spec.Proxy.ServiceDiscovery.EndpointReturnsList = true
   102  		spec.Proxy.ServiceDiscovery.QueryEndpoint = sds.URL
   103  		spec.Proxy.ServiceDiscovery.DataPath = "hostname"
   104  		spec.Proxy.EnableLoadBalancing = true
   105  		spec.ListenPort = p
   106  		spec.Proxy.TargetURL = service1.Addr().String()
   107  	})
   108  
   109  	e := "service1"
   110  	var result []string
   111  
   112  	dial := func() string {
   113  		l, err := net.Dial("tcp", address)
   114  		if err != nil {
   115  			t.Fatal(err)
   116  		}
   117  		defer l.Close()
   118  		_, err = l.Write([]byte("whois"))
   119  		if err != nil {
   120  			t.Fatal(err)
   121  		}
   122  		buf := make([]byte, len(e))
   123  		_, err = l.Read(buf)
   124  		if err != nil {
   125  			t.Fatal(err)
   126  		}
   127  		return string(buf)
   128  	}
   129  	for i := 0; i < 4; i++ {
   130  		if ServiceCache != nil {
   131  			ServiceCache.Flush()
   132  		}
   133  		result = append(result, dial())
   134  	}
   135  	expect := []string{"service2", "service1", "service2", "service1"}
   136  	if !reflect.DeepEqual(result, expect) {
   137  		t.Errorf("expected %#v got %#v", expect, result)
   138  	}
   139  }
   140  
   141  func TestTCP_missing_port(t *testing.T) {
   142  	ts := StartTest()
   143  	defer ts.Close()
   144  	BuildAndLoadAPI(func(spec *APISpec) {
   145  		spec.Name = "no -listen-port"
   146  		spec.Protocol = "tcp"
   147  	})
   148  	apisMu.RLock()
   149  	n := len(apiSpecs)
   150  	apisMu.RUnlock()
   151  	if n != 0 {
   152  		t.Errorf("expected 0 apis to be loaded got %d", n)
   153  	}
   154  }
   155  
   156  // getUnusedPort returns a tcp port that is a vailable for binding.
   157  func getUnusedPort() (int, error) {
   158  	rp, err := net.Listen("tcp", "127.0.0.1:0")
   159  	if err != nil {
   160  		return 0, err
   161  	}
   162  	defer rp.Close()
   163  	_, port, err := net.SplitHostPort(rp.Addr().String())
   164  	if err != nil {
   165  		return 0, err
   166  	}
   167  	p, err := strconv.Atoi(port)
   168  	if err != nil {
   169  		return 0, err
   170  	}
   171  	return p, nil
   172  }
   173  
   174  func TestCheckPortWhiteList(t *testing.T) {
   175  	base := config.Global()
   176  	cases := []struct {
   177  		name     string
   178  		protocol string
   179  		port     int
   180  		fail     bool
   181  		wls      map[string]config.PortWhiteList
   182  	}{
   183  		{"gw port empty protocol", "", base.ListenPort, true, nil},
   184  		{"gw port http protocol", "http", base.ListenPort, false, map[string]config.PortWhiteList{
   185  			"http": {
   186  				Ports: []int{base.ListenPort},
   187  			},
   188  		}},
   189  		{"unknown tls", "tls", base.ListenPort, true, nil},
   190  		{"unknown tcp", "tls", base.ListenPort, true, nil},
   191  		{"whitelisted tcp", "tcp", base.ListenPort, false, map[string]config.PortWhiteList{
   192  			"tcp": {
   193  				Ports: []int{base.ListenPort},
   194  			},
   195  		}},
   196  		{"whitelisted tls", "tls", base.ListenPort, false, map[string]config.PortWhiteList{
   197  			"tls": {
   198  				Ports: []int{base.ListenPort},
   199  			},
   200  		}},
   201  		{"black listed tcp", "tcp", base.ListenPort, true, map[string]config.PortWhiteList{
   202  			"tls": {
   203  				Ports: []int{base.ListenPort},
   204  			},
   205  		}},
   206  		{"blacklisted tls", "tls", base.ListenPort, true, map[string]config.PortWhiteList{
   207  			"tcp": {
   208  				Ports: []int{base.ListenPort},
   209  			},
   210  		}},
   211  		{"whitelisted tls range", "tls", base.ListenPort, false, map[string]config.PortWhiteList{
   212  			"tls": {
   213  				Ranges: []config.PortRange{
   214  					{
   215  						From: base.ListenPort - 1,
   216  						To:   base.ListenPort + 1,
   217  					},
   218  				},
   219  			},
   220  		}},
   221  		{"whitelisted tcp range", "tcp", base.ListenPort, false, map[string]config.PortWhiteList{
   222  			"tcp": {
   223  				Ranges: []config.PortRange{
   224  					{
   225  						From: base.ListenPort - 1,
   226  						To:   base.ListenPort + 1,
   227  					},
   228  				},
   229  			},
   230  		}},
   231  		{"whitelisted http range", "http", 8090, false, map[string]config.PortWhiteList{
   232  			"http": {
   233  				Ranges: []config.PortRange{
   234  					{
   235  						From: 8000,
   236  						To:   9000,
   237  					},
   238  				},
   239  			},
   240  		}},
   241  	}
   242  	for i, tt := range cases {
   243  		t.Run(tt.name, func(ts *testing.T) {
   244  			err := CheckPortWhiteList(tt.wls, tt.port, tt.protocol)
   245  			if tt.fail {
   246  				if err == nil {
   247  					ts.Error("expected an error got nil")
   248  				}
   249  			} else {
   250  				if err != nil {
   251  					ts.Errorf("%d: expected an nil got %v", i, err)
   252  				}
   253  			}
   254  		})
   255  	}
   256  }
   257  
   258  func TestHTTP_custom_ports(t *testing.T) {
   259  	ts := StartTest()
   260  	defer ts.Close()
   261  	echo := "Hello, world"
   262  	us := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   263  		w.Write([]byte(echo))
   264  	}))
   265  	defer us.Close()
   266  	port, err := getUnusedPort()
   267  	if err != nil {
   268  		t.Fatal(err)
   269  	}
   270  	EnablePort(port, "http")
   271  	BuildAndLoadAPI(func(spec *APISpec) {
   272  		spec.Proxy.ListenPath = "/"
   273  		spec.Protocol = "http"
   274  		spec.ListenPort = port
   275  		spec.Proxy.TargetURL = us.URL
   276  	})
   277  	s := fmt.Sprintf("http://localhost:%d", port)
   278  	w, err := http.Get(s)
   279  	if err != nil {
   280  		t.Fatal(err)
   281  	}
   282  	defer w.Body.Close()
   283  	b, err := ioutil.ReadAll(w.Body)
   284  	if err != nil {
   285  		t.Fatal(err)
   286  	}
   287  	bs := string(b)
   288  	if bs != echo {
   289  		t.Errorf("expected %s to %s", echo, bs)
   290  	}
   291  }