github.com/stffabi/git-lfs@v2.3.5-0.20180214015214-8eeaa8d88902+incompatible/lfsapi/ntlm_test.go (about)

     1  package lfsapi
     2  
     3  import (
     4  	"encoding/base64"
     5  	"io/ioutil"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"net/url"
     9  	"strings"
    10  	"sync/atomic"
    11  	"testing"
    12  
    13  	"github.com/ThomsonReutersEikon/go-ntlm/ntlm"
    14  	"github.com/stretchr/testify/assert"
    15  	"github.com/stretchr/testify/require"
    16  )
    17  
    18  func TestNTLMAuth(t *testing.T) {
    19  	session, err := ntlm.CreateServerSession(ntlm.Version2, ntlm.ConnectionOrientedMode)
    20  	require.Nil(t, err)
    21  	session.SetUserInfo("ntlmuser", "ntlmpass", "NTLMDOMAIN")
    22  
    23  	var called uint32
    24  	srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
    25  		reqIndex := atomic.LoadUint32(&called)
    26  		atomic.AddUint32(&called, 1)
    27  
    28  		authHeader := req.Header.Get("Authorization")
    29  		t.Logf("REQUEST %d: %s %s", reqIndex, req.Method, req.URL)
    30  		t.Logf("AUTH: %q", authHeader)
    31  
    32  		// assert full body is sent each time
    33  		by, err := ioutil.ReadAll(req.Body)
    34  		req.Body.Close()
    35  		if assert.Nil(t, err) {
    36  			assert.Equal(t, "ntlm", string(by))
    37  		}
    38  
    39  		switch authHeader {
    40  		case "":
    41  			w.Header().Set("Www-Authenticate", "ntlm")
    42  			w.WriteHeader(401)
    43  		case ntlmNegotiateMessage:
    44  			assert.True(t, strings.HasPrefix(req.Header.Get("Authorization"), "NTLM "))
    45  			ch, err := session.GenerateChallengeMessage()
    46  			if !assert.Nil(t, err) {
    47  				t.Logf("challenge gen error: %+v", err)
    48  				w.WriteHeader(500)
    49  				return
    50  			}
    51  			chMsg := base64.StdEncoding.EncodeToString(ch.Bytes())
    52  			w.Header().Set("Www-Authenticate", "ntlm "+chMsg)
    53  			w.WriteHeader(401)
    54  		default: // should be an auth msg
    55  			authHeader := req.Header.Get("Authorization")
    56  			assert.True(t, strings.HasPrefix(strings.ToUpper(authHeader), "NTLM "))
    57  			auth := authHeader[5:] // strip "ntlm " prefix
    58  			val, err := base64.StdEncoding.DecodeString(auth)
    59  			if !assert.Nil(t, err) {
    60  				t.Logf("auth base64 error: %+v", err)
    61  				w.WriteHeader(500)
    62  				return
    63  			}
    64  
    65  			_, err = ntlm.ParseAuthenticateMessage(val, 2)
    66  			if !assert.Nil(t, err) {
    67  				t.Logf("auth parse error: %+v", err)
    68  				w.WriteHeader(500)
    69  				return
    70  			}
    71  			w.WriteHeader(200)
    72  		}
    73  	}))
    74  	defer srv.Close()
    75  
    76  	req, err := http.NewRequest("POST", srv.URL+"/ntlm", NewByteBody([]byte("ntlm")))
    77  	require.Nil(t, err)
    78  
    79  	credHelper := newMockCredentialHelper()
    80  	cli, err := NewClient(NewContext(nil, nil, map[string]string{
    81  		"lfs.url":                         srv.URL + "/ntlm",
    82  		"lfs." + srv.URL + "/ntlm.access": "ntlm",
    83  	}))
    84  	cli.Credentials = credHelper
    85  	require.Nil(t, err)
    86  
    87  	// ntlm support pulls domain and login info from git credentials
    88  	srvURL, err := url.Parse(srv.URL)
    89  	require.Nil(t, err)
    90  	creds := Creds{
    91  		"protocol": srvURL.Scheme,
    92  		"host":     srvURL.Host,
    93  		"username": "ntlmdomain\\ntlmuser",
    94  		"password": "ntlmpass",
    95  	}
    96  	credHelper.Approve(creds)
    97  
    98  	res, err := cli.DoWithAuth("remote", req)
    99  	require.Nil(t, err)
   100  	assert.Equal(t, 200, res.StatusCode)
   101  	assert.True(t, credHelper.IsApproved(creds))
   102  }
   103  
   104  func TestNtlmClientSession(t *testing.T) {
   105  	cli, err := NewClient(nil)
   106  	require.Nil(t, err)
   107  
   108  	creds := Creds{"username": "MOOSEDOMAIN\\canadian", "password": "MooseAntlersYeah"}
   109  	session1, err := cli.ntlmClientSession(creds)
   110  	assert.Nil(t, err)
   111  	assert.NotNil(t, session1)
   112  
   113  	// The second call should ignore creds and give the session we just created.
   114  	badCreds := Creds{"username": "MOOSEDOMAIN\\badusername", "password": "MooseAntlersYeah"}
   115  	session2, err := cli.ntlmClientSession(badCreds)
   116  	assert.Nil(t, err)
   117  	assert.NotNil(t, session2)
   118  	assert.EqualValues(t, session1, session2)
   119  }
   120  
   121  func TestNtlmClientSessionBadCreds(t *testing.T) {
   122  	cli, err := NewClient(nil)
   123  	require.Nil(t, err)
   124  	creds := Creds{"username": "badusername", "password": "MooseAntlersYeah"}
   125  	_, err = cli.ntlmClientSession(creds)
   126  	assert.NotNil(t, err)
   127  }
   128  
   129  func TestNtlmHeaderParseValid(t *testing.T) {
   130  	res := http.Response{}
   131  	res.Header = make(map[string][]string)
   132  	res.Header.Add("Www-Authenticate", "NTLM "+base64.StdEncoding.EncodeToString([]byte("I am a moose")))
   133  	bytes, err := parseChallengeResponse(&res)
   134  	assert.Nil(t, err)
   135  	assert.False(t, strings.HasPrefix(string(bytes), "NTLM"))
   136  }
   137  
   138  func TestNtlmHeaderParseInvalidLength(t *testing.T) {
   139  	res := http.Response{}
   140  	res.Header = make(map[string][]string)
   141  	res.Header.Add("Www-Authenticate", "NTL")
   142  	ret, err := parseChallengeResponse(&res)
   143  	assert.NotNil(t, err)
   144  	assert.Nil(t, ret)
   145  }
   146  
   147  func TestNtlmHeaderParseInvalid(t *testing.T) {
   148  	res := http.Response{}
   149  	res.Header = make(map[string][]string)
   150  	res.Header.Add("Www-Authenticate", base64.StdEncoding.EncodeToString([]byte("NTLM I am a moose")))
   151  	ret, err := parseChallengeResponse(&res)
   152  	assert.NotNil(t, err)
   153  	assert.Nil(t, ret)
   154  }
   155  
   156  func assertRequestsEqual(t *testing.T, req1 *http.Request, req2 *http.Request, req1Body []byte) {
   157  	assert.Equal(t, req1.Method, req2.Method)
   158  
   159  	for k, v := range req1.Header {
   160  		assert.Equal(t, v, req2.Header[k])
   161  	}
   162  
   163  	if req1.Body == nil {
   164  		assert.Nil(t, req2.Body)
   165  	} else {
   166  		bytes2, _ := ioutil.ReadAll(req2.Body)
   167  		assert.Equal(t, req1Body, bytes2)
   168  	}
   169  }