github.com/saracen/git-lfs@v2.5.2+incompatible/lfsapi/client_test.go (about)

     1  package lfsapi
     2  
     3  import (
     4  	"encoding/base64"
     5  	"encoding/json"
     6  	"fmt"
     7  	"net"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"strings"
    11  	"sync/atomic"
    12  	"testing"
    13  
    14  	"github.com/stretchr/testify/assert"
    15  	"github.com/stretchr/testify/require"
    16  )
    17  
    18  type redirectTest struct {
    19  	Test string
    20  }
    21  
    22  func TestClientRedirect(t *testing.T) {
    23  	var srv3Https, srv3Http string
    24  
    25  	var called1 uint32
    26  	var called2 uint32
    27  	var called3 uint32
    28  	srv3 := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    29  		atomic.AddUint32(&called3, 1)
    30  		t.Logf("srv3 req %s %s", r.Method, r.URL.Path)
    31  		assert.Equal(t, "POST", r.Method)
    32  
    33  		switch r.URL.Path {
    34  		case "/upgrade":
    35  			assert.Equal(t, "auth", r.Header.Get("Authorization"))
    36  			assert.Equal(t, "1", r.Header.Get("A"))
    37  			w.Header().Set("Location", srv3Https+"/upgraded")
    38  			w.WriteHeader(301)
    39  		case "/upgraded":
    40  			// Since srv3 listens on both a TLS-enabled socket and a
    41  			// TLS-disabled one, they are two different hosts.
    42  			// Ensure that, even though this is a "secure" upgrade,
    43  			// the authorization header is stripped.
    44  			assert.Equal(t, "", r.Header.Get("Authorization"))
    45  			assert.Equal(t, "1", r.Header.Get("A"))
    46  
    47  		case "/downgrade":
    48  			assert.Equal(t, "auth", r.Header.Get("Authorization"))
    49  			assert.Equal(t, "1", r.Header.Get("A"))
    50  			w.Header().Set("Location", srv3Http+"/404")
    51  			w.WriteHeader(301)
    52  
    53  		default:
    54  			w.WriteHeader(404)
    55  		}
    56  	}))
    57  
    58  	srv2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    59  		atomic.AddUint32(&called2, 1)
    60  		t.Logf("srv2 req %s %s", r.Method, r.URL.Path)
    61  		assert.Equal(t, "POST", r.Method)
    62  
    63  		switch r.URL.Path {
    64  		case "/ok":
    65  			assert.Equal(t, "", r.Header.Get("Authorization"))
    66  			assert.Equal(t, "1", r.Header.Get("A"))
    67  			body := &redirectTest{}
    68  			err := json.NewDecoder(r.Body).Decode(body)
    69  			assert.Nil(t, err)
    70  			assert.Equal(t, "External", body.Test)
    71  
    72  			w.WriteHeader(200)
    73  		default:
    74  			w.WriteHeader(404)
    75  		}
    76  	}))
    77  
    78  	srv1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    79  		atomic.AddUint32(&called1, 1)
    80  		t.Logf("srv1 req %s %s", r.Method, r.URL.Path)
    81  		assert.Equal(t, "POST", r.Method)
    82  
    83  		switch r.URL.Path {
    84  		case "/local":
    85  			w.Header().Set("Location", "/ok")
    86  			w.WriteHeader(307)
    87  		case "/external":
    88  			w.Header().Set("Location", srv2.URL+"/ok")
    89  			w.WriteHeader(307)
    90  		case "/ok":
    91  			assert.Equal(t, "auth", r.Header.Get("Authorization"))
    92  			assert.Equal(t, "1", r.Header.Get("A"))
    93  			body := &redirectTest{}
    94  			err := json.NewDecoder(r.Body).Decode(body)
    95  			assert.Nil(t, err)
    96  			assert.Equal(t, "Local", body.Test)
    97  
    98  			w.WriteHeader(200)
    99  		default:
   100  			w.WriteHeader(404)
   101  		}
   102  	}))
   103  	defer srv1.Close()
   104  	defer srv2.Close()
   105  	defer srv3.Close()
   106  
   107  	srv3InsecureListener, err := net.Listen("tcp", "127.0.0.1:0")
   108  	require.Nil(t, err)
   109  
   110  	go http.Serve(srv3InsecureListener, srv3.Config.Handler)
   111  	defer srv3InsecureListener.Close()
   112  
   113  	srv3Https = srv3.URL
   114  	srv3Http = fmt.Sprintf("http://%s", srv3InsecureListener.Addr().String())
   115  
   116  	c, err := NewClient(NewContext(nil, nil, map[string]string{
   117  		fmt.Sprintf("http.%s.sslverify", srv3Https):  "false",
   118  		fmt.Sprintf("http.%s/.sslverify", srv3Https): "false",
   119  		fmt.Sprintf("http.%s.sslverify", srv3Http):   "false",
   120  		fmt.Sprintf("http.%s/.sslverify", srv3Http):  "false",
   121  		fmt.Sprintf("http.sslverify"):                "false",
   122  	}))
   123  	require.Nil(t, err)
   124  
   125  	// local redirect
   126  	req, err := http.NewRequest("POST", srv1.URL+"/local", nil)
   127  	require.Nil(t, err)
   128  	req.Header.Set("Authorization", "auth")
   129  	req.Header.Set("A", "1")
   130  
   131  	require.Nil(t, MarshalToRequest(req, &redirectTest{Test: "Local"}))
   132  
   133  	res, err := c.Do(req)
   134  	require.Nil(t, err)
   135  	assert.Equal(t, 200, res.StatusCode)
   136  	assert.EqualValues(t, 2, called1)
   137  	assert.EqualValues(t, 0, called2)
   138  
   139  	// external redirect
   140  	req, err = http.NewRequest("POST", srv1.URL+"/external", nil)
   141  	require.Nil(t, err)
   142  	req.Header.Set("Authorization", "auth")
   143  	req.Header.Set("A", "1")
   144  
   145  	require.Nil(t, MarshalToRequest(req, &redirectTest{Test: "External"}))
   146  
   147  	res, err = c.Do(req)
   148  	require.Nil(t, err)
   149  	assert.Equal(t, 200, res.StatusCode)
   150  	assert.EqualValues(t, 3, called1)
   151  	assert.EqualValues(t, 1, called2)
   152  
   153  	// http -> https (secure upgrade)
   154  
   155  	req, err = http.NewRequest("POST", srv3Http+"/upgrade", nil)
   156  	require.Nil(t, err)
   157  	req.Header.Set("Authorization", "auth")
   158  	req.Header.Set("A", "1")
   159  
   160  	require.Nil(t, MarshalToRequest(req, &redirectTest{Test: "http->https"}))
   161  
   162  	res, err = c.Do(req)
   163  	require.Nil(t, err)
   164  	assert.Equal(t, 200, res.StatusCode)
   165  	assert.EqualValues(t, 2, atomic.LoadUint32(&called3))
   166  
   167  	// https -> http (insecure downgrade)
   168  
   169  	req, err = http.NewRequest("POST", srv3Https+"/downgrade", nil)
   170  	require.Nil(t, err)
   171  	req.Header.Set("Authorization", "auth")
   172  	req.Header.Set("A", "1")
   173  
   174  	require.Nil(t, MarshalToRequest(req, &redirectTest{Test: "https->http"}))
   175  
   176  	_, err = c.Do(req)
   177  	assert.EqualError(t, err, "lfsapi/client: refusing insecure redirect, https->http")
   178  }
   179  
   180  func TestClientRedirectReauthenticate(t *testing.T) {
   181  	var srv1, srv2 *httptest.Server
   182  	var called1, called2 uint32
   183  	var creds1, creds2 Creds
   184  
   185  	srv1 = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   186  		atomic.AddUint32(&called1, 1)
   187  
   188  		if hdr := r.Header.Get("Authorization"); len(hdr) > 0 {
   189  			parts := strings.SplitN(hdr, " ", 2)
   190  			typ, b64 := parts[0], parts[1]
   191  
   192  			auth, err := base64.URLEncoding.DecodeString(b64)
   193  			assert.Nil(t, err)
   194  			assert.Equal(t, "Basic", typ)
   195  			assert.Equal(t, "user1:pass1", string(auth))
   196  
   197  			http.Redirect(w, r, srv2.URL+r.URL.Path, http.StatusMovedPermanently)
   198  			return
   199  		}
   200  		w.WriteHeader(http.StatusUnauthorized)
   201  	}))
   202  
   203  	srv2 = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   204  		atomic.AddUint32(&called2, 1)
   205  
   206  		parts := strings.SplitN(r.Header.Get("Authorization"), " ", 2)
   207  		typ, b64 := parts[0], parts[1]
   208  
   209  		auth, err := base64.URLEncoding.DecodeString(b64)
   210  		assert.Nil(t, err)
   211  		assert.Equal(t, "Basic", typ)
   212  		assert.Equal(t, "user2:pass2", string(auth))
   213  	}))
   214  
   215  	// Change the URL of srv2 to make it appears as if it is a different
   216  	// host.
   217  	srv2.URL = strings.Replace(srv2.URL, "127.0.0.1", "0.0.0.0", 1)
   218  
   219  	creds1 = Creds(map[string]string{
   220  		"protocol": "http",
   221  		"host":     strings.TrimPrefix(srv1.URL, "http://"),
   222  
   223  		"username": "user1",
   224  		"password": "pass1",
   225  	})
   226  	creds2 = Creds(map[string]string{
   227  		"protocol": "http",
   228  		"host":     strings.TrimPrefix(srv2.URL, "http://"),
   229  
   230  		"username": "user2",
   231  		"password": "pass2",
   232  	})
   233  
   234  	defer srv1.Close()
   235  	defer srv2.Close()
   236  
   237  	c, err := NewClient(NewContext(nil, nil, nil))
   238  	creds := newCredentialCacher()
   239  	creds.Approve(creds1)
   240  	creds.Approve(creds2)
   241  	c.Credentials = creds
   242  
   243  	req, err := http.NewRequest("GET", srv1.URL, nil)
   244  	require.Nil(t, err)
   245  
   246  	_, err = c.DoWithAuth("", req)
   247  	assert.Nil(t, err)
   248  
   249  	// called1 is 2 since LFS tries an unauthenticated request first
   250  	assert.EqualValues(t, 2, called1)
   251  	assert.EqualValues(t, 1, called2)
   252  }
   253  
   254  func TestNewClient(t *testing.T) {
   255  	c, err := NewClient(NewContext(nil, nil, map[string]string{
   256  		"lfs.dialtimeout":         "151",
   257  		"lfs.keepalive":           "152",
   258  		"lfs.tlstimeout":          "153",
   259  		"lfs.concurrenttransfers": "154",
   260  	}))
   261  
   262  	require.Nil(t, err)
   263  	assert.Equal(t, 151, c.DialTimeout)
   264  	assert.Equal(t, 152, c.KeepaliveTimeout)
   265  	assert.Equal(t, 153, c.TLSTimeout)
   266  	assert.Equal(t, 154, c.ConcurrentTransfers)
   267  }
   268  
   269  func TestNewClientWithGitSSLVerify(t *testing.T) {
   270  	c, err := NewClient(nil)
   271  	assert.Nil(t, err)
   272  	assert.False(t, c.SkipSSLVerify)
   273  
   274  	for _, value := range []string{"true", "1", "t"} {
   275  		c, err = NewClient(NewContext(nil, nil, map[string]string{
   276  			"http.sslverify": value,
   277  		}))
   278  		t.Logf("http.sslverify: %q", value)
   279  		assert.Nil(t, err)
   280  		assert.False(t, c.SkipSSLVerify)
   281  	}
   282  
   283  	for _, value := range []string{"false", "0", "f"} {
   284  		c, err = NewClient(NewContext(nil, nil, map[string]string{
   285  			"http.sslverify": value,
   286  		}))
   287  		t.Logf("http.sslverify: %q", value)
   288  		assert.Nil(t, err)
   289  		assert.True(t, c.SkipSSLVerify)
   290  	}
   291  }
   292  
   293  func TestNewClientWithOSSSLVerify(t *testing.T) {
   294  	c, err := NewClient(nil)
   295  	assert.Nil(t, err)
   296  	assert.False(t, c.SkipSSLVerify)
   297  
   298  	for _, value := range []string{"false", "0", "f"} {
   299  		c, err = NewClient(NewContext(nil, map[string]string{
   300  			"GIT_SSL_NO_VERIFY": value,
   301  		}, nil))
   302  		t.Logf("GIT_SSL_NO_VERIFY: %q", value)
   303  		assert.Nil(t, err)
   304  		assert.False(t, c.SkipSSLVerify)
   305  	}
   306  
   307  	for _, value := range []string{"true", "1", "t"} {
   308  		c, err = NewClient(NewContext(nil, map[string]string{
   309  			"GIT_SSL_NO_VERIFY": value,
   310  		}, nil))
   311  		t.Logf("GIT_SSL_NO_VERIFY: %q", value)
   312  		assert.Nil(t, err)
   313  		assert.True(t, c.SkipSSLVerify)
   314  	}
   315  }
   316  
   317  func TestNewRequest(t *testing.T) {
   318  	tests := [][]string{
   319  		{"https://example.com", "a", "https://example.com/a"},
   320  		{"https://example.com/", "a", "https://example.com/a"},
   321  		{"https://example.com/a", "b", "https://example.com/a/b"},
   322  		{"https://example.com/a/", "b", "https://example.com/a/b"},
   323  	}
   324  
   325  	for _, test := range tests {
   326  		c, err := NewClient(NewContext(nil, nil, map[string]string{
   327  			"lfs.url": test[0],
   328  		}))
   329  		require.Nil(t, err)
   330  
   331  		req, err := c.NewRequest("POST", c.Endpoints.Endpoint("", ""), test[1], nil)
   332  		require.Nil(t, err)
   333  		assert.Equal(t, "POST", req.Method)
   334  		assert.Equal(t, test[2], req.URL.String(), fmt.Sprintf("endpoint: %s, suffix: %s, expected: %s", test[0], test[1], test[2]))
   335  	}
   336  }
   337  
   338  func TestNewRequestWithBody(t *testing.T) {
   339  	c, err := NewClient(NewContext(nil, nil, map[string]string{
   340  		"lfs.url": "https://example.com",
   341  	}))
   342  	require.Nil(t, err)
   343  
   344  	body := struct {
   345  		Test string
   346  	}{Test: "test"}
   347  	req, err := c.NewRequest("POST", c.Endpoints.Endpoint("", ""), "body", body)
   348  	require.Nil(t, err)
   349  
   350  	assert.NotNil(t, req.Body)
   351  	assert.Equal(t, "15", req.Header.Get("Content-Length"))
   352  	assert.EqualValues(t, 15, req.ContentLength)
   353  }
   354  
   355  func TestMarshalToRequest(t *testing.T) {
   356  	req, err := http.NewRequest("POST", "https://foo/bar", nil)
   357  	require.Nil(t, err)
   358  
   359  	assert.Nil(t, req.Body)
   360  	assert.Equal(t, "", req.Header.Get("Content-Length"))
   361  	assert.EqualValues(t, 0, req.ContentLength)
   362  
   363  	body := struct {
   364  		Test string
   365  	}{Test: "test"}
   366  	require.Nil(t, MarshalToRequest(req, body))
   367  
   368  	assert.NotNil(t, req.Body)
   369  	assert.Equal(t, "15", req.Header.Get("Content-Length"))
   370  	assert.EqualValues(t, 15, req.ContentLength)
   371  }