github.com/avenga/couper@v1.12.2/handler/endpoint_test.go (about)

     1  package handler_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"io"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"os"
    10  	"strings"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/hashicorp/hcl/v2"
    15  	"github.com/hashicorp/hcl/v2/hclsimple"
    16  	"github.com/hashicorp/hcl/v2/hclsyntax"
    17  	logrustest "github.com/sirupsen/logrus/hooks/test"
    18  
    19  	hclbody "github.com/avenga/couper/config/body"
    20  	"github.com/avenga/couper/config/request"
    21  	"github.com/avenga/couper/config/sequence"
    22  	"github.com/avenga/couper/errors"
    23  	"github.com/avenga/couper/eval"
    24  	"github.com/avenga/couper/eval/buffer"
    25  	"github.com/avenga/couper/handler"
    26  	"github.com/avenga/couper/handler/producer"
    27  	"github.com/avenga/couper/handler/transport"
    28  	"github.com/avenga/couper/internal/test"
    29  	"github.com/avenga/couper/logging"
    30  	"github.com/avenga/couper/server/writer"
    31  	"github.com/sirupsen/logrus"
    32  )
    33  
    34  func TestEndpoint_RoundTrip_Eval(t *testing.T) {
    35  	type header map[string]string
    36  
    37  	type testCase struct {
    38  		name       string
    39  		hcl        string
    40  		method     string
    41  		body       io.Reader
    42  		wantHeader header
    43  	}
    44  
    45  	type hclBody struct {
    46  		Inline hcl.Body `hcl:",remain"`
    47  	}
    48  
    49  	origin := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
    50  		if r.Method == http.MethodPost {
    51  			if err := r.ParseForm(); err != nil {
    52  				t.Fatal(err)
    53  			}
    54  		}
    55  
    56  		rw.WriteHeader(http.StatusNoContent)
    57  	}))
    58  	defer origin.Close()
    59  
    60  	log, hook := logrustest.NewNullLogger()
    61  	logger := log.WithContext(context.Background())
    62  
    63  	tests := []testCase{
    64  		{"GET use request.Header", `
    65  		set_response_headers = {
    66  			X-Method = request.method
    67  		}`, http.MethodGet, nil, header{"X-Method": http.MethodGet}},
    68  		{"POST use request.form_body", `
    69  		set_response_headers = {
    70  			X-Method = request.method
    71  			X-Form_Body = request.form_body.foo
    72  		}`, http.MethodPost, strings.NewReader(`foo=bar`), header{
    73  			"X-Method":    http.MethodPost,
    74  			"X-Form_Body": "bar",
    75  		}},
    76  	}
    77  
    78  	evalCtx := eval.NewDefaultContext()
    79  
    80  	for _, tt := range tests {
    81  		t.Run(tt.name, func(subT *testing.T) {
    82  			helper := test.New(subT)
    83  			hook.Reset()
    84  
    85  			var remain hclBody
    86  			err := hclsimple.Decode("test.hcl", []byte(tt.hcl), evalCtx.HCLContext(), &remain)
    87  			helper.Must(err)
    88  
    89  			backend := transport.NewBackend(
    90  				hclbody.NewHCLSyntaxBodyWithStringAttr("origin", "http://"+origin.Listener.Addr().String()),
    91  				&transport.Config{NoProxyFromEnv: true}, nil, logger)
    92  
    93  			ep := handler.NewEndpoint(&handler.EndpointOptions{
    94  				ErrorTemplate: errors.DefaultJSON,
    95  				Context:       remain.Inline.(*hclsyntax.Body),
    96  				ReqBodyLimit:  1024,
    97  				Items:         sequence.List{&sequence.Item{Name: "default"}},
    98  				Producers:     map[string]producer.Roundtrip{"default": &producer.Proxy{Name: "default", RoundTrip: backend}},
    99  			}, logger, nil)
   100  
   101  			req := httptest.NewRequest(tt.method, "http://couper.io", tt.body)
   102  			if tt.body != nil {
   103  				req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
   104  			}
   105  
   106  			helper.Must(eval.SetGetBody(req, buffer.Request, 1024))
   107  			*req = *req.WithContext(evalCtx.WithClientRequest(req))
   108  
   109  			rec := httptest.NewRecorder()
   110  			rw := writer.NewResponseWriter(rec, "") // crucial for working ep due to res.Write()
   111  			ep.ServeHTTP(rw, req)
   112  			rec.Flush()
   113  			res := rec.Result()
   114  
   115  			if res == nil {
   116  				subT.Log(hook.LastEntry().String())
   117  				subT.Errorf("Expected a response")
   118  				return
   119  			}
   120  
   121  			if res.StatusCode != http.StatusNoContent {
   122  				subT.Errorf("Expected StatusNoContent 204, got: %q %d", res.Status, res.StatusCode)
   123  				subT.Log(hook.LastEntry().String())
   124  			}
   125  
   126  			for k, v := range tt.wantHeader {
   127  				if got := res.Header.Get(k); got != v {
   128  					subT.Errorf("Expected value for header %q: %q, got: %q", k, v, got)
   129  					subT.Log(hook.LastEntry().String())
   130  				}
   131  			}
   132  
   133  		})
   134  	}
   135  }
   136  
   137  func TestEndpoint_RoundTripContext_Variables_json_body(t *testing.T) {
   138  	type want struct {
   139  		req test.Header
   140  	}
   141  
   142  	defaultMethods := []string{
   143  		http.MethodGet,
   144  		http.MethodPost,
   145  		http.MethodPut,
   146  		http.MethodPatch,
   147  		http.MethodDelete,
   148  		http.MethodConnect,
   149  		http.MethodOptions,
   150  	}
   151  
   152  	origin := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
   153  		// reflect req headers
   154  		for k, v := range r.Header {
   155  			if !strings.HasPrefix(strings.ToLower(k), "x-") {
   156  				continue
   157  			}
   158  			rw.Header()[k] = v
   159  		}
   160  		rw.WriteHeader(http.StatusNoContent)
   161  	}))
   162  	defer origin.Close()
   163  
   164  	tests := []struct {
   165  		name      string
   166  		inlineCtx string
   167  		methods   []string
   168  		header    test.Header
   169  		body      string
   170  		want      want
   171  	}{
   172  		{"method /w body", `
   173  		origin = "` + origin.URL + `"
   174  		set_request_headers = {
   175  			x-test = request.json_body.foo
   176  		}`, defaultMethods, test.Header{"Content-Type": "application/json"}, `{"foo": "bar"}`, want{req: test.Header{"x-test": "bar"}},
   177  		},
   178  		{"method /w body +json content-type", `
   179  		origin = "` + origin.URL + `"
   180  		set_request_headers = {
   181  			x-test = request.json_body.foo
   182  		}`, defaultMethods, test.Header{"Content-Type": "applicAtion/foo+jsOn"}, `{"foo": "bar"}`, want{req: test.Header{"x-test": "bar"}},
   183  		},
   184  		{"method /w body wrong content-type", `
   185  		origin = "` + origin.URL + `"
   186  		set_request_headers = {
   187  			x-test = request.json_body.foo
   188  		}`, defaultMethods, test.Header{"Content-Type": "application/fooson"}, `{"foo": "bar"}`, want{req: test.Header{"x-test": ""}},
   189  		},
   190  		{"method /w body", `
   191  		origin = "` + origin.URL + `"
   192  		set_request_headers = {
   193  			x-test = request.json_body.foo
   194  		}`, []string{http.MethodTrace}, test.Header{"Content-Type": "application/json"}, `{"foo": "bar"}`, want{req: test.Header{"x-test": ""}}},
   195  		{"method /wo body", `
   196  		origin = "` + origin.URL + `"
   197  		set_request_headers = {
   198  			x-test = request.json_body.foo
   199  		}`, append(defaultMethods, http.MethodTrace),
   200  			test.Header{"Content-Type": "application/json"}, "", want{req: test.Header{"x-test": ""}}},
   201  	}
   202  
   203  	log, _ := logrustest.NewNullLogger()
   204  	logger := log.WithContext(context.Background())
   205  
   206  	for _, tt := range tests {
   207  		for _, method := range tt.methods {
   208  			t.Run(method+" "+tt.name, func(subT *testing.T) {
   209  				helper := test.New(subT)
   210  
   211  				backend := transport.NewBackend(
   212  					helper.NewInlineContext(tt.inlineCtx),
   213  					&transport.Config{NoProxyFromEnv: true}, nil, logger)
   214  
   215  				ep := handler.NewEndpoint(&handler.EndpointOptions{
   216  					ErrorTemplate: errors.DefaultJSON,
   217  					Context:       &hclsyntax.Body{},
   218  					ReqBodyLimit:  1024,
   219  					Items:         sequence.List{&sequence.Item{Name: "default"}},
   220  					Producers:     map[string]producer.Roundtrip{"default": &producer.Proxy{Name: "default", RoundTrip: backend}},
   221  				}, logger, nil)
   222  
   223  				var body io.Reader
   224  				if tt.body != "" {
   225  					body = bytes.NewBufferString(tt.body)
   226  				}
   227  				req := httptest.NewRequest(method, "/", body)
   228  				tt.header.Set(req)
   229  
   230  				// normally injected by server/http
   231  				helper.Must(eval.SetGetBody(req, buffer.Request, 1024))
   232  				*req = *req.WithContext(eval.NewDefaultContext().WithClientRequest(req))
   233  
   234  				rec := httptest.NewRecorder()
   235  				rw := writer.NewResponseWriter(rec, "") // crucial for working ep due to res.Write()
   236  				ep.ServeHTTP(rw, req)
   237  				rec.Flush()
   238  				res := rec.Result()
   239  
   240  				for k, v := range tt.want.req {
   241  					if res.Header.Get(k) != v {
   242  						subT.Errorf("want: %q for key %q, got: %q", v, k, res.Header.Get(k))
   243  					}
   244  				}
   245  			})
   246  		}
   247  	}
   248  }
   249  
   250  // TestProxy_SetRoundtripContext_Null_Eval tests the handling with non-existing references or cty.Null evaluations.
   251  func TestEndpoint_RoundTripContext_Null_Eval(t *testing.T) {
   252  	helper := test.New(t)
   253  
   254  	type testCase struct {
   255  		name       string
   256  		remain     string
   257  		ct         string
   258  		expHeaders test.Header
   259  	}
   260  
   261  	clientPayload := []byte(`{ "client": true, "origin": false, "nil": null }`)
   262  	originPayload := []byte(`{ "client": false, "origin": true, "nil": null }`)
   263  
   264  	origin := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
   265  		clientData, err := io.ReadAll(r.Body)
   266  		helper.Must(err)
   267  		if !bytes.Equal(clientData, clientPayload) {
   268  			t.Errorf("Expected a request with client payload, got %q", string(clientData))
   269  			rw.WriteHeader(http.StatusInternalServerError)
   270  			return
   271  		}
   272  
   273  		if ct := r.Header.Get("Content-Type"); ct != "" {
   274  			rw.Header().Set("Content-Type", ct)
   275  		} else {
   276  			rw.Header().Set("Content-Type", "application/json")
   277  		}
   278  		_, err = rw.Write(originPayload)
   279  		helper.Must(err)
   280  	}))
   281  
   282  	log, _ := logrustest.NewNullLogger()
   283  	logger := log.WithContext(context.Background())
   284  
   285  	for _, tc := range []testCase{
   286  		{"no eval", `path = "/"`, "", test.Header{}},
   287  		{"json_body client field", `set_response_headers = { "x-client" = "my-val-x-${request.json_body.client}" }`, "",
   288  			test.Header{
   289  				"x-client": "my-val-x-true",
   290  			}},
   291  		{"json_body request/response", `set_response_headers = {
   292  				x-client = "my-val-x-${request.json_body.client}"
   293  				x-client2 = request.body
   294  				x-origin = "my-val-y-${backend_responses.default.json_body.origin}"
   295  				x-origin2 = backend_responses.default.body
   296  			}`, "",
   297  			test.Header{
   298  				"x-client":  "my-val-x-true",
   299  				"x-client2": `{ "client": true, "origin": false, "nil": null }`,
   300  				"x-origin":  "my-val-y-true",
   301  				"x-origin2": `{ "client": false, "origin": true, "nil": null }`,
   302  			}},
   303  		{"json_body request/response json variant", `set_response_headers = {
   304  				x-client = "my-val-x-${request.json_body.client}"
   305  				x-origin = "my-val-y-${backend_responses.default.json_body.origin}"
   306  			}`, "application/foo+json",
   307  			test.Header{
   308  				"x-client": "my-val-x-true",
   309  				"x-origin": "my-val-y-true",
   310  			}},
   311  		{"json_body non existing shared parent", `set_response_headers = {
   312  				x-client = request.json_body.not-there
   313  				x-client-nested = request.json_body.not-there.nested
   314  			}`, "application/foo+json",
   315  			test.Header{
   316  				"x-client":        "",
   317  				"x-client-nested": "",
   318  			}},
   319  		{"json_body non existing field", `set_response_headers = {
   320  "${backend_responses.default.json_body.not-there}" = "my-val-0-${backend_responses.default.json_body.origin}"
   321  "${request.json_body.client}-my-val-a" = "my-val-b-${backend_responses.default.json_body.client}"
   322  }`, "",
   323  			test.Header{"true-my-val-a": "my-val-b-false"}},
   324  		{"json_body null value", `set_response_headers = { "x-null" = "${backend_responses.default.json_body.nil}" }`, "", test.Header{"x-null": ""}},
   325  	} {
   326  		t.Run(tc.name, func(subT *testing.T) {
   327  			h := test.New(subT)
   328  
   329  			backend := transport.NewBackend(
   330  				hclbody.NewHCLSyntaxBodyWithStringAttr("origin", "http://"+origin.Listener.Addr().String()),
   331  				&transport.Config{NoProxyFromEnv: true}, nil, logger)
   332  
   333  			bufOpts := buffer.Must(helper.NewInlineContext(tc.remain))
   334  
   335  			ep := handler.NewEndpoint(&handler.EndpointOptions{
   336  				BufferOpts:    bufOpts,
   337  				Context:       helper.NewInlineContext(tc.remain),
   338  				ErrorTemplate: errors.DefaultJSON,
   339  				ReqBodyLimit:  1024,
   340  				Items:         sequence.List{&sequence.Item{Name: "default"}},
   341  				Producers:     map[string]producer.Roundtrip{"default": &producer.Proxy{Name: "default", RoundTrip: backend}},
   342  			}, logger, nil)
   343  
   344  			req := httptest.NewRequest(http.MethodPost, "http://localhost/", bytes.NewReader(clientPayload))
   345  			helper.Must(eval.SetGetBody(req, bufOpts, 1024))
   346  			if tc.ct != "" {
   347  				req.Header.Set("Content-Type", tc.ct)
   348  			} else {
   349  				req.Header.Set("Content-Type", "application/json")
   350  			}
   351  			req = req.WithContext(eval.NewDefaultContext().WithClientRequest(req))
   352  
   353  			rec := httptest.NewRecorder()
   354  			rw := writer.NewResponseWriter(rec, "") // crucial for working ep due to res.Write()
   355  			ep.ServeHTTP(rw, req)
   356  			rec.Flush()
   357  			res := rec.Result()
   358  
   359  			if res.StatusCode != http.StatusOK {
   360  				subT.Errorf("Expected StatusOK, got: %d", res.StatusCode)
   361  			}
   362  
   363  			originData, err := io.ReadAll(res.Body)
   364  			h.Must(err)
   365  
   366  			if !bytes.Equal(originPayload, originData) {
   367  				subT.Errorf("Expected same origin payload, got:\n%s\nlog message:\n", string(originData))
   368  			}
   369  
   370  			for k, v := range tc.expHeaders {
   371  				if res.Header.Get(k) != v {
   372  					subT.Errorf("%q: Expected header %q value: %q, got: %q", tc.name, k, v, res.Header.Get(k))
   373  				}
   374  			}
   375  		})
   376  
   377  	}
   378  
   379  	origin.Close()
   380  }
   381  
   382  var _ producer.Roundtrip = &mockProducerResult{}
   383  
   384  type mockProducerResult struct {
   385  	rt http.RoundTripper
   386  }
   387  
   388  func (m *mockProducerResult) Produce(r *http.Request) *producer.Result {
   389  	if m == nil || m.rt == nil {
   390  		return nil
   391  	}
   392  
   393  	res, err := m.rt.RoundTrip(r)
   394  	return &producer.Result{
   395  		RoundTripName: "default",
   396  		Beresp:        res,
   397  		Err:           err,
   398  	}
   399  }
   400  
   401  func (m *mockProducerResult) SetDependsOn(ps string) {
   402  }
   403  
   404  func TestEndpoint_ServeHTTP_FaultyDefaultResponse(t *testing.T) {
   405  	log, hook := test.NewLogger()
   406  
   407  	origin := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
   408  		ico, _ := os.ReadFile("testdata/file/favicon.ico")
   409  
   410  		rw.Header().Set("Content-Encoding", "gzip")  // wrong
   411  		rw.Header().Set("Content-Type", "text/html") // wrong
   412  		rw.Header().Set("Cache-Control", "no-cache, no-store, max-age=0")
   413  
   414  		_, err := rw.Write(ico)
   415  		if err != nil {
   416  			t.Error(err)
   417  		}
   418  	}))
   419  	defer origin.Close()
   420  
   421  	rt := transport.NewBackend(
   422  		hclbody.NewHCLSyntaxBodyWithStringAttr("origin", origin.URL), &transport.Config{},
   423  		&transport.BackendOptions{}, log.WithContext(context.Background()))
   424  
   425  	mockProducer := &mockProducerResult{rt}
   426  
   427  	ep := handler.NewEndpoint(&handler.EndpointOptions{
   428  		Context:       &hclsyntax.Body{},
   429  		ErrorTemplate: errors.DefaultJSON,
   430  		Items:         sequence.List{&sequence.Item{Name: "default"}},
   431  		Producers:     map[string]producer.Roundtrip{"default": mockProducer},
   432  	}, log.WithContext(context.Background()), nil)
   433  
   434  	ctx := context.Background()
   435  	req := httptest.NewRequest(http.MethodGet, "http://", nil).WithContext(ctx)
   436  	ctx = eval.NewDefaultContext().WithClientRequest(req)
   437  	ctx = context.WithValue(ctx, request.UID, "test123")
   438  
   439  	rec := httptest.NewRecorder()
   440  	rw := writer.NewResponseWriter(rec, "")
   441  	ep.ServeHTTP(rw, req.Clone(ctx))
   442  	res := rec.Result()
   443  
   444  	if res.StatusCode == 0 {
   445  		t.Errorf("Fatal error: response status is zero")
   446  		if res.Header.Get("Couper-Error") != "internal server error" {
   447  			t.Errorf("Expected internal server error, got: %s", res.Header.Get("Couper-Error"))
   448  		}
   449  	} else if res.StatusCode != http.StatusOK {
   450  		t.Errorf("Expected status ok, got: %v", res.StatusCode)
   451  	}
   452  
   453  	for _, e := range hook.AllEntries() {
   454  		if e.Level != logrus.ErrorLevel {
   455  			continue
   456  		}
   457  		if e.Message != "backend error: body reset: gzip: invalid header" {
   458  			t.Errorf("Unexpected error message: %s", e.Message)
   459  		}
   460  	}
   461  }
   462  
   463  func TestEndpoint_ServeHTTP_Cancel(t *testing.T) {
   464  	log, hook := test.NewLogger()
   465  	slowOrigin := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
   466  		time.Sleep(time.Second * 5)
   467  		rw.WriteHeader(http.StatusNoContent)
   468  	}))
   469  	defer slowOrigin.Close()
   470  
   471  	ctx, cancelFn := context.WithCancel(context.WithValue(context.Background(), request.UID, "test123"))
   472  	ctx = context.WithValue(ctx, request.StartTime, time.Now())
   473  
   474  	rt := transport.NewBackend(
   475  		hclbody.NewHCLSyntaxBodyWithStringAttr("origin", slowOrigin.URL), &transport.Config{},
   476  		&transport.BackendOptions{}, log.WithContext(context.Background()))
   477  
   478  	mockProducer := &mockProducerResult{rt}
   479  
   480  	ep := handler.NewEndpoint(&handler.EndpointOptions{
   481  		Context:       &hclsyntax.Body{},
   482  		ErrorTemplate: errors.DefaultJSON,
   483  		Items:         sequence.List{&sequence.Item{Name: "default"}},
   484  		Producers:     map[string]producer.Roundtrip{"default": mockProducer},
   485  	}, log.WithContext(ctx), nil)
   486  
   487  	req := httptest.NewRequest(http.MethodGet, "https://couper.io/", nil)
   488  	ctx = eval.NewDefaultContext().WithClientRequest(req.WithContext(ctx))
   489  
   490  	start := time.Now()
   491  	go func() {
   492  		time.Sleep(time.Second)
   493  		cancelFn()
   494  	}()
   495  
   496  	rec := httptest.NewRecorder()
   497  	access := logging.NewAccessLog(&logging.Config{}, log)
   498  
   499  	outreq := req.WithContext(ctx)
   500  	ep.ServeHTTP(rec, outreq)
   501  	access.Do(rec, outreq)
   502  	rec.Flush()
   503  
   504  	elapsed := time.Since(start)
   505  	if elapsed > time.Second+(time.Millisecond*50) {
   506  		t.Error("Expected canceled request")
   507  	}
   508  
   509  	for _, e := range hook.AllEntries() {
   510  		if e.Message == "client request error: context canceled" {
   511  			return
   512  		}
   513  	}
   514  
   515  	t.Error("Expected context canceled access log, got:\n")
   516  	for _, e := range hook.AllEntries() {
   517  		println(e.String())
   518  	}
   519  }