github.com/letsencrypt/boulder@v0.20251208.0/web/context_test.go (about)

     1  package web
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"fmt"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"strings"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/letsencrypt/boulder/features"
    15  	blog "github.com/letsencrypt/boulder/log"
    16  	"github.com/letsencrypt/boulder/test"
    17  )
    18  
    19  type myHandler struct{}
    20  
    21  func (m myHandler) ServeHTTP(e *RequestEvent, w http.ResponseWriter, r *http.Request) {
    22  	w.WriteHeader(201)
    23  	e.Endpoint = "/endpoint"
    24  	_, _ = w.Write([]byte("hi"))
    25  }
    26  
    27  func TestLogCode(t *testing.T) {
    28  	mockLog := blog.UseMock()
    29  	th := NewTopHandler(mockLog, myHandler{})
    30  	req, err := http.NewRequest("GET", "/thisisignored", &bytes.Reader{})
    31  	if err != nil {
    32  		t.Fatal(err)
    33  	}
    34  	th.ServeHTTP(httptest.NewRecorder(), req)
    35  	expected := `INFO: GET /endpoint 0 201 0 0.0.0.0 JSON={}`
    36  	if len(mockLog.GetAllMatching(expected)) != 1 {
    37  		t.Errorf("Expected exactly one log line matching %q. Got \n%s",
    38  			expected, strings.Join(mockLog.GetAllMatching(".*"), "\n"))
    39  	}
    40  }
    41  
    42  type codeHandler struct{}
    43  
    44  func (ch codeHandler) ServeHTTP(e *RequestEvent, w http.ResponseWriter, r *http.Request) {
    45  	e.Endpoint = "/endpoint"
    46  	_, _ = w.Write([]byte("hi"))
    47  }
    48  
    49  func TestStatusCodeLogging(t *testing.T) {
    50  	mockLog := blog.UseMock()
    51  	th := NewTopHandler(mockLog, codeHandler{})
    52  	req, err := http.NewRequest("GET", "/thisisignored", &bytes.Reader{})
    53  	if err != nil {
    54  		t.Fatal(err)
    55  	}
    56  	th.ServeHTTP(httptest.NewRecorder(), req)
    57  	expected := `INFO: GET /endpoint 0 200 0 0.0.0.0 JSON={}`
    58  	if len(mockLog.GetAllMatching(expected)) != 1 {
    59  		t.Errorf("Expected exactly one log line matching %q. Got \n%s",
    60  			expected, strings.Join(mockLog.GetAllMatching(".*"), "\n"))
    61  	}
    62  }
    63  
    64  func TestOrigin(t *testing.T) {
    65  	mockLog := blog.UseMock()
    66  	th := NewTopHandler(mockLog, myHandler{})
    67  	req, err := http.NewRequest("GET", "/thisisignored", &bytes.Reader{})
    68  	if err != nil {
    69  		t.Fatal(err)
    70  	}
    71  	req.Header.Add("Origin", "https://example.com")
    72  	th.ServeHTTP(httptest.NewRecorder(), req)
    73  	expected := `INFO: GET /endpoint 0 201 0 0.0.0.0 JSON={.*"Origin":"https://example.com"}`
    74  	if len(mockLog.GetAllMatching(expected)) != 1 {
    75  		t.Errorf("Expected exactly one log line matching %q. Got \n%s",
    76  			expected, strings.Join(mockLog.GetAllMatching(".*"), "\n"))
    77  	}
    78  }
    79  
    80  type hostHeaderHandler struct {
    81  	f func(*RequestEvent, http.ResponseWriter, *http.Request)
    82  }
    83  
    84  func (hhh hostHeaderHandler) ServeHTTP(e *RequestEvent, w http.ResponseWriter, r *http.Request) {
    85  	hhh.f(e, w, r)
    86  }
    87  
    88  func TestHostHeaderRewrite(t *testing.T) {
    89  	mockLog := blog.UseMock()
    90  	hhh := hostHeaderHandler{f: func(_ *RequestEvent, _ http.ResponseWriter, r *http.Request) {
    91  		t.Helper()
    92  		test.AssertEquals(t, r.Host, "localhost")
    93  	}}
    94  	th := NewTopHandler(mockLog, &hhh)
    95  
    96  	req, err := http.NewRequest("GET", "/", &bytes.Reader{})
    97  	test.AssertNotError(t, err, "http.NewRequest failed")
    98  	req.Host = "localhost:80"
    99  	fmt.Println("here")
   100  	th.ServeHTTP(httptest.NewRecorder(), req)
   101  
   102  	req, err = http.NewRequest("GET", "/", &bytes.Reader{})
   103  	test.AssertNotError(t, err, "http.NewRequest failed")
   104  	req.Host = "localhost:443"
   105  	req.TLS = &tls.ConnectionState{}
   106  	th.ServeHTTP(httptest.NewRecorder(), req)
   107  
   108  	req, err = http.NewRequest("GET", "/", &bytes.Reader{})
   109  	test.AssertNotError(t, err, "http.NewRequest failed")
   110  	req.Host = "localhost:443"
   111  	req.TLS = nil
   112  	th.ServeHTTP(httptest.NewRecorder(), req)
   113  
   114  	hhh.f = func(_ *RequestEvent, _ http.ResponseWriter, r *http.Request) {
   115  		t.Helper()
   116  		test.AssertEquals(t, r.Host, "localhost:123")
   117  	}
   118  	req, err = http.NewRequest("GET", "/", &bytes.Reader{})
   119  	test.AssertNotError(t, err, "http.NewRequest failed")
   120  	req.Host = "localhost:123"
   121  	th.ServeHTTP(httptest.NewRecorder(), req)
   122  }
   123  
   124  type cancelHandler struct {
   125  	res chan string
   126  }
   127  
   128  func (ch cancelHandler) ServeHTTP(e *RequestEvent, w http.ResponseWriter, r *http.Request) {
   129  	select {
   130  	case <-r.Context().Done():
   131  		ch.res <- r.Context().Err().Error()
   132  	case <-time.After(300 * time.Millisecond):
   133  		ch.res <- "300 ms passed"
   134  	}
   135  }
   136  
   137  func TestPropagateCancel(t *testing.T) {
   138  	mockLog := blog.UseMock()
   139  	res := make(chan string)
   140  	features.Set(features.Config{PropagateCancels: true})
   141  	th := NewTopHandler(mockLog, cancelHandler{res})
   142  	ctx, cancel := context.WithCancel(context.Background())
   143  	go func() {
   144  		req, err := http.NewRequestWithContext(ctx, "GET", "/thisisignored", &bytes.Reader{})
   145  		if err != nil {
   146  			t.Error(err)
   147  		}
   148  		th.ServeHTTP(httptest.NewRecorder(), req)
   149  	}()
   150  	cancel()
   151  	result := <-res
   152  	if result != "context canceled" {
   153  		t.Errorf("expected 'context canceled', got %q", result)
   154  	}
   155  }