github.com/avenga/couper@v1.12.2/handler/endpoint_test.go (about) 1 package handler_test 2 3 import ( 4 "bytes" 5 "context" 6 "io" 7 "net/http" 8 "net/http/httptest" 9 "os" 10 "strings" 11 "testing" 12 "time" 13 14 "github.com/hashicorp/hcl/v2" 15 "github.com/hashicorp/hcl/v2/hclsimple" 16 "github.com/hashicorp/hcl/v2/hclsyntax" 17 logrustest "github.com/sirupsen/logrus/hooks/test" 18 19 hclbody "github.com/avenga/couper/config/body" 20 "github.com/avenga/couper/config/request" 21 "github.com/avenga/couper/config/sequence" 22 "github.com/avenga/couper/errors" 23 "github.com/avenga/couper/eval" 24 "github.com/avenga/couper/eval/buffer" 25 "github.com/avenga/couper/handler" 26 "github.com/avenga/couper/handler/producer" 27 "github.com/avenga/couper/handler/transport" 28 "github.com/avenga/couper/internal/test" 29 "github.com/avenga/couper/logging" 30 "github.com/avenga/couper/server/writer" 31 "github.com/sirupsen/logrus" 32 ) 33 34 func TestEndpoint_RoundTrip_Eval(t *testing.T) { 35 type header map[string]string 36 37 type testCase struct { 38 name string 39 hcl string 40 method string 41 body io.Reader 42 wantHeader header 43 } 44 45 type hclBody struct { 46 Inline hcl.Body `hcl:",remain"` 47 } 48 49 origin := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 50 if r.Method == http.MethodPost { 51 if err := r.ParseForm(); err != nil { 52 t.Fatal(err) 53 } 54 } 55 56 rw.WriteHeader(http.StatusNoContent) 57 })) 58 defer origin.Close() 59 60 log, hook := logrustest.NewNullLogger() 61 logger := log.WithContext(context.Background()) 62 63 tests := []testCase{ 64 {"GET use request.Header", ` 65 set_response_headers = { 66 X-Method = request.method 67 }`, http.MethodGet, nil, header{"X-Method": http.MethodGet}}, 68 {"POST use request.form_body", ` 69 set_response_headers = { 70 X-Method = request.method 71 X-Form_Body = request.form_body.foo 72 }`, http.MethodPost, strings.NewReader(`foo=bar`), header{ 73 "X-Method": http.MethodPost, 74 "X-Form_Body": "bar", 75 }}, 76 } 77 78 evalCtx := eval.NewDefaultContext() 79 80 for _, tt := range tests { 81 t.Run(tt.name, func(subT *testing.T) { 82 helper := test.New(subT) 83 hook.Reset() 84 85 var remain hclBody 86 err := hclsimple.Decode("test.hcl", []byte(tt.hcl), evalCtx.HCLContext(), &remain) 87 helper.Must(err) 88 89 backend := transport.NewBackend( 90 hclbody.NewHCLSyntaxBodyWithStringAttr("origin", "http://"+origin.Listener.Addr().String()), 91 &transport.Config{NoProxyFromEnv: true}, nil, logger) 92 93 ep := handler.NewEndpoint(&handler.EndpointOptions{ 94 ErrorTemplate: errors.DefaultJSON, 95 Context: remain.Inline.(*hclsyntax.Body), 96 ReqBodyLimit: 1024, 97 Items: sequence.List{&sequence.Item{Name: "default"}}, 98 Producers: map[string]producer.Roundtrip{"default": &producer.Proxy{Name: "default", RoundTrip: backend}}, 99 }, logger, nil) 100 101 req := httptest.NewRequest(tt.method, "http://couper.io", tt.body) 102 if tt.body != nil { 103 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 104 } 105 106 helper.Must(eval.SetGetBody(req, buffer.Request, 1024)) 107 *req = *req.WithContext(evalCtx.WithClientRequest(req)) 108 109 rec := httptest.NewRecorder() 110 rw := writer.NewResponseWriter(rec, "") // crucial for working ep due to res.Write() 111 ep.ServeHTTP(rw, req) 112 rec.Flush() 113 res := rec.Result() 114 115 if res == nil { 116 subT.Log(hook.LastEntry().String()) 117 subT.Errorf("Expected a response") 118 return 119 } 120 121 if res.StatusCode != http.StatusNoContent { 122 subT.Errorf("Expected StatusNoContent 204, got: %q %d", res.Status, res.StatusCode) 123 subT.Log(hook.LastEntry().String()) 124 } 125 126 for k, v := range tt.wantHeader { 127 if got := res.Header.Get(k); got != v { 128 subT.Errorf("Expected value for header %q: %q, got: %q", k, v, got) 129 subT.Log(hook.LastEntry().String()) 130 } 131 } 132 133 }) 134 } 135 } 136 137 func TestEndpoint_RoundTripContext_Variables_json_body(t *testing.T) { 138 type want struct { 139 req test.Header 140 } 141 142 defaultMethods := []string{ 143 http.MethodGet, 144 http.MethodPost, 145 http.MethodPut, 146 http.MethodPatch, 147 http.MethodDelete, 148 http.MethodConnect, 149 http.MethodOptions, 150 } 151 152 origin := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 153 // reflect req headers 154 for k, v := range r.Header { 155 if !strings.HasPrefix(strings.ToLower(k), "x-") { 156 continue 157 } 158 rw.Header()[k] = v 159 } 160 rw.WriteHeader(http.StatusNoContent) 161 })) 162 defer origin.Close() 163 164 tests := []struct { 165 name string 166 inlineCtx string 167 methods []string 168 header test.Header 169 body string 170 want want 171 }{ 172 {"method /w body", ` 173 origin = "` + origin.URL + `" 174 set_request_headers = { 175 x-test = request.json_body.foo 176 }`, defaultMethods, test.Header{"Content-Type": "application/json"}, `{"foo": "bar"}`, want{req: test.Header{"x-test": "bar"}}, 177 }, 178 {"method /w body +json content-type", ` 179 origin = "` + origin.URL + `" 180 set_request_headers = { 181 x-test = request.json_body.foo 182 }`, defaultMethods, test.Header{"Content-Type": "applicAtion/foo+jsOn"}, `{"foo": "bar"}`, want{req: test.Header{"x-test": "bar"}}, 183 }, 184 {"method /w body wrong content-type", ` 185 origin = "` + origin.URL + `" 186 set_request_headers = { 187 x-test = request.json_body.foo 188 }`, defaultMethods, test.Header{"Content-Type": "application/fooson"}, `{"foo": "bar"}`, want{req: test.Header{"x-test": ""}}, 189 }, 190 {"method /w body", ` 191 origin = "` + origin.URL + `" 192 set_request_headers = { 193 x-test = request.json_body.foo 194 }`, []string{http.MethodTrace}, test.Header{"Content-Type": "application/json"}, `{"foo": "bar"}`, want{req: test.Header{"x-test": ""}}}, 195 {"method /wo body", ` 196 origin = "` + origin.URL + `" 197 set_request_headers = { 198 x-test = request.json_body.foo 199 }`, append(defaultMethods, http.MethodTrace), 200 test.Header{"Content-Type": "application/json"}, "", want{req: test.Header{"x-test": ""}}}, 201 } 202 203 log, _ := logrustest.NewNullLogger() 204 logger := log.WithContext(context.Background()) 205 206 for _, tt := range tests { 207 for _, method := range tt.methods { 208 t.Run(method+" "+tt.name, func(subT *testing.T) { 209 helper := test.New(subT) 210 211 backend := transport.NewBackend( 212 helper.NewInlineContext(tt.inlineCtx), 213 &transport.Config{NoProxyFromEnv: true}, nil, logger) 214 215 ep := handler.NewEndpoint(&handler.EndpointOptions{ 216 ErrorTemplate: errors.DefaultJSON, 217 Context: &hclsyntax.Body{}, 218 ReqBodyLimit: 1024, 219 Items: sequence.List{&sequence.Item{Name: "default"}}, 220 Producers: map[string]producer.Roundtrip{"default": &producer.Proxy{Name: "default", RoundTrip: backend}}, 221 }, logger, nil) 222 223 var body io.Reader 224 if tt.body != "" { 225 body = bytes.NewBufferString(tt.body) 226 } 227 req := httptest.NewRequest(method, "/", body) 228 tt.header.Set(req) 229 230 // normally injected by server/http 231 helper.Must(eval.SetGetBody(req, buffer.Request, 1024)) 232 *req = *req.WithContext(eval.NewDefaultContext().WithClientRequest(req)) 233 234 rec := httptest.NewRecorder() 235 rw := writer.NewResponseWriter(rec, "") // crucial for working ep due to res.Write() 236 ep.ServeHTTP(rw, req) 237 rec.Flush() 238 res := rec.Result() 239 240 for k, v := range tt.want.req { 241 if res.Header.Get(k) != v { 242 subT.Errorf("want: %q for key %q, got: %q", v, k, res.Header.Get(k)) 243 } 244 } 245 }) 246 } 247 } 248 } 249 250 // TestProxy_SetRoundtripContext_Null_Eval tests the handling with non-existing references or cty.Null evaluations. 251 func TestEndpoint_RoundTripContext_Null_Eval(t *testing.T) { 252 helper := test.New(t) 253 254 type testCase struct { 255 name string 256 remain string 257 ct string 258 expHeaders test.Header 259 } 260 261 clientPayload := []byte(`{ "client": true, "origin": false, "nil": null }`) 262 originPayload := []byte(`{ "client": false, "origin": true, "nil": null }`) 263 264 origin := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 265 clientData, err := io.ReadAll(r.Body) 266 helper.Must(err) 267 if !bytes.Equal(clientData, clientPayload) { 268 t.Errorf("Expected a request with client payload, got %q", string(clientData)) 269 rw.WriteHeader(http.StatusInternalServerError) 270 return 271 } 272 273 if ct := r.Header.Get("Content-Type"); ct != "" { 274 rw.Header().Set("Content-Type", ct) 275 } else { 276 rw.Header().Set("Content-Type", "application/json") 277 } 278 _, err = rw.Write(originPayload) 279 helper.Must(err) 280 })) 281 282 log, _ := logrustest.NewNullLogger() 283 logger := log.WithContext(context.Background()) 284 285 for _, tc := range []testCase{ 286 {"no eval", `path = "/"`, "", test.Header{}}, 287 {"json_body client field", `set_response_headers = { "x-client" = "my-val-x-${request.json_body.client}" }`, "", 288 test.Header{ 289 "x-client": "my-val-x-true", 290 }}, 291 {"json_body request/response", `set_response_headers = { 292 x-client = "my-val-x-${request.json_body.client}" 293 x-client2 = request.body 294 x-origin = "my-val-y-${backend_responses.default.json_body.origin}" 295 x-origin2 = backend_responses.default.body 296 }`, "", 297 test.Header{ 298 "x-client": "my-val-x-true", 299 "x-client2": `{ "client": true, "origin": false, "nil": null }`, 300 "x-origin": "my-val-y-true", 301 "x-origin2": `{ "client": false, "origin": true, "nil": null }`, 302 }}, 303 {"json_body request/response json variant", `set_response_headers = { 304 x-client = "my-val-x-${request.json_body.client}" 305 x-origin = "my-val-y-${backend_responses.default.json_body.origin}" 306 }`, "application/foo+json", 307 test.Header{ 308 "x-client": "my-val-x-true", 309 "x-origin": "my-val-y-true", 310 }}, 311 {"json_body non existing shared parent", `set_response_headers = { 312 x-client = request.json_body.not-there 313 x-client-nested = request.json_body.not-there.nested 314 }`, "application/foo+json", 315 test.Header{ 316 "x-client": "", 317 "x-client-nested": "", 318 }}, 319 {"json_body non existing field", `set_response_headers = { 320 "${backend_responses.default.json_body.not-there}" = "my-val-0-${backend_responses.default.json_body.origin}" 321 "${request.json_body.client}-my-val-a" = "my-val-b-${backend_responses.default.json_body.client}" 322 }`, "", 323 test.Header{"true-my-val-a": "my-val-b-false"}}, 324 {"json_body null value", `set_response_headers = { "x-null" = "${backend_responses.default.json_body.nil}" }`, "", test.Header{"x-null": ""}}, 325 } { 326 t.Run(tc.name, func(subT *testing.T) { 327 h := test.New(subT) 328 329 backend := transport.NewBackend( 330 hclbody.NewHCLSyntaxBodyWithStringAttr("origin", "http://"+origin.Listener.Addr().String()), 331 &transport.Config{NoProxyFromEnv: true}, nil, logger) 332 333 bufOpts := buffer.Must(helper.NewInlineContext(tc.remain)) 334 335 ep := handler.NewEndpoint(&handler.EndpointOptions{ 336 BufferOpts: bufOpts, 337 Context: helper.NewInlineContext(tc.remain), 338 ErrorTemplate: errors.DefaultJSON, 339 ReqBodyLimit: 1024, 340 Items: sequence.List{&sequence.Item{Name: "default"}}, 341 Producers: map[string]producer.Roundtrip{"default": &producer.Proxy{Name: "default", RoundTrip: backend}}, 342 }, logger, nil) 343 344 req := httptest.NewRequest(http.MethodPost, "http://localhost/", bytes.NewReader(clientPayload)) 345 helper.Must(eval.SetGetBody(req, bufOpts, 1024)) 346 if tc.ct != "" { 347 req.Header.Set("Content-Type", tc.ct) 348 } else { 349 req.Header.Set("Content-Type", "application/json") 350 } 351 req = req.WithContext(eval.NewDefaultContext().WithClientRequest(req)) 352 353 rec := httptest.NewRecorder() 354 rw := writer.NewResponseWriter(rec, "") // crucial for working ep due to res.Write() 355 ep.ServeHTTP(rw, req) 356 rec.Flush() 357 res := rec.Result() 358 359 if res.StatusCode != http.StatusOK { 360 subT.Errorf("Expected StatusOK, got: %d", res.StatusCode) 361 } 362 363 originData, err := io.ReadAll(res.Body) 364 h.Must(err) 365 366 if !bytes.Equal(originPayload, originData) { 367 subT.Errorf("Expected same origin payload, got:\n%s\nlog message:\n", string(originData)) 368 } 369 370 for k, v := range tc.expHeaders { 371 if res.Header.Get(k) != v { 372 subT.Errorf("%q: Expected header %q value: %q, got: %q", tc.name, k, v, res.Header.Get(k)) 373 } 374 } 375 }) 376 377 } 378 379 origin.Close() 380 } 381 382 var _ producer.Roundtrip = &mockProducerResult{} 383 384 type mockProducerResult struct { 385 rt http.RoundTripper 386 } 387 388 func (m *mockProducerResult) Produce(r *http.Request) *producer.Result { 389 if m == nil || m.rt == nil { 390 return nil 391 } 392 393 res, err := m.rt.RoundTrip(r) 394 return &producer.Result{ 395 RoundTripName: "default", 396 Beresp: res, 397 Err: err, 398 } 399 } 400 401 func (m *mockProducerResult) SetDependsOn(ps string) { 402 } 403 404 func TestEndpoint_ServeHTTP_FaultyDefaultResponse(t *testing.T) { 405 log, hook := test.NewLogger() 406 407 origin := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { 408 ico, _ := os.ReadFile("testdata/file/favicon.ico") 409 410 rw.Header().Set("Content-Encoding", "gzip") // wrong 411 rw.Header().Set("Content-Type", "text/html") // wrong 412 rw.Header().Set("Cache-Control", "no-cache, no-store, max-age=0") 413 414 _, err := rw.Write(ico) 415 if err != nil { 416 t.Error(err) 417 } 418 })) 419 defer origin.Close() 420 421 rt := transport.NewBackend( 422 hclbody.NewHCLSyntaxBodyWithStringAttr("origin", origin.URL), &transport.Config{}, 423 &transport.BackendOptions{}, log.WithContext(context.Background())) 424 425 mockProducer := &mockProducerResult{rt} 426 427 ep := handler.NewEndpoint(&handler.EndpointOptions{ 428 Context: &hclsyntax.Body{}, 429 ErrorTemplate: errors.DefaultJSON, 430 Items: sequence.List{&sequence.Item{Name: "default"}}, 431 Producers: map[string]producer.Roundtrip{"default": mockProducer}, 432 }, log.WithContext(context.Background()), nil) 433 434 ctx := context.Background() 435 req := httptest.NewRequest(http.MethodGet, "http://", nil).WithContext(ctx) 436 ctx = eval.NewDefaultContext().WithClientRequest(req) 437 ctx = context.WithValue(ctx, request.UID, "test123") 438 439 rec := httptest.NewRecorder() 440 rw := writer.NewResponseWriter(rec, "") 441 ep.ServeHTTP(rw, req.Clone(ctx)) 442 res := rec.Result() 443 444 if res.StatusCode == 0 { 445 t.Errorf("Fatal error: response status is zero") 446 if res.Header.Get("Couper-Error") != "internal server error" { 447 t.Errorf("Expected internal server error, got: %s", res.Header.Get("Couper-Error")) 448 } 449 } else if res.StatusCode != http.StatusOK { 450 t.Errorf("Expected status ok, got: %v", res.StatusCode) 451 } 452 453 for _, e := range hook.AllEntries() { 454 if e.Level != logrus.ErrorLevel { 455 continue 456 } 457 if e.Message != "backend error: body reset: gzip: invalid header" { 458 t.Errorf("Unexpected error message: %s", e.Message) 459 } 460 } 461 } 462 463 func TestEndpoint_ServeHTTP_Cancel(t *testing.T) { 464 log, hook := test.NewLogger() 465 slowOrigin := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { 466 time.Sleep(time.Second * 5) 467 rw.WriteHeader(http.StatusNoContent) 468 })) 469 defer slowOrigin.Close() 470 471 ctx, cancelFn := context.WithCancel(context.WithValue(context.Background(), request.UID, "test123")) 472 ctx = context.WithValue(ctx, request.StartTime, time.Now()) 473 474 rt := transport.NewBackend( 475 hclbody.NewHCLSyntaxBodyWithStringAttr("origin", slowOrigin.URL), &transport.Config{}, 476 &transport.BackendOptions{}, log.WithContext(context.Background())) 477 478 mockProducer := &mockProducerResult{rt} 479 480 ep := handler.NewEndpoint(&handler.EndpointOptions{ 481 Context: &hclsyntax.Body{}, 482 ErrorTemplate: errors.DefaultJSON, 483 Items: sequence.List{&sequence.Item{Name: "default"}}, 484 Producers: map[string]producer.Roundtrip{"default": mockProducer}, 485 }, log.WithContext(ctx), nil) 486 487 req := httptest.NewRequest(http.MethodGet, "https://couper.io/", nil) 488 ctx = eval.NewDefaultContext().WithClientRequest(req.WithContext(ctx)) 489 490 start := time.Now() 491 go func() { 492 time.Sleep(time.Second) 493 cancelFn() 494 }() 495 496 rec := httptest.NewRecorder() 497 access := logging.NewAccessLog(&logging.Config{}, log) 498 499 outreq := req.WithContext(ctx) 500 ep.ServeHTTP(rec, outreq) 501 access.Do(rec, outreq) 502 rec.Flush() 503 504 elapsed := time.Since(start) 505 if elapsed > time.Second+(time.Millisecond*50) { 506 t.Error("Expected canceled request") 507 } 508 509 for _, e := range hook.AllEntries() { 510 if e.Message == "client request error: context canceled" { 511 return 512 } 513 } 514 515 t.Error("Expected context canceled access log, got:\n") 516 for _, e := range hook.AllEntries() { 517 println(e.String()) 518 } 519 }