github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/handlers/canonical_test.go (about)

     1  package handlers
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"log"
     7  	"net/url"
     8  	"strings"
     9  	"testing"
    10  
    11  	http "github.com/hxx258456/ccgo/gmhttp"
    12  	"github.com/hxx258456/ccgo/gmhttp/httptest"
    13  )
    14  
    15  func TestCleanHost(t *testing.T) {
    16  	tests := []struct {
    17  		in, want string
    18  	}{
    19  		{"www.google.com", "www.google.com"},
    20  		{"www.google.com foo", "www.google.com"},
    21  		{"www.google.com/foo", "www.google.com"},
    22  		{" first character is a space", ""},
    23  	}
    24  	for _, tt := range tests {
    25  		got := cleanHost(tt.in)
    26  		if tt.want != got {
    27  			t.Errorf("cleanHost(%q) = %q, want %q", tt.in, got, tt.want)
    28  		}
    29  	}
    30  }
    31  
    32  func TestCanonicalHost(t *testing.T) {
    33  	gorilla := "http://www.gorillatoolkit.org"
    34  
    35  	rr := httptest.NewRecorder()
    36  	r := newRequest("GET", "http://www.example.com/")
    37  
    38  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
    39  
    40  	// Test a re-direct: should return a 302 Found.
    41  	CanonicalHost(gorilla, http.StatusFound)(testHandler).ServeHTTP(rr, r)
    42  
    43  	if rr.Code != http.StatusFound {
    44  		t.Fatalf("bad status: got %v want %v", rr.Code, http.StatusFound)
    45  	}
    46  
    47  	if rr.Header().Get("Location") != gorilla+r.URL.Path {
    48  		t.Fatalf("bad re-direct: got %q want %q", rr.Header().Get("Location"), gorilla+r.URL.Path)
    49  	}
    50  
    51  }
    52  
    53  func TestKeepsQueryString(t *testing.T) {
    54  	google := "https://www.google.com"
    55  
    56  	rr := httptest.NewRecorder()
    57  	querystring := url.Values{"q": {"golang"}, "format": {"json"}}.Encode()
    58  	r := newRequest("GET", "http://www.example.com/search?"+querystring)
    59  
    60  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
    61  	CanonicalHost(google, http.StatusFound)(testHandler).ServeHTTP(rr, r)
    62  
    63  	want := google + r.URL.Path + "?" + querystring
    64  	if rr.Header().Get("Location") != want {
    65  		t.Fatalf("bad re-direct: got %q want %q", rr.Header().Get("Location"), want)
    66  	}
    67  }
    68  
    69  func TestBadDomain(t *testing.T) {
    70  	rr := httptest.NewRecorder()
    71  	r := newRequest("GET", "http://www.example.com/")
    72  
    73  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
    74  
    75  	// Test a bad domain - should return 200 OK.
    76  	CanonicalHost("%", http.StatusFound)(testHandler).ServeHTTP(rr, r)
    77  
    78  	if rr.Code != http.StatusOK {
    79  		t.Fatalf("bad status: got %v want %v", rr.Code, http.StatusOK)
    80  	}
    81  }
    82  
    83  func TestEmptyHost(t *testing.T) {
    84  	rr := httptest.NewRecorder()
    85  	r := newRequest("GET", "http://www.example.com/")
    86  
    87  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
    88  
    89  	// Test a domain that returns an empty url.Host from url.Parse.
    90  	CanonicalHost("hello.com", http.StatusFound)(testHandler).ServeHTTP(rr, r)
    91  
    92  	if rr.Code != http.StatusOK {
    93  		t.Fatalf("bad status: got %v want %v", rr.Code, http.StatusOK)
    94  	}
    95  }
    96  
    97  func TestHeaderWrites(t *testing.T) {
    98  	gorilla := "http://www.gorillatoolkit.org"
    99  
   100  	testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   101  		w.WriteHeader(200)
   102  	})
   103  
   104  	// Catch the log output to ensure we don't write multiple headers.
   105  	var b bytes.Buffer
   106  	buf := bufio.NewWriter(&b)
   107  	tl := log.New(buf, "test: ", log.Lshortfile)
   108  
   109  	srv := httptest.NewServer(
   110  		CanonicalHost(gorilla, http.StatusFound)(testHandler))
   111  	defer srv.Close()
   112  	srv.Config.ErrorLog = tl
   113  
   114  	_, err := http.Get(srv.URL)
   115  	if err != nil {
   116  		t.Fatal(err)
   117  	}
   118  
   119  	err = buf.Flush()
   120  	if err != nil {
   121  		t.Fatal(err)
   122  	}
   123  
   124  	// We rely on the error not changing: net/http does not export it.
   125  	if strings.Contains(b.String(), "multiple response.WriteHeader calls") {
   126  		t.Fatalf("re-direct did not return early: multiple header writes")
   127  	}
   128  }