github.com/Tyktechnologies/tyk@v2.9.5+incompatible/gateway/proxy_muxer_test.go (about) 1 package gateway 2 3 import ( 4 "encoding/json" 5 "fmt" 6 "io/ioutil" 7 "net" 8 "net/http" 9 "net/http/httptest" 10 "reflect" 11 "strconv" 12 "sync/atomic" 13 "testing" 14 15 "github.com/TykTechnologies/tyk/config" 16 ) 17 18 func TestTCPDial_with_service_discovery(t *testing.T) { 19 service1, err := net.Listen("tcp", "127.0.0.1:0") 20 if err != nil { 21 t.Fatal(err) 22 } 23 defer service1.Close() 24 msg := "whois" 25 go func() { 26 for { 27 ls, err := service1.Accept() 28 if err != nil { 29 break 30 } 31 buf := make([]byte, len(msg)) 32 _, err = ls.Read(buf) 33 if err != nil { 34 break 35 } 36 ls.Write([]byte("service1")) 37 } 38 }() 39 service2, err := net.Listen("tcp", "127.0.0.1:0") 40 if err != nil { 41 t.Fatal(err) 42 } 43 defer service1.Close() 44 go func() { 45 for { 46 ls, err := service2.Accept() 47 if err != nil { 48 break 49 } 50 buf := make([]byte, len(msg)) 51 _, err = ls.Read(buf) 52 if err != nil { 53 break 54 } 55 ls.Write([]byte("service2")) 56 } 57 }() 58 var active atomic.Value 59 active.Store(0) 60 sds := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 61 list := []string{ 62 "tcp://" + service1.Addr().String(), 63 "tcp://" + service2.Addr().String(), 64 } 65 idx := active.Load().(int) 66 if idx == 0 { 67 idx = 1 68 } else { 69 idx = 0 70 } 71 active.Store(idx) 72 json.NewEncoder(w).Encode([]interface{}{ 73 map[string]string{ 74 "hostname": list[idx], 75 }, 76 }) 77 })) 78 defer sds.Close() 79 ts := StartTest() 80 defer ts.Close() 81 rp, err := net.Listen("tcp", "127.0.0.1:0") 82 if err != nil { 83 t.Fatal(err) 84 } 85 _, port, err := net.SplitHostPort(rp.Addr().String()) 86 if err != nil { 87 t.Fatal(err) 88 } 89 p, err := strconv.Atoi(port) 90 if err != nil { 91 t.Fatal(err) 92 } 93 EnablePort(p, "tcp") 94 defer ResetTestConfig() 95 address := rp.Addr().String() 96 rp.Close() 97 BuildAndLoadAPI(func(spec *APISpec) { 98 spec.Proxy.ListenPath = "/" 99 spec.Protocol = "tcp" 100 spec.Proxy.ServiceDiscovery.UseDiscoveryService = true 101 spec.Proxy.ServiceDiscovery.EndpointReturnsList = true 102 spec.Proxy.ServiceDiscovery.QueryEndpoint = sds.URL 103 spec.Proxy.ServiceDiscovery.DataPath = "hostname" 104 spec.Proxy.EnableLoadBalancing = true 105 spec.ListenPort = p 106 spec.Proxy.TargetURL = service1.Addr().String() 107 }) 108 109 e := "service1" 110 var result []string 111 112 dial := func() string { 113 l, err := net.Dial("tcp", address) 114 if err != nil { 115 t.Fatal(err) 116 } 117 defer l.Close() 118 _, err = l.Write([]byte("whois")) 119 if err != nil { 120 t.Fatal(err) 121 } 122 buf := make([]byte, len(e)) 123 _, err = l.Read(buf) 124 if err != nil { 125 t.Fatal(err) 126 } 127 return string(buf) 128 } 129 for i := 0; i < 4; i++ { 130 if ServiceCache != nil { 131 ServiceCache.Flush() 132 } 133 result = append(result, dial()) 134 } 135 expect := []string{"service2", "service1", "service2", "service1"} 136 if !reflect.DeepEqual(result, expect) { 137 t.Errorf("expected %#v got %#v", expect, result) 138 } 139 } 140 141 func TestTCP_missing_port(t *testing.T) { 142 ts := StartTest() 143 defer ts.Close() 144 BuildAndLoadAPI(func(spec *APISpec) { 145 spec.Name = "no -listen-port" 146 spec.Protocol = "tcp" 147 }) 148 apisMu.RLock() 149 n := len(apiSpecs) 150 apisMu.RUnlock() 151 if n != 0 { 152 t.Errorf("expected 0 apis to be loaded got %d", n) 153 } 154 } 155 156 // getUnusedPort returns a tcp port that is a vailable for binding. 157 func getUnusedPort() (int, error) { 158 rp, err := net.Listen("tcp", "127.0.0.1:0") 159 if err != nil { 160 return 0, err 161 } 162 defer rp.Close() 163 _, port, err := net.SplitHostPort(rp.Addr().String()) 164 if err != nil { 165 return 0, err 166 } 167 p, err := strconv.Atoi(port) 168 if err != nil { 169 return 0, err 170 } 171 return p, nil 172 } 173 174 func TestCheckPortWhiteList(t *testing.T) { 175 base := config.Global() 176 cases := []struct { 177 name string 178 protocol string 179 port int 180 fail bool 181 wls map[string]config.PortWhiteList 182 }{ 183 {"gw port empty protocol", "", base.ListenPort, true, nil}, 184 {"gw port http protocol", "http", base.ListenPort, false, map[string]config.PortWhiteList{ 185 "http": { 186 Ports: []int{base.ListenPort}, 187 }, 188 }}, 189 {"unknown tls", "tls", base.ListenPort, true, nil}, 190 {"unknown tcp", "tls", base.ListenPort, true, nil}, 191 {"whitelisted tcp", "tcp", base.ListenPort, false, map[string]config.PortWhiteList{ 192 "tcp": { 193 Ports: []int{base.ListenPort}, 194 }, 195 }}, 196 {"whitelisted tls", "tls", base.ListenPort, false, map[string]config.PortWhiteList{ 197 "tls": { 198 Ports: []int{base.ListenPort}, 199 }, 200 }}, 201 {"black listed tcp", "tcp", base.ListenPort, true, map[string]config.PortWhiteList{ 202 "tls": { 203 Ports: []int{base.ListenPort}, 204 }, 205 }}, 206 {"blacklisted tls", "tls", base.ListenPort, true, map[string]config.PortWhiteList{ 207 "tcp": { 208 Ports: []int{base.ListenPort}, 209 }, 210 }}, 211 {"whitelisted tls range", "tls", base.ListenPort, false, map[string]config.PortWhiteList{ 212 "tls": { 213 Ranges: []config.PortRange{ 214 { 215 From: base.ListenPort - 1, 216 To: base.ListenPort + 1, 217 }, 218 }, 219 }, 220 }}, 221 {"whitelisted tcp range", "tcp", base.ListenPort, false, map[string]config.PortWhiteList{ 222 "tcp": { 223 Ranges: []config.PortRange{ 224 { 225 From: base.ListenPort - 1, 226 To: base.ListenPort + 1, 227 }, 228 }, 229 }, 230 }}, 231 {"whitelisted http range", "http", 8090, false, map[string]config.PortWhiteList{ 232 "http": { 233 Ranges: []config.PortRange{ 234 { 235 From: 8000, 236 To: 9000, 237 }, 238 }, 239 }, 240 }}, 241 } 242 for i, tt := range cases { 243 t.Run(tt.name, func(ts *testing.T) { 244 err := CheckPortWhiteList(tt.wls, tt.port, tt.protocol) 245 if tt.fail { 246 if err == nil { 247 ts.Error("expected an error got nil") 248 } 249 } else { 250 if err != nil { 251 ts.Errorf("%d: expected an nil got %v", i, err) 252 } 253 } 254 }) 255 } 256 } 257 258 func TestHTTP_custom_ports(t *testing.T) { 259 ts := StartTest() 260 defer ts.Close() 261 echo := "Hello, world" 262 us := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 263 w.Write([]byte(echo)) 264 })) 265 defer us.Close() 266 port, err := getUnusedPort() 267 if err != nil { 268 t.Fatal(err) 269 } 270 EnablePort(port, "http") 271 BuildAndLoadAPI(func(spec *APISpec) { 272 spec.Proxy.ListenPath = "/" 273 spec.Protocol = "http" 274 spec.ListenPort = port 275 spec.Proxy.TargetURL = us.URL 276 }) 277 s := fmt.Sprintf("http://localhost:%d", port) 278 w, err := http.Get(s) 279 if err != nil { 280 t.Fatal(err) 281 } 282 defer w.Body.Close() 283 b, err := ioutil.ReadAll(w.Body) 284 if err != nil { 285 t.Fatal(err) 286 } 287 bs := string(b) 288 if bs != echo { 289 t.Errorf("expected %s to %s", echo, bs) 290 } 291 }