github.com/masterhung0112/hk_server/v5@v5.0.0-20220302090640-ec71aef15e1c/services/httpservice/client_test.go (about) 1 // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. 2 // See LICENSE.txt for license information. 3 4 package httpservice 5 6 import ( 7 "context" 8 "fmt" 9 "io/ioutil" 10 "net" 11 "net/http" 12 "net/http/httptest" 13 "net/url" 14 "strings" 15 "testing" 16 17 "github.com/stretchr/testify/assert" 18 "github.com/stretchr/testify/require" 19 ) 20 21 func TestHTTPClient(t *testing.T) { 22 mockHTTP := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 23 w.WriteHeader(http.StatusOK) 24 })) 25 defer mockHTTP.Close() 26 27 mockSelfSignedHTTPS := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 28 w.WriteHeader(http.StatusOK) 29 })) 30 defer mockSelfSignedHTTPS.Close() 31 32 t.Run("insecure connections", func(t *testing.T) { 33 disableInsecureConnections := false 34 enableInsecureConnections := true 35 36 testCases := []struct { 37 description string 38 enableInsecureConnections bool 39 url string 40 expectedAllowed bool 41 }{ 42 {"allow HTTP even when insecure disabled", disableInsecureConnections, mockHTTP.URL, true}, 43 {"allow HTTP when insecure enabled", enableInsecureConnections, mockHTTP.URL, true}, 44 {"reject self-signed HTTPS even when insecure disabled", disableInsecureConnections, mockSelfSignedHTTPS.URL, false}, 45 {"allow self-signed HTTPS when insecure enabled", enableInsecureConnections, mockSelfSignedHTTPS.URL, true}, 46 } 47 48 for _, testCase := range testCases { 49 t.Run(testCase.description, func(t *testing.T) { 50 c := NewHTTPClient(NewTransport(testCase.enableInsecureConnections, nil, nil)) 51 if _, err := c.Get(testCase.url); testCase.expectedAllowed { 52 require.NoError(t, err) 53 } else { 54 require.Error(t, err) 55 } 56 57 }) 58 } 59 }) 60 61 t.Run("checks", func(t *testing.T) { 62 allowHost := func(_ string) bool { return true } 63 rejectHost := func(_ string) bool { return false } 64 allowIP := func(_ net.IP) bool { return true } 65 rejectIP := func(_ net.IP) bool { return false } 66 67 testCases := []struct { 68 description string 69 allowHost func(string) bool 70 allowIP func(net.IP) bool 71 expectedAllowed bool 72 }{ 73 {"allow with no checks", nil, nil, true}, 74 {"reject without host check when ip rejected", nil, rejectIP, false}, 75 {"allow without host check when ip allowed", nil, allowIP, true}, 76 77 {"reject when host rejected since no ip check", rejectHost, nil, false}, 78 {"reject when host and ip rejected", rejectHost, rejectIP, false}, 79 {"allow when host rejected since ip allowed", rejectHost, allowIP, true}, 80 81 {"allow when host allowed even without ip check", allowHost, nil, true}, 82 {"allow when host allowed even if ip rejected", allowHost, rejectIP, true}, 83 {"allow when host and ip allowed", allowHost, allowIP, true}, 84 } 85 for _, testCase := range testCases { 86 t.Run(testCase.description, func(t *testing.T) { 87 c := NewHTTPClient(NewTransport(false, testCase.allowHost, testCase.allowIP)) 88 if _, err := c.Get(mockHTTP.URL); testCase.expectedAllowed { 89 require.NoError(t, err) 90 } else { 91 require.IsType(t, &url.Error{}, err) 92 require.Equal(t, AddressForbidden, err.(*url.Error).Err) 93 } 94 }) 95 } 96 }) 97 } 98 99 func TestHTTPClientWithProxy(t *testing.T) { 100 proxy := createProxyServer() 101 defer proxy.Close() 102 103 c := NewHTTPClient(NewTransport(true, nil, nil)) 104 purl, _ := url.Parse(proxy.URL) 105 c.Transport.(*MattermostTransport).Transport.(*http.Transport).Proxy = http.ProxyURL(purl) 106 107 resp, err := c.Get("http://acme.com") 108 require.NoError(t, err) 109 defer resp.Body.Close() 110 111 body, err := ioutil.ReadAll(resp.Body) 112 require.NoError(t, err) 113 require.Equal(t, "proxy", string(body)) 114 } 115 116 func createProxyServer() *httptest.Server { 117 return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 118 w.WriteHeader(200) 119 w.Header().Set("Content-Type", "text/plain; charset=us-ascii") 120 fmt.Fprint(w, "proxy") 121 })) 122 } 123 124 func TestDialContextFilter(t *testing.T) { 125 for _, tc := range []struct { 126 Addr string 127 IsValid bool 128 }{ 129 { 130 Addr: "google.com:80", 131 IsValid: true, 132 }, 133 { 134 Addr: "8.8.8.8:53", 135 IsValid: true, 136 }, 137 { 138 Addr: "127.0.0.1:80", 139 }, 140 { 141 Addr: "10.0.0.1:80", 142 IsValid: true, 143 }, 144 } { 145 didDial := false 146 filter := dialContextFilter(func(ctx context.Context, network, addr string) (net.Conn, error) { 147 didDial = true 148 return nil, nil 149 }, func(host string) bool { return host == "10.0.0.1" }, func(ip net.IP) bool { return !IsReservedIP(ip) }) 150 _, err := filter(context.Background(), "", tc.Addr) 151 152 if tc.IsValid { 153 require.NoError(t, err) 154 require.True(t, didDial) 155 } else { 156 require.Error(t, err) 157 require.Equal(t, err, AddressForbidden) 158 require.False(t, didDial) 159 } 160 } 161 } 162 163 func TestUserAgentIsSet(t *testing.T) { 164 testUserAgent := "test-user-agent" 165 defaultUserAgent = testUserAgent 166 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { 167 ua := req.UserAgent() 168 assert.NotEqual(t, "", ua, "expected user-agent to be non-empty") 169 assert.Equalf(t, testUserAgent, ua, "expected user-agent to be %q but was %q", testUserAgent, ua) 170 })) 171 defer ts.Close() 172 client := NewHTTPClient(NewTransport(true, nil, nil)) 173 req, err := http.NewRequest("GET", ts.URL, nil) 174 175 require.NoError(t, err, "NewRequest failed", err) 176 177 client.Do(req) 178 } 179 180 func NewHTTPClient(transport http.RoundTripper) *http.Client { 181 return &http.Client{ 182 Transport: transport, 183 } 184 } 185 186 func TestIsReservedIP(t *testing.T) { 187 tests := []struct { 188 name string 189 ip net.IP 190 want bool 191 }{ 192 {"127.8.3.5", net.IPv4(127, 8, 3, 5), true}, 193 {"192.168.0.1", net.IPv4(192, 168, 0, 1), true}, 194 {"169.254.0.6", net.IPv4(169, 254, 0, 6), true}, 195 {"127.120.6.3", net.IPv4(127, 120, 6, 3), true}, 196 {"8.8.8.8", net.IPv4(8, 8, 8, 8), false}, 197 {"9.9.9.9", net.IPv4(9, 9, 9, 8), false}, 198 } 199 for _, tt := range tests { 200 t.Run(tt.name, func(t *testing.T) { 201 got := IsReservedIP(tt.ip) 202 assert.Equalf(t, tt.want, got, "IsReservedIP() = %v, want %v", got, tt.want) 203 }) 204 } 205 } 206 207 func TestIsOwnIP(t *testing.T) { 208 tests := []struct { 209 name string 210 ip net.IP 211 want bool 212 }{ 213 {"127.0.0.1", net.IPv4(127, 0, 0, 1), true}, 214 {"8.8.8.8", net.IPv4(8, 0, 0, 8), false}, 215 } 216 for _, tt := range tests { 217 t.Run(tt.name, func(t *testing.T) { 218 got, _ := IsOwnIP(tt.ip) 219 assert.Equalf(t, tt.want, got, "IsOwnIP() = %v, want %v for IP %s", got, tt.want, tt.ip.String()) 220 }) 221 } 222 } 223 224 func TestSplitHostnames(t *testing.T) { 225 var config string 226 var hostnames []string 227 228 config = "" 229 hostnames = strings.FieldsFunc(config, splitFields) 230 require.Equal(t, []string{}, hostnames) 231 232 config = "127.0.0.1 localhost" 233 hostnames = strings.FieldsFunc(config, splitFields) 234 require.Equal(t, []string{"127.0.0.1", "localhost"}, hostnames) 235 236 config = "127.0.0.1,localhost" 237 hostnames = strings.FieldsFunc(config, splitFields) 238 require.Equal(t, []string{"127.0.0.1", "localhost"}, hostnames) 239 240 config = "127.0.0.1,,localhost" 241 hostnames = strings.FieldsFunc(config, splitFields) 242 require.Equal(t, []string{"127.0.0.1", "localhost"}, hostnames) 243 244 config = "127.0.0.1 localhost" 245 hostnames = strings.FieldsFunc(config, splitFields) 246 require.Equal(t, []string{"127.0.0.1", "localhost"}, hostnames) 247 248 config = "127.0.0.1 , localhost" 249 hostnames = strings.FieldsFunc(config, splitFields) 250 require.Equal(t, []string{"127.0.0.1", "localhost"}, hostnames) 251 252 config = "127.0.0.1 localhost " 253 hostnames = strings.FieldsFunc(config, splitFields) 254 require.Equal(t, []string{"127.0.0.1", "localhost"}, hostnames) 255 256 config = " 127.0.0.1 ,,localhost , , ,," 257 hostnames = strings.FieldsFunc(config, splitFields) 258 require.Equal(t, []string{"127.0.0.1", "localhost"}, hostnames) 259 260 config = "127.0.0.1 localhost, 192.168.1.0" 261 hostnames = strings.FieldsFunc(config, splitFields) 262 require.Equal(t, []string{"127.0.0.1", "localhost", "192.168.1.0"}, hostnames) 263 }