github.com/snowflakedb/gosnowflake@v1.9.0/authokta_test.go (about)

     1  // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"context"
     7  	"errors"
     8  	"net/http"
     9  	"net/url"
    10  	"strconv"
    11  	"testing"
    12  	"time"
    13  )
    14  
    15  func TestUnitPostBackURL(t *testing.T) {
    16  	c := `<html><form id="1" action="https&#x3a;&#x2f;&#x2f;abc.com&#x2f;"></form></html>`
    17  	pbURL, err := postBackURL([]byte(c))
    18  	if err != nil {
    19  		t.Fatalf("failed to get URL. err: %v, %v", err, c)
    20  	}
    21  	if pbURL.String() != "https://abc.com/" {
    22  		t.Errorf("failed to get URL. got: %v, %v", pbURL, c)
    23  	}
    24  	c = `<html></html>`
    25  	_, err = postBackURL([]byte(c))
    26  	if err == nil {
    27  		t.Fatalf("should have failed")
    28  	}
    29  	c = `<html><form id="1"/></html>`
    30  	_, err = postBackURL([]byte(c))
    31  	if err == nil {
    32  		t.Fatalf("should have failed")
    33  	}
    34  	c = `<html><form id="1" action="https&#x3a;&#x2f;&#x2f;abc.com&#x2f;/></html>`
    35  	_, err = postBackURL([]byte(c))
    36  	if err == nil {
    37  		t.Fatalf("should have failed")
    38  	}
    39  }
    40  
    41  func getTestError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) {
    42  	return &http.Response{
    43  		StatusCode: http.StatusOK,
    44  		Body:       &fakeResponseBody{body: []byte{0x12, 0x34}},
    45  	}, errors.New("failed to run post method")
    46  }
    47  
    48  func getTestAppBadGatewayError(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) {
    49  	return &http.Response{
    50  		StatusCode: http.StatusBadGateway,
    51  		Body:       &fakeResponseBody{body: []byte{0x12, 0x34}},
    52  	}, nil
    53  }
    54  
    55  func getTestHTMLSuccess(_ context.Context, _ *snowflakeRestful, _ *url.URL, _ map[string]string, _ time.Duration) (*http.Response, error) {
    56  	return &http.Response{
    57  		StatusCode: http.StatusOK,
    58  		Body:       &fakeResponseBody{body: []byte("<htm></html>")},
    59  	}, nil
    60  }
    61  
    62  func TestUnitPostAuthSAML(t *testing.T) {
    63  	sr := &snowflakeRestful{
    64  		FuncPost:      postTestError,
    65  		TokenAccessor: getSimpleTokenAccessor(),
    66  	}
    67  	var err error
    68  	_, err = postAuthSAML(context.Background(), sr, make(map[string]string), []byte{}, 0)
    69  	if err == nil {
    70  		t.Fatal("should have failed.")
    71  	}
    72  	sr.FuncPost = postTestAppBadGatewayError
    73  	_, err = postAuthSAML(context.Background(), sr, make(map[string]string), []byte{}, 0)
    74  	if err == nil {
    75  		t.Fatal("should have failed.")
    76  	}
    77  	sr.FuncPost = postTestSuccessButInvalidJSON
    78  	_, err = postAuthSAML(context.Background(), sr, make(map[string]string), []byte{0x12, 0x34}, 0)
    79  	if err == nil {
    80  		t.Fatalf("should have failed to post")
    81  	}
    82  }
    83  
    84  func TestUnitPostAuthOKTA(t *testing.T) {
    85  	sr := &snowflakeRestful{
    86  		FuncPost:      postTestError,
    87  		TokenAccessor: getSimpleTokenAccessor(),
    88  	}
    89  	var err error
    90  	_, err = postAuthOKTA(context.Background(), sr, make(map[string]string), []byte{}, "hahah", 0)
    91  	if err == nil {
    92  		t.Fatal("should have failed.")
    93  	}
    94  	sr.FuncPost = postTestAppBadGatewayError
    95  	_, err = postAuthOKTA(context.Background(), sr, make(map[string]string), []byte{}, "hahah", 0)
    96  	if err == nil {
    97  		t.Fatal("should have failed.")
    98  	}
    99  	sr.FuncPost = postTestSuccessButInvalidJSON
   100  	_, err = postAuthOKTA(context.Background(), sr, make(map[string]string), []byte{0x12, 0x34}, "haha", 0)
   101  	if err == nil {
   102  		t.Fatal("should have failed to run post request after the renewal")
   103  	}
   104  }
   105  
   106  func TestUnitGetSSO(t *testing.T) {
   107  	sr := &snowflakeRestful{
   108  		FuncGet:       getTestError,
   109  		TokenAccessor: getSimpleTokenAccessor(),
   110  	}
   111  	var err error
   112  	_, err = getSSO(context.Background(), sr, &url.Values{}, make(map[string]string), "hahah", 0)
   113  	if err == nil {
   114  		t.Fatal("should have failed.")
   115  	}
   116  	sr.FuncGet = getTestAppBadGatewayError
   117  	_, err = getSSO(context.Background(), sr, &url.Values{}, make(map[string]string), "hahah", 0)
   118  	if err == nil {
   119  		t.Fatal("should have failed.")
   120  	}
   121  	sr.FuncGet = getTestHTMLSuccess
   122  	_, err = getSSO(context.Background(), sr, &url.Values{}, make(map[string]string), "hahah", 0)
   123  	if err != nil {
   124  		t.Fatalf("failed to get HTML content. err: %v", err)
   125  	}
   126  	_, err = getSSO(context.Background(), sr, &url.Values{}, make(map[string]string), "invalid!@url$%^", 0)
   127  	if err == nil {
   128  		t.Fatal("should have failed to parse URL.")
   129  	}
   130  }
   131  
   132  func postAuthSAMLError(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
   133  	return &authResponse{}, errors.New("failed to get SAML response")
   134  }
   135  
   136  func postAuthSAMLAuthFail(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
   137  	return &authResponse{
   138  		Success: false,
   139  		Message: "SAML auth failed",
   140  	}, nil
   141  }
   142  
   143  func postAuthSAMLAuthFailWithCode(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
   144  	return &authResponse{
   145  		Success: false,
   146  		Code:    strconv.Itoa(ErrCodeIdpConnectionError),
   147  		Message: "SAML auth failed",
   148  	}, nil
   149  }
   150  
   151  func postAuthSAMLAuthSuccessButInvalidURL(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
   152  	return &authResponse{
   153  		Success: true,
   154  		Message: "",
   155  		Data: authResponseMain{
   156  			TokenURL: "https://1abc.com/token",
   157  			SSOURL:   "https://2abc.com/sso",
   158  		},
   159  	}, nil
   160  }
   161  
   162  func postAuthSAMLAuthSuccessButInvalidTokenURL(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
   163  	return &authResponse{
   164  		Success: true,
   165  		Message: "",
   166  		Data: authResponseMain{
   167  			TokenURL: "invalid!@url$%^",
   168  			SSOURL:   "https://abc.com/sso",
   169  		},
   170  	}, nil
   171  }
   172  
   173  func postAuthSAMLAuthSuccessButInvalidSSOURL(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
   174  	return &authResponse{
   175  		Success: true,
   176  		Message: "",
   177  		Data: authResponseMain{
   178  			TokenURL: "https://abc.com/token",
   179  			SSOURL:   "invalid!@url$%^",
   180  		},
   181  	}, nil
   182  }
   183  
   184  func postAuthSAMLAuthSuccess(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ time.Duration) (*authResponse, error) {
   185  	return &authResponse{
   186  		Success: true,
   187  		Message: "",
   188  		Data: authResponseMain{
   189  			TokenURL: "https://abc.com/token",
   190  			SSOURL:   "https://abc.com/sso",
   191  		},
   192  	}, nil
   193  }
   194  
   195  func postAuthOKTAError(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ string, _ time.Duration) (*authOKTAResponse, error) {
   196  	return &authOKTAResponse{}, errors.New("failed to get SAML response")
   197  }
   198  
   199  func postAuthOKTASuccess(_ context.Context, _ *snowflakeRestful, _ map[string]string, _ []byte, _ string, _ time.Duration) (*authOKTAResponse, error) {
   200  	return &authOKTAResponse{}, nil
   201  }
   202  
   203  func getSSOError(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ string, _ time.Duration) ([]byte, error) {
   204  	return []byte{}, errors.New("failed to get SSO html")
   205  }
   206  
   207  func getSSOSuccessButInvalidURL(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ string, _ time.Duration) ([]byte, error) {
   208  	return []byte(`<html><form id="1"/></html>`), nil
   209  }
   210  
   211  func getSSOSuccess(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ string, _ time.Duration) ([]byte, error) {
   212  	return []byte(`<html><form id="1" action="https&#x3a;&#x2f;&#x2f;abc.com&#x2f;"></form></html>`), nil
   213  }
   214  
   215  func getSSOSuccessButWrongPrefixURL(_ context.Context, _ *snowflakeRestful, _ *url.Values, _ map[string]string, _ string, _ time.Duration) ([]byte, error) {
   216  	return []byte(`<html><form id="1" action="https&#x3a;&#x2f;&#x2f;1abc.com&#x2f;"></form></html>`), nil
   217  }
   218  
   219  func TestUnitAuthenticateBySAML(t *testing.T) {
   220  	authenticator := &url.URL{
   221  		Scheme: "https",
   222  		Host:   "abc.com",
   223  	}
   224  	application := "testapp"
   225  	account := "testaccount"
   226  	user := "u"
   227  	password := "p"
   228  	sr := &snowflakeRestful{
   229  		Protocol:         "https",
   230  		Host:             "abc.com",
   231  		Port:             443,
   232  		FuncPostAuthSAML: postAuthSAMLError,
   233  		TokenAccessor:    getSimpleTokenAccessor(),
   234  	}
   235  	var err error
   236  	_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
   237  	assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.")
   238  	assertEqualE(t, err.Error(), "failed to get SAML response")
   239  
   240  	sr.FuncPostAuthSAML = postAuthSAMLAuthFail
   241  	_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
   242  	assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.")
   243  	assertEqualE(t, err.Error(), "strconv.Atoi: parsing \"\": invalid syntax")
   244  
   245  	sr.FuncPostAuthSAML = postAuthSAMLAuthFailWithCode
   246  	_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
   247  	assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.")
   248  	driverErr, ok := err.(*SnowflakeError)
   249  	assertTrueF(t, ok, "should be a SnowflakeError")
   250  	assertEqualE(t, driverErr.Number, ErrCodeIdpConnectionError)
   251  
   252  	sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidURL
   253  	_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
   254  	assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.")
   255  	driverErr, ok = err.(*SnowflakeError)
   256  	assertTrueF(t, ok, "should be a SnowflakeError")
   257  	assertEqualE(t, driverErr.Number, ErrCodeIdpConnectionError)
   258  
   259  	sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidTokenURL
   260  	_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
   261  	assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.")
   262  	assertEqualE(t, err.Error(), "failed to parse token URL. invalid!@url$%^")
   263  
   264  	sr.FuncPostAuthSAML = postAuthSAMLAuthSuccessButInvalidSSOURL
   265  	_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
   266  	assertNotNilF(t, err, "should have failed at FuncPostAuthSAML.")
   267  	assertEqualE(t, err.Error(), "failed to parse SSO URL. invalid!@url$%^")
   268  
   269  	sr.FuncPostAuthSAML = postAuthSAMLAuthSuccess
   270  	sr.FuncPostAuthOKTA = postAuthOKTAError
   271  	_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
   272  	assertNotNilF(t, err, "should have failed at FuncPostAuthOKTA.")
   273  	assertEqualE(t, err.Error(), "failed to get SAML response")
   274  
   275  	sr.FuncPostAuthOKTA = postAuthOKTASuccess
   276  	sr.FuncGetSSO = getSSOError
   277  	_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
   278  	assertNotNilF(t, err, "should have failed at FuncGetSSO.")
   279  	assertEqualE(t, err.Error(), "failed to get SSO html")
   280  
   281  	sr.FuncGetSSO = getSSOSuccessButInvalidURL
   282  	_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
   283  	assertNotNilF(t, err, "should have failed at FuncGetSSO.")
   284  	assertHasPrefixE(t, err.Error(), "failed to find action field in HTML response")
   285  
   286  	sr.FuncGetSSO = getSSOSuccess
   287  	_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
   288  	assertNilF(t, err, "should have succeeded at FuncGetSSO.")
   289  
   290  	sr.FuncGetSSO = getSSOSuccessButWrongPrefixURL
   291  	_, err = authenticateBySAML(context.Background(), sr, authenticator, application, account, user, password)
   292  	assertNotNilF(t, err, "should have failed at FuncGetSSO.")
   293  	driverErr, ok = err.(*SnowflakeError)
   294  	assertTrueF(t, ok, "should be a SnowflakeError")
   295  	assertEqualE(t, driverErr.Number, ErrCodeSSOURLNotMatch)
   296  }