github.com/masterhung0112/hk_server/v5@v5.0.0-20220302090640-ec71aef15e1c/services/httpservice/client_test.go (about)

     1  // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
     2  // See LICENSE.txt for license information.
     3  
     4  package httpservice
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"io/ioutil"
    10  	"net"
    11  	"net/http"
    12  	"net/http/httptest"
    13  	"net/url"
    14  	"strings"
    15  	"testing"
    16  
    17  	"github.com/stretchr/testify/assert"
    18  	"github.com/stretchr/testify/require"
    19  )
    20  
    21  func TestHTTPClient(t *testing.T) {
    22  	mockHTTP := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    23  		w.WriteHeader(http.StatusOK)
    24  	}))
    25  	defer mockHTTP.Close()
    26  
    27  	mockSelfSignedHTTPS := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    28  		w.WriteHeader(http.StatusOK)
    29  	}))
    30  	defer mockSelfSignedHTTPS.Close()
    31  
    32  	t.Run("insecure connections", func(t *testing.T) {
    33  		disableInsecureConnections := false
    34  		enableInsecureConnections := true
    35  
    36  		testCases := []struct {
    37  			description               string
    38  			enableInsecureConnections bool
    39  			url                       string
    40  			expectedAllowed           bool
    41  		}{
    42  			{"allow HTTP even when insecure disabled", disableInsecureConnections, mockHTTP.URL, true},
    43  			{"allow HTTP when insecure enabled", enableInsecureConnections, mockHTTP.URL, true},
    44  			{"reject self-signed HTTPS even when insecure disabled", disableInsecureConnections, mockSelfSignedHTTPS.URL, false},
    45  			{"allow self-signed HTTPS when insecure enabled", enableInsecureConnections, mockSelfSignedHTTPS.URL, true},
    46  		}
    47  
    48  		for _, testCase := range testCases {
    49  			t.Run(testCase.description, func(t *testing.T) {
    50  				c := NewHTTPClient(NewTransport(testCase.enableInsecureConnections, nil, nil))
    51  				if _, err := c.Get(testCase.url); testCase.expectedAllowed {
    52  					require.NoError(t, err)
    53  				} else {
    54  					require.Error(t, err)
    55  				}
    56  
    57  			})
    58  		}
    59  	})
    60  
    61  	t.Run("checks", func(t *testing.T) {
    62  		allowHost := func(_ string) bool { return true }
    63  		rejectHost := func(_ string) bool { return false }
    64  		allowIP := func(_ net.IP) bool { return true }
    65  		rejectIP := func(_ net.IP) bool { return false }
    66  
    67  		testCases := []struct {
    68  			description     string
    69  			allowHost       func(string) bool
    70  			allowIP         func(net.IP) bool
    71  			expectedAllowed bool
    72  		}{
    73  			{"allow with no checks", nil, nil, true},
    74  			{"reject without host check when ip rejected", nil, rejectIP, false},
    75  			{"allow without host check when ip allowed", nil, allowIP, true},
    76  
    77  			{"reject when host rejected since no ip check", rejectHost, nil, false},
    78  			{"reject when host and ip rejected", rejectHost, rejectIP, false},
    79  			{"allow when host rejected since ip allowed", rejectHost, allowIP, true},
    80  
    81  			{"allow when host allowed even without ip check", allowHost, nil, true},
    82  			{"allow when host allowed even if ip rejected", allowHost, rejectIP, true},
    83  			{"allow when host and ip allowed", allowHost, allowIP, true},
    84  		}
    85  		for _, testCase := range testCases {
    86  			t.Run(testCase.description, func(t *testing.T) {
    87  				c := NewHTTPClient(NewTransport(false, testCase.allowHost, testCase.allowIP))
    88  				if _, err := c.Get(mockHTTP.URL); testCase.expectedAllowed {
    89  					require.NoError(t, err)
    90  				} else {
    91  					require.IsType(t, &url.Error{}, err)
    92  					require.Equal(t, AddressForbidden, err.(*url.Error).Err)
    93  				}
    94  			})
    95  		}
    96  	})
    97  }
    98  
    99  func TestHTTPClientWithProxy(t *testing.T) {
   100  	proxy := createProxyServer()
   101  	defer proxy.Close()
   102  
   103  	c := NewHTTPClient(NewTransport(true, nil, nil))
   104  	purl, _ := url.Parse(proxy.URL)
   105  	c.Transport.(*MattermostTransport).Transport.(*http.Transport).Proxy = http.ProxyURL(purl)
   106  
   107  	resp, err := c.Get("http://acme.com")
   108  	require.NoError(t, err)
   109  	defer resp.Body.Close()
   110  
   111  	body, err := ioutil.ReadAll(resp.Body)
   112  	require.NoError(t, err)
   113  	require.Equal(t, "proxy", string(body))
   114  }
   115  
   116  func createProxyServer() *httptest.Server {
   117  	return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   118  		w.WriteHeader(200)
   119  		w.Header().Set("Content-Type", "text/plain; charset=us-ascii")
   120  		fmt.Fprint(w, "proxy")
   121  	}))
   122  }
   123  
   124  func TestDialContextFilter(t *testing.T) {
   125  	for _, tc := range []struct {
   126  		Addr    string
   127  		IsValid bool
   128  	}{
   129  		{
   130  			Addr:    "google.com:80",
   131  			IsValid: true,
   132  		},
   133  		{
   134  			Addr:    "8.8.8.8:53",
   135  			IsValid: true,
   136  		},
   137  		{
   138  			Addr: "127.0.0.1:80",
   139  		},
   140  		{
   141  			Addr:    "10.0.0.1:80",
   142  			IsValid: true,
   143  		},
   144  	} {
   145  		didDial := false
   146  		filter := dialContextFilter(func(ctx context.Context, network, addr string) (net.Conn, error) {
   147  			didDial = true
   148  			return nil, nil
   149  		}, func(host string) bool { return host == "10.0.0.1" }, func(ip net.IP) bool { return !IsReservedIP(ip) })
   150  		_, err := filter(context.Background(), "", tc.Addr)
   151  
   152  		if tc.IsValid {
   153  			require.NoError(t, err)
   154  			require.True(t, didDial)
   155  		} else {
   156  			require.Error(t, err)
   157  			require.Equal(t, err, AddressForbidden)
   158  			require.False(t, didDial)
   159  		}
   160  	}
   161  }
   162  
   163  func TestUserAgentIsSet(t *testing.T) {
   164  	testUserAgent := "test-user-agent"
   165  	defaultUserAgent = testUserAgent
   166  	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
   167  		ua := req.UserAgent()
   168  		assert.NotEqual(t, "", ua, "expected user-agent to be non-empty")
   169  		assert.Equalf(t, testUserAgent, ua, "expected user-agent to be %q but was %q", testUserAgent, ua)
   170  	}))
   171  	defer ts.Close()
   172  	client := NewHTTPClient(NewTransport(true, nil, nil))
   173  	req, err := http.NewRequest("GET", ts.URL, nil)
   174  
   175  	require.NoError(t, err, "NewRequest failed", err)
   176  
   177  	client.Do(req)
   178  }
   179  
   180  func NewHTTPClient(transport http.RoundTripper) *http.Client {
   181  	return &http.Client{
   182  		Transport: transport,
   183  	}
   184  }
   185  
   186  func TestIsReservedIP(t *testing.T) {
   187  	tests := []struct {
   188  		name string
   189  		ip   net.IP
   190  		want bool
   191  	}{
   192  		{"127.8.3.5", net.IPv4(127, 8, 3, 5), true},
   193  		{"192.168.0.1", net.IPv4(192, 168, 0, 1), true},
   194  		{"169.254.0.6", net.IPv4(169, 254, 0, 6), true},
   195  		{"127.120.6.3", net.IPv4(127, 120, 6, 3), true},
   196  		{"8.8.8.8", net.IPv4(8, 8, 8, 8), false},
   197  		{"9.9.9.9", net.IPv4(9, 9, 9, 8), false},
   198  	}
   199  	for _, tt := range tests {
   200  		t.Run(tt.name, func(t *testing.T) {
   201  			got := IsReservedIP(tt.ip)
   202  			assert.Equalf(t, tt.want, got, "IsReservedIP() = %v, want %v", got, tt.want)
   203  		})
   204  	}
   205  }
   206  
   207  func TestIsOwnIP(t *testing.T) {
   208  	tests := []struct {
   209  		name string
   210  		ip   net.IP
   211  		want bool
   212  	}{
   213  		{"127.0.0.1", net.IPv4(127, 0, 0, 1), true},
   214  		{"8.8.8.8", net.IPv4(8, 0, 0, 8), false},
   215  	}
   216  	for _, tt := range tests {
   217  		t.Run(tt.name, func(t *testing.T) {
   218  			got, _ := IsOwnIP(tt.ip)
   219  			assert.Equalf(t, tt.want, got, "IsOwnIP() = %v, want %v for IP %s", got, tt.want, tt.ip.String())
   220  		})
   221  	}
   222  }
   223  
   224  func TestSplitHostnames(t *testing.T) {
   225  	var config string
   226  	var hostnames []string
   227  
   228  	config = ""
   229  	hostnames = strings.FieldsFunc(config, splitFields)
   230  	require.Equal(t, []string{}, hostnames)
   231  
   232  	config = "127.0.0.1 localhost"
   233  	hostnames = strings.FieldsFunc(config, splitFields)
   234  	require.Equal(t, []string{"127.0.0.1", "localhost"}, hostnames)
   235  
   236  	config = "127.0.0.1,localhost"
   237  	hostnames = strings.FieldsFunc(config, splitFields)
   238  	require.Equal(t, []string{"127.0.0.1", "localhost"}, hostnames)
   239  
   240  	config = "127.0.0.1,,localhost"
   241  	hostnames = strings.FieldsFunc(config, splitFields)
   242  	require.Equal(t, []string{"127.0.0.1", "localhost"}, hostnames)
   243  
   244  	config = "127.0.0.1  localhost"
   245  	hostnames = strings.FieldsFunc(config, splitFields)
   246  	require.Equal(t, []string{"127.0.0.1", "localhost"}, hostnames)
   247  
   248  	config = "127.0.0.1 , localhost"
   249  	hostnames = strings.FieldsFunc(config, splitFields)
   250  	require.Equal(t, []string{"127.0.0.1", "localhost"}, hostnames)
   251  
   252  	config = "127.0.0.1  localhost  "
   253  	hostnames = strings.FieldsFunc(config, splitFields)
   254  	require.Equal(t, []string{"127.0.0.1", "localhost"}, hostnames)
   255  
   256  	config = " 127.0.0.1  ,,localhost  , , ,,"
   257  	hostnames = strings.FieldsFunc(config, splitFields)
   258  	require.Equal(t, []string{"127.0.0.1", "localhost"}, hostnames)
   259  
   260  	config = "127.0.0.1 localhost, 192.168.1.0"
   261  	hostnames = strings.FieldsFunc(config, splitFields)
   262  	require.Equal(t, []string{"127.0.0.1", "localhost", "192.168.1.0"}, hostnames)
   263  }