go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/common/testing/httpmitm/httpmitm_test.go (about) 1 // Copyright 2015 The LUCI Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package httpmitm 16 17 import ( 18 "bytes" 19 "errors" 20 "io" 21 "net/http" 22 "net/http/httptest" 23 "regexp" 24 "testing" 25 26 . "github.com/smartystreets/goconvey/convey" 27 ) 28 29 type record struct { 30 o Origin 31 d string 32 e error 33 } 34 35 // shouldRecord tests whether a captured record (actual) matches the expected result: 36 // 0) Origin 37 // 1) bool (are we expecting an error?) 38 // 2) string (optional). If present, the regexp pattern that must match the data. 39 func shouldRecord(actual any, expected ...any) string { 40 r := actual.(*record) 41 o := expected[0].(Origin) 42 e := expected[1].(bool) 43 44 var re string 45 if len(expected) == 3 { 46 re = expected[2].(string) 47 } 48 49 if err := ShouldEqual(r.o, o); err != "" { 50 return err 51 } 52 if !e { 53 // No error. 54 if err := ShouldBeNil(r.e); err != "" { 55 return err 56 } 57 } else { 58 // Error expected. 59 if err := ShouldNotBeNil(r.e); err != "" { 60 return err 61 } 62 } 63 if re != "" { 64 m, e := regexp.MatchString(re, r.d) 65 if err := ShouldEqual(e, nil); err != "" { 66 return err 67 } 68 if err := ShouldBeTrue(m); err != "" { 69 return err 70 } 71 } 72 return "" 73 } 74 75 func TestTransport(t *testing.T) { 76 t.Parallel() 77 78 Convey(`A debug HTTP client`, t, func() { 79 // Generic callback-based server. Each test will set its callback. 80 var callback func() (string, error) 81 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 82 err := errors.New("No Callback.") 83 if callback != nil { 84 var content string 85 content, err = callback() 86 if err == nil { 87 _, err = w.Write([]byte(content)) 88 } 89 } 90 if err != nil { 91 http.Error(w, err.Error(), http.StatusInternalServerError) 92 } 93 })) 94 defer ts.Close() 95 96 var records []*record 97 client := http.Client{ 98 Transport: &Transport{ 99 Callback: func(o Origin, data []byte, err error) { 100 records = append(records, &record{o, string(data), err}) 101 }, 102 }, 103 } 104 105 Convey(`Successfully fetches content.`, func() { 106 callback = func() (string, error) { 107 return "Hello, client!", nil 108 } 109 resp, err := client.Post(ts.URL, "test", bytes.NewBufferString("DATA")) 110 So(err, ShouldBeNil) 111 defer resp.Body.Close() 112 113 So(len(records), ShouldEqual, 2) 114 So(records[0], shouldRecord, Request, false, 115 "(?s:POST / HTTP/1.1\r\n(.+:.+\r\n)*\r\nDATA)") 116 So(records[1], shouldRecord, Response, false, 117 "(?s:HTTP/1.1 200 OK\r\n(.+:.+\r\n)*\r\nHello, client!)") 118 119 body, err := io.ReadAll(resp.Body) 120 So(err, ShouldBeNil) 121 So(string(body), ShouldEqual, "Hello, client!") 122 So(resp.StatusCode, ShouldEqual, http.StatusOK) 123 }) 124 125 Convey(`Handles HTTP error.`, func() { 126 callback = func() (string, error) { 127 return "", errors.New("Failure!") 128 } 129 resp, err := client.Post(ts.URL, "test", bytes.NewBufferString("DATA")) 130 So(err, ShouldBeNil) 131 defer resp.Body.Close() 132 133 So(len(records), ShouldEqual, 2) 134 So(records[0], shouldRecord, Request, false, 135 "(?s:POST / HTTP/1.1\r\n(.+:.+\r\n)*\r\nDATA)") 136 So(records[1], shouldRecord, Response, false, 137 "(?s:HTTP/1.1 500 Internal Server Error\r\n(.+:.+\r\n)*\r\nFailure!)") 138 139 body, err := io.ReadAll(resp.Body) 140 So(err, ShouldBeNil) 141 So(string(body), ShouldEqual, "Failure!\n") 142 So(resp.StatusCode, ShouldEqual, http.StatusInternalServerError) 143 }) 144 145 Convey(`Handles connection error.`, func() { 146 _, err := client.Get("http+testingfakeprotocol://") 147 So(err, ShouldNotBeNil) 148 149 So(len(records), ShouldEqual, 2) 150 So(records[0], shouldRecord, Request, false, 151 "(?s:GET / HTTP/1.1\r\n(.+:.+\r\n)*)") 152 So(records[1], shouldRecord, Response, true) 153 }) 154 }) 155 }