cuelang.org/go@v0.10.1/internal/httplog/client_test.go (about)

     1  package httplog
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"io"
     8  	"log/slog"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"net/url"
    12  	"strings"
    13  	"testing"
    14  
    15  	"github.com/go-quicktest/qt"
    16  )
    17  
    18  func TestTransportWithSlog(t *testing.T) {
    19  	seq.Store(0)
    20  	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
    21  		w.Write([]byte("hello"))
    22  	}))
    23  	var buf strings.Builder
    24  
    25  	client := &http.Client{
    26  		Transport: Transport(&TransportConfig{
    27  			Logger: SlogLogger{
    28  				Logger: slog.New(slog.NewJSONHandler(&buf, nil)),
    29  			},
    30  		}),
    31  	}
    32  	resp, err := client.Get(srv.URL + "/foo/bar?foo=bar")
    33  	qt.Assert(t, qt.IsNil(err))
    34  	data, err := io.ReadAll(resp.Body)
    35  	qt.Assert(t, qt.IsNil(err))
    36  	resp.Body.Close()
    37  	qt.Assert(t, qt.Equals(string(data), "hello"))
    38  
    39  	qt.Assert(t, qt.Matches(buf.String(), `
    40  {"time":"\d\d\d\d-[^"]+","level":"INFO","msg":"http client->","info":{"id":1,"method":"GET","url":"http://[^/]+/foo/bar\?foo=REDACTED","contentLength":0,"header":{}}}
    41  {"time":"\d\d\d\d-[^"]+","level":"INFO","msg":"http client<-","info":{"id":1,"method":"GET","url":"http://[^/]+/foo/bar\?foo=REDACTED","statusCode":200,"header":{"Content-Length":\["5"\],"Content-Type":\["text/plain; charset=utf-8"\],"Date":\[.*\]},"body":"hello"}}
    42  `[1:]))
    43  }
    44  
    45  func TestQueryParamAllowList(t *testing.T) {
    46  	seq.Store(10)
    47  	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {}))
    48  
    49  	var recorder logRecorder
    50  	client := &http.Client{
    51  		Transport: Transport(&TransportConfig{
    52  			Logger: &recorder,
    53  		}),
    54  	}
    55  	ctx := ContextWithAllowedURLQueryParams(context.Background(),
    56  		func(key string) bool {
    57  			return key == "x1"
    58  		},
    59  	)
    60  	req, err := http.NewRequestWithContext(ctx, "GET", srv.URL+"/foo/bar?x1=ok&x2=redact1&x2=redact2", nil)
    61  	qt.Assert(t, qt.IsNil(err))
    62  	resp, err := client.Do(req)
    63  	qt.Assert(t, qt.IsNil(err))
    64  	resp.Body.Close()
    65  	req, err = http.NewRequestWithContext(ctx, "GET", srv.URL+"/foo/bar?x1=ok1&x1=ok2", nil)
    66  	qt.Assert(t, qt.IsNil(err))
    67  	resp, err = client.Do(req)
    68  	qt.Assert(t, qt.IsNil(err))
    69  	resp.Body.Close()
    70  	qt.Assert(t, qt.DeepEquals(recorder.EventKinds, []EventKind{
    71  		KindClientSendRequest,
    72  		KindClientRecvResponse,
    73  		KindClientSendRequest,
    74  		KindClientRecvResponse,
    75  	}))
    76  	qt.Assert(t, qt.DeepEquals(recorder.Events, []RequestOrResponse{
    77  		&Request{
    78  			ID:     11,
    79  			Method: "GET",
    80  			URL:    "http://localhost/foo/bar?x1=ok&x2=REDACTED&x2=REDACTED",
    81  			Header: http.Header{},
    82  		},
    83  		&Response{
    84  			ID:         11,
    85  			Method:     "GET",
    86  			URL:        "http://localhost/foo/bar?x1=ok&x2=REDACTED&x2=REDACTED",
    87  			StatusCode: http.StatusOK,
    88  			Header: http.Header{
    89  				"Content-Length": {"0"},
    90  				"Date":           {"now"},
    91  			},
    92  		},
    93  		&Request{
    94  			ID:     12,
    95  			Method: "GET",
    96  			URL:    "http://localhost/foo/bar?x1=ok1&x1=ok2",
    97  			Header: http.Header{},
    98  		},
    99  		&Response{
   100  			ID:         12,
   101  			Method:     "GET",
   102  			URL:        "http://localhost/foo/bar?x1=ok1&x1=ok2",
   103  			StatusCode: http.StatusOK,
   104  			Header: http.Header{
   105  				"Content-Length": {"0"},
   106  				"Date":           {"now"},
   107  			},
   108  		},
   109  	}))
   110  }
   111  
   112  func TestAuthorizationHeaderRedacted(t *testing.T) {
   113  	seq.Store(10)
   114  	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {}))
   115  
   116  	var recorder logRecorder
   117  	client := &http.Client{
   118  		Transport: Transport(&TransportConfig{
   119  			Logger: &recorder,
   120  		}),
   121  	}
   122  	ctx := context.Background()
   123  
   124  	req, err := http.NewRequestWithContext(ctx, "GET", srv.URL+"/foo", nil)
   125  	qt.Assert(t, qt.IsNil(err))
   126  	req.SetBasicAuth("someuser", "somepassword")
   127  	req.Header.Add("Authorization", "Bearer sensitive-info")
   128  	req.Header.Add("Authorization", "othertoken")
   129  
   130  	resp, err := client.Do(req)
   131  	qt.Assert(t, qt.IsNil(err))
   132  	resp.Body.Close()
   133  	qt.Assert(t, qt.DeepEquals(recorder.EventKinds, []EventKind{
   134  		KindClientSendRequest,
   135  		KindClientRecvResponse,
   136  	}))
   137  	qt.Assert(t, qt.DeepEquals(recorder.Events, []RequestOrResponse{
   138  		&Request{
   139  			ID:     11,
   140  			Method: "GET",
   141  			URL:    "http://localhost/foo",
   142  			Header: http.Header{
   143  				"Authorization": {
   144  					"Basic REDACTED",
   145  					"Bearer REDACTED",
   146  					"REDACTED",
   147  				},
   148  			},
   149  		},
   150  		&Response{
   151  			ID:         11,
   152  			Method:     "GET",
   153  			URL:        "http://localhost/foo",
   154  			StatusCode: http.StatusOK,
   155  			Header: http.Header{
   156  				"Content-Length": {"0"},
   157  				"Date":           {"now"},
   158  			},
   159  		},
   160  	}))
   161  }
   162  
   163  func TestIncludeAllQueryParams(t *testing.T) {
   164  	seq.Store(10)
   165  	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {}))
   166  
   167  	var recorder logRecorder
   168  	client := &http.Client{
   169  		Transport: Transport(&TransportConfig{
   170  			Logger:                &recorder,
   171  			IncludeAllQueryParams: true,
   172  		}),
   173  	}
   174  	ctx := context.Background()
   175  	req, err := http.NewRequestWithContext(ctx, "GET", srv.URL+"/foo/bar?x1=ok&x2=redact1&x2=redact2", nil)
   176  	qt.Assert(t, qt.IsNil(err))
   177  	resp, err := client.Do(req)
   178  	qt.Assert(t, qt.IsNil(err))
   179  	resp.Body.Close()
   180  	qt.Assert(t, qt.DeepEquals(recorder.EventKinds, []EventKind{KindClientSendRequest, KindClientRecvResponse}))
   181  	qt.Assert(t, qt.DeepEquals(recorder.Events, []RequestOrResponse{
   182  		&Request{
   183  			ID:     11,
   184  			Method: "GET",
   185  			URL:    "http://localhost/foo/bar?x1=ok&x2=redact1&x2=redact2",
   186  			Header: http.Header{},
   187  		},
   188  		&Response{
   189  			ID:         11,
   190  			Method:     "GET",
   191  			URL:        "http://localhost/foo/bar?x1=ok&x2=redact1&x2=redact2",
   192  			StatusCode: http.StatusOK,
   193  			Header: http.Header{
   194  				"Content-Length": {"0"},
   195  				"Date":           {"now"},
   196  			},
   197  		},
   198  	}))
   199  }
   200  
   201  func TestOmitBody(t *testing.T) {
   202  	seq.Store(10)
   203  	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   204  		w.Write([]byte("response body"))
   205  	}))
   206  
   207  	var recorder logRecorder
   208  	client := &http.Client{
   209  		Transport: Transport(&TransportConfig{
   210  			Logger: &recorder,
   211  		}),
   212  	}
   213  	ctx := context.Background()
   214  	ctx = RedactRequestBody(ctx, "not keen on request bodies")
   215  	ctx = RedactResponseBody(ctx, "response bodies are right out")
   216  	req, err := http.NewRequestWithContext(ctx, "GET", srv.URL+"/foo/bar", strings.NewReader("request body"))
   217  	qt.Assert(t, qt.IsNil(err))
   218  	resp, err := client.Do(req)
   219  	qt.Assert(t, qt.IsNil(err))
   220  	resp.Body.Close()
   221  	qt.Assert(t, qt.DeepEquals(recorder.EventKinds, []EventKind{KindClientSendRequest, KindClientRecvResponse}))
   222  	qt.Assert(t, qt.DeepEquals(recorder.Events, []RequestOrResponse{
   223  		&Request{
   224  			ID:            11,
   225  			Method:        "GET",
   226  			ContentLength: 12,
   227  			URL:           "http://localhost/foo/bar",
   228  			Header:        http.Header{},
   229  			BodyData: BodyData{
   230  				BodyRedactedBecause: "not keen on request bodies",
   231  			},
   232  		},
   233  		&Response{
   234  			ID:         11,
   235  			Method:     "GET",
   236  			URL:        "http://localhost/foo/bar",
   237  			StatusCode: http.StatusOK,
   238  			Header: http.Header{
   239  				"Content-Length": {"13"},
   240  				"Content-Type":   {"text/plain; charset=utf-8"},
   241  				"Date":           {"now"},
   242  			},
   243  			BodyData: BodyData{
   244  				BodyRedactedBecause: "response bodies are right out",
   245  			},
   246  		},
   247  	}))
   248  }
   249  
   250  func TestLongBody(t *testing.T) {
   251  	seq.Store(10)
   252  	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   253  		data, err := io.ReadAll(req.Body)
   254  		qt.Check(t, qt.IsNil(err))
   255  		qt.Check(t, qt.Equals(string(data), strings.Repeat("a", 30)))
   256  		w.Write(bytes.Repeat([]byte("b"), 20))
   257  	}))
   258  
   259  	var recorder logRecorder
   260  	client := &http.Client{
   261  		Transport: Transport(&TransportConfig{
   262  			Logger:      &recorder,
   263  			MaxBodySize: 10,
   264  		}),
   265  	}
   266  	ctx := context.Background()
   267  	req, err := http.NewRequestWithContext(ctx, "GET", srv.URL+"/foo/bar", strings.NewReader(strings.Repeat("a", 30)))
   268  	qt.Assert(t, qt.IsNil(err))
   269  	resp, err := client.Do(req)
   270  	qt.Assert(t, qt.IsNil(err))
   271  	data, err := io.ReadAll(resp.Body)
   272  	qt.Assert(t, qt.IsNil(err))
   273  	qt.Assert(t, qt.Equals(string(data), strings.Repeat("b", 20)))
   274  	resp.Body.Close()
   275  	qt.Assert(t, qt.DeepEquals(recorder.EventKinds, []EventKind{KindClientSendRequest, KindClientRecvResponse}))
   276  	qt.Assert(t, qt.DeepEquals(recorder.Events, []RequestOrResponse{
   277  		&Request{
   278  			ID:            11,
   279  			Method:        "GET",
   280  			ContentLength: 30,
   281  			URL:           "http://localhost/foo/bar",
   282  			Header:        http.Header{},
   283  			BodyData: BodyData{
   284  				Body:          strings.Repeat("a", 10),
   285  				BodyTruncated: true,
   286  			},
   287  		},
   288  		&Response{
   289  			ID:         11,
   290  			Method:     "GET",
   291  			URL:        "http://localhost/foo/bar",
   292  			StatusCode: http.StatusOK,
   293  			Header: http.Header{
   294  				"Content-Length": {"20"},
   295  				"Content-Type":   {"text/plain; charset=utf-8"},
   296  				"Date":           {"now"},
   297  			},
   298  			BodyData: BodyData{
   299  				Body:          strings.Repeat("b", 10),
   300  				BodyTruncated: true,
   301  			},
   302  		},
   303  	}))
   304  }
   305  
   306  func TestRoundTripError(t *testing.T) {
   307  	seq.Store(10)
   308  
   309  	var recorder logRecorder
   310  	client := &http.Client{
   311  		Transport: Transport(&TransportConfig{
   312  			Transport: errorTransport{},
   313  			Logger:    &recorder,
   314  		}),
   315  	}
   316  	ctx := context.Background()
   317  	req, err := http.NewRequestWithContext(ctx, "GET", "http://localhost:1234/foo/bar", nil)
   318  	qt.Assert(t, qt.IsNil(err))
   319  	_, err = client.Do(req)
   320  	qt.Assert(t, qt.ErrorMatches(err, `Get "http://localhost:1234/foo/bar": error in RoundTrip`))
   321  	qt.Assert(t, qt.DeepEquals(recorder.EventKinds, []EventKind{KindClientSendRequest, KindClientRecvResponse}))
   322  	qt.Assert(t, qt.DeepEquals(recorder.Events, []RequestOrResponse{
   323  		&Request{
   324  			ID:     11,
   325  			Method: "GET",
   326  			URL:    "http://localhost/foo/bar",
   327  			Header: http.Header{},
   328  		},
   329  		&Response{
   330  			ID:     11,
   331  			Method: "GET",
   332  			URL:    "http://localhost/foo/bar",
   333  			Error:  "error in RoundTrip",
   334  		},
   335  	}))
   336  }
   337  
   338  func TestBodyBinaryData(t *testing.T) {
   339  	seq.Store(10)
   340  	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   341  		data, err := io.ReadAll(req.Body)
   342  		qt.Check(t, qt.IsNil(err))
   343  		qt.Check(t, qt.Equals(string(data), "\xff"))
   344  		w.Write([]byte{0xff})
   345  	}))
   346  
   347  	var recorder logRecorder
   348  	client := &http.Client{
   349  		Transport: Transport(&TransportConfig{
   350  			Logger: &recorder,
   351  		}),
   352  	}
   353  	ctx := context.Background()
   354  	req, err := http.NewRequestWithContext(ctx, "GET", srv.URL+"/foo/bar", bytes.NewReader([]byte{0xff}))
   355  	qt.Assert(t, qt.IsNil(err))
   356  	resp, err := client.Do(req)
   357  	qt.Assert(t, qt.IsNil(err))
   358  	data, err := io.ReadAll(resp.Body)
   359  	qt.Assert(t, qt.IsNil(err))
   360  	qt.Assert(t, qt.Equals(string(data), "\xff"))
   361  	resp.Body.Close()
   362  	qt.Assert(t, qt.DeepEquals(recorder.EventKinds, []EventKind{KindClientSendRequest, KindClientRecvResponse}))
   363  	qt.Assert(t, qt.DeepEquals(recorder.Events, []RequestOrResponse{
   364  		&Request{
   365  			ID:            11,
   366  			Method:        "GET",
   367  			ContentLength: 1,
   368  			URL:           "http://localhost/foo/bar",
   369  			Header:        http.Header{},
   370  			BodyData: BodyData{
   371  				Body64: []byte("\xff"),
   372  			},
   373  		},
   374  		&Response{
   375  			ID:         11,
   376  			Method:     "GET",
   377  			URL:        "http://localhost/foo/bar",
   378  			StatusCode: http.StatusOK,
   379  			Header: http.Header{
   380  				"Content-Length": {"1"},
   381  				"Content-Type":   {"text/plain; charset=utf-8"},
   382  				"Date":           {"now"},
   383  			},
   384  			BodyData: BodyData{
   385  				Body64: []byte{0xff},
   386  			},
   387  		},
   388  	}))
   389  }
   390  
   391  type logRecorder struct {
   392  	EventKinds []EventKind
   393  	Events     []RequestOrResponse
   394  }
   395  
   396  func (r *logRecorder) Log(ctx context.Context, kind EventKind, event RequestOrResponse) {
   397  	field := urlField(event)
   398  	// Sanitize the host so we don't need to worry about localhost ports.
   399  	u, err := url.Parse(*field)
   400  	if err != nil {
   401  		panic(err)
   402  	}
   403  	u.Host = "localhost"
   404  	*field = u.String()
   405  
   406  	if _, ok := headerField(event)["Date"]; ok {
   407  		headerField(event)["Date"] = []string{"now"}
   408  	}
   409  
   410  	r.EventKinds = append(r.EventKinds, kind)
   411  	r.Events = append(r.Events, event)
   412  }
   413  
   414  type errorTransport struct{}
   415  
   416  func (errorTransport) RoundTrip(req *http.Request) (*http.Response, error) {
   417  	if req.Body != nil {
   418  		req.Body.Close()
   419  	}
   420  	return nil, fmt.Errorf("error in RoundTrip")
   421  }
   422  
   423  func urlField(event RequestOrResponse) *string {
   424  	switch event := event.(type) {
   425  	case *Request:
   426  		return &event.URL
   427  	case *Response:
   428  		return &event.URL
   429  	}
   430  	panic("unreachable")
   431  }
   432  
   433  func headerField(event RequestOrResponse) http.Header {
   434  	switch event := event.(type) {
   435  	case *Request:
   436  		return event.Header
   437  	case *Response:
   438  		return event.Header
   439  	}
   440  	panic("unreachable")
   441  }