github.com/snowflakedb/gosnowflake@v1.9.0/util_test.go (about)

     1  // Copyright (c) 2017-2023 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  import (
     6  	"context"
     7  	"database/sql/driver"
     8  	"fmt"
     9  	"math/rand"
    10  	"os"
    11  	"strconv"
    12  	"sync"
    13  	"testing"
    14  	"time"
    15  )
    16  
    17  type tcIntMinMax struct {
    18  	v1  int
    19  	v2  int
    20  	out int
    21  }
    22  
    23  type tcUUID struct {
    24  	uuid string
    25  }
    26  
    27  type constTypeProvider struct {
    28  	constTime int64
    29  }
    30  
    31  func (ctp *constTypeProvider) currentTime() int64 {
    32  	return ctp.constTime
    33  }
    34  
    35  func constTimeProvider(constTime int64) *constTypeProvider {
    36  	return &constTypeProvider{constTime: constTime}
    37  }
    38  
    39  func TestSimpleTokenAccessor(t *testing.T) {
    40  	accessor := getSimpleTokenAccessor()
    41  	token, masterToken, sessionID := accessor.GetTokens()
    42  	if token != "" {
    43  		t.Errorf("unexpected token %v", token)
    44  	}
    45  	if masterToken != "" {
    46  		t.Errorf("unexpected master token %v", masterToken)
    47  	}
    48  	if sessionID != -1 {
    49  		t.Errorf("unexpected session id %v", sessionID)
    50  	}
    51  
    52  	expectedToken, expectedMasterToken, expectedSessionID := "token123", "master123", int64(123)
    53  	accessor.SetTokens(expectedToken, expectedMasterToken, expectedSessionID)
    54  	token, masterToken, sessionID = accessor.GetTokens()
    55  	if token != expectedToken {
    56  		t.Errorf("unexpected token %v", token)
    57  	}
    58  	if masterToken != expectedMasterToken {
    59  		t.Errorf("unexpected master token %v", masterToken)
    60  	}
    61  	if sessionID != expectedSessionID {
    62  		t.Errorf("unexpected session id %v", sessionID)
    63  	}
    64  }
    65  
    66  func TestSimpleTokenAccessorGetTokensSynchronization(t *testing.T) {
    67  	accessor := getSimpleTokenAccessor()
    68  	var wg sync.WaitGroup
    69  	failed := false
    70  	for i := 0; i < 1000; i++ {
    71  		wg.Add(1)
    72  		go func() {
    73  			// set a random session and token
    74  			session := rand.Int63()
    75  			sessionStr := strconv.FormatInt(session, 10)
    76  			accessor.SetTokens("t"+sessionStr, "m"+sessionStr, session)
    77  
    78  			// read back session and token and verify that invariant still holds
    79  			token, masterToken, session := accessor.GetTokens()
    80  			sessionStr = strconv.FormatInt(session, 10)
    81  			if "t"+sessionStr != token || "m"+sessionStr != masterToken {
    82  				failed = true
    83  			}
    84  			wg.Done()
    85  		}()
    86  	}
    87  	// wait for all competing goroutines to finish setting and getting tokens
    88  	wg.Wait()
    89  	if failed {
    90  		t.Fail()
    91  	}
    92  }
    93  
    94  func TestGetRequestIDFromContext(t *testing.T) {
    95  	expectedRequestID := NewUUID()
    96  	ctx := WithRequestID(context.Background(), expectedRequestID)
    97  	requestID := getOrGenerateRequestIDFromContext(ctx)
    98  	if requestID != expectedRequestID {
    99  		t.Errorf("unexpected request id: %v, expected: %v", requestID, expectedRequestID)
   100  	}
   101  	ctx = WithRequestID(context.Background(), nilUUID)
   102  	requestID = getOrGenerateRequestIDFromContext(ctx)
   103  	if requestID == nilUUID {
   104  		t.Errorf("unexpected request id, should not be nil")
   105  	}
   106  }
   107  
   108  func TestGenerateRequestID(t *testing.T) {
   109  	firstRequestID := getOrGenerateRequestIDFromContext(context.Background())
   110  	otherRequestID := getOrGenerateRequestIDFromContext(context.Background())
   111  	if firstRequestID == otherRequestID {
   112  		t.Errorf("request id should not be the same")
   113  	}
   114  }
   115  
   116  func TestIntMin(t *testing.T) {
   117  	testcases := []tcIntMinMax{
   118  		{1, 3, 1},
   119  		{5, 100, 5},
   120  		{321, 3, 3},
   121  		{123, 123, 123},
   122  	}
   123  	for _, test := range testcases {
   124  		t.Run(fmt.Sprintf("%v_%v_%v", test.v1, test.v2, test.out), func(t *testing.T) {
   125  			a := intMin(test.v1, test.v2)
   126  			if test.out != a {
   127  				t.Errorf("failed int min. v1: %v, v2: %v, expected: %v, got: %v", test.v1, test.v2, test.out, a)
   128  			}
   129  		})
   130  	}
   131  }
   132  func TestIntMax(t *testing.T) {
   133  	testcases := []tcIntMinMax{
   134  		{1, 3, 3},
   135  		{5, 100, 100},
   136  		{321, 3, 321},
   137  		{123, 123, 123},
   138  	}
   139  	for _, test := range testcases {
   140  		t.Run(fmt.Sprintf("%v_%v_%v", test.v1, test.v2, test.out), func(t *testing.T) {
   141  			a := intMax(test.v1, test.v2)
   142  			if test.out != a {
   143  				t.Errorf("failed int max. v1: %v, v2: %v, expected: %v, got: %v", test.v1, test.v2, test.out, a)
   144  			}
   145  		})
   146  	}
   147  }
   148  
   149  type tcDurationMinMax struct {
   150  	v1  time.Duration
   151  	v2  time.Duration
   152  	out time.Duration
   153  }
   154  
   155  func TestDurationMin(t *testing.T) {
   156  	testcases := []tcDurationMinMax{
   157  		{1 * time.Second, 3 * time.Second, 1 * time.Second},
   158  		{5 * time.Second, 100 * time.Second, 5 * time.Second},
   159  		{321 * time.Second, 3 * time.Second, 3 * time.Second},
   160  		{123 * time.Second, 123 * time.Second, 123 * time.Second},
   161  	}
   162  	for _, test := range testcases {
   163  		t.Run(fmt.Sprintf("%v_%v_%v", test.v1, test.v2, test.out), func(t *testing.T) {
   164  			a := durationMin(test.v1, test.v2)
   165  			if test.out != a {
   166  				t.Errorf("failed duratoin max. v1: %v, v2: %v, expected: %v, got: %v", test.v1, test.v2, test.out, a)
   167  			}
   168  		})
   169  	}
   170  }
   171  
   172  func TestDurationMax(t *testing.T) {
   173  	testcases := []tcDurationMinMax{
   174  		{1 * time.Second, 3 * time.Second, 3 * time.Second},
   175  		{5 * time.Second, 100 * time.Second, 100 * time.Second},
   176  		{321 * time.Second, 3 * time.Second, 321 * time.Second},
   177  		{123 * time.Second, 123 * time.Second, 123 * time.Second},
   178  	}
   179  	for _, test := range testcases {
   180  		t.Run(fmt.Sprintf("%v_%v_%v", test.v1, test.v2, test.out), func(t *testing.T) {
   181  			a := durationMax(test.v1, test.v2)
   182  			if test.out != a {
   183  				t.Errorf("failed duratoin max. v1: %v, v2: %v, expected: %v, got: %v", test.v1, test.v2, test.out, a)
   184  			}
   185  		})
   186  	}
   187  }
   188  
   189  type tcNamedValues struct {
   190  	values []driver.Value
   191  	out    []driver.NamedValue
   192  }
   193  
   194  func compareNamedValues(v1 []driver.NamedValue, v2 []driver.NamedValue) bool {
   195  	if v1 == nil && v2 == nil {
   196  		return true
   197  	}
   198  	if v1 == nil || v2 == nil {
   199  		return false
   200  	}
   201  	if len(v1) != len(v2) {
   202  		return false
   203  	}
   204  	for i := range v1 {
   205  		if v1[i] != v2[i] {
   206  			return false
   207  		}
   208  	}
   209  	return true
   210  }
   211  
   212  func TestToNamedValues(t *testing.T) {
   213  	testcases := []tcNamedValues{
   214  		{
   215  			values: []driver.Value{},
   216  			out:    []driver.NamedValue{},
   217  		},
   218  		{
   219  			values: []driver.Value{1},
   220  			out:    []driver.NamedValue{{Name: "", Ordinal: 1, Value: 1}},
   221  		},
   222  		{
   223  			values: []driver.Value{1, "test1", 9.876, nil},
   224  			out: []driver.NamedValue{
   225  				{Name: "", Ordinal: 1, Value: 1},
   226  				{Name: "", Ordinal: 2, Value: "test1"},
   227  				{Name: "", Ordinal: 3, Value: 9.876},
   228  				{Name: "", Ordinal: 4, Value: nil}},
   229  		},
   230  	}
   231  	for _, test := range testcases {
   232  		t.Run("", func(t *testing.T) {
   233  			a := toNamedValues(test.values)
   234  
   235  			if !compareNamedValues(test.out, a) {
   236  				t.Errorf("failed int max. v1: %v, v2: %v, expected: %v, got: %v", test.values, test.out, test.out, a)
   237  			}
   238  		})
   239  	}
   240  }
   241  
   242  type tcIntArrayMin struct {
   243  	in  []int
   244  	out int
   245  }
   246  
   247  func TestGetMin(t *testing.T) {
   248  	testcases := []tcIntArrayMin{
   249  		{[]int{1, 2, 3, 4, 5}, 1},
   250  		{[]int{10, 25, 15, 5, 20}, 5},
   251  		{[]int{15, 12, 9, 6, 3}, 3},
   252  		{[]int{123, 123, 123, 123, 123}, 123},
   253  		{[]int{}, -1},
   254  	}
   255  	for _, test := range testcases {
   256  		t.Run(fmt.Sprintf("%v", test.out), func(t *testing.T) {
   257  			a := getMin(test.in)
   258  			if test.out != a {
   259  				t.Errorf("failed get min. in: %v, expected: %v, got: %v", test.in, test.out, a)
   260  			}
   261  		})
   262  	}
   263  }
   264  
   265  type tcURLList struct {
   266  	in  string
   267  	out bool
   268  }
   269  
   270  func TestValidURL(t *testing.T) {
   271  	testcases := []tcURLList{
   272  		{"https://ssoTestURL.okta.com", true},
   273  		{"https://ssoTestURL.okta.com:8080", true},
   274  		{"https://ssoTestURL.okta.com/testpathvalue", true},
   275  		{"-a calculator", false},
   276  		{"This is a random test", false},
   277  		{"file://TestForFile", false},
   278  	}
   279  	for _, test := range testcases {
   280  		t.Run(test.in, func(t *testing.T) {
   281  			result := isValidURL(test.in)
   282  			if test.out != result {
   283  				t.Errorf("Failed to validate URL, input :%v, expected: %v, got: %v", test.in, test.out, result)
   284  			}
   285  		})
   286  	}
   287  }
   288  
   289  type tcEncodeList struct {
   290  	in  string
   291  	out string
   292  }
   293  
   294  func TestEncodeURL(t *testing.T) {
   295  	testcases := []tcEncodeList{
   296  		{"Hello @World", "Hello+%40World"},
   297  		{"Test//String", "Test%2F%2FString"},
   298  	}
   299  
   300  	for _, test := range testcases {
   301  		t.Run(test.in, func(t *testing.T) {
   302  			result := urlEncode(test.in)
   303  			if test.out != result {
   304  				t.Errorf("Failed to encode string, input %v, expected: %v, got: %v", test.in, test.out, result)
   305  			}
   306  		})
   307  	}
   308  }
   309  
   310  func TestParseUUID(t *testing.T) {
   311  	testcases := []tcUUID{
   312  		{"6ba7b812-9dad-11d1-80b4-00c04fd430c8"},
   313  		{"00302010-0504-0706-0809-0a0b0c0d0e0f"},
   314  	}
   315  
   316  	for _, test := range testcases {
   317  		t.Run(test.uuid, func(t *testing.T) {
   318  			requestID := ParseUUID(test.uuid)
   319  			if requestID.String() != test.uuid {
   320  				t.Fatalf("failed to parse uuid")
   321  			}
   322  		})
   323  	}
   324  }
   325  
   326  type tcEscapeCsv struct {
   327  	in  string
   328  	out string
   329  }
   330  
   331  func TestEscapeForCSV(t *testing.T) {
   332  	testcases := []tcEscapeCsv{
   333  		{"", "\"\""},
   334  		{"\n", "\"\n\""},
   335  		{"test\\", "\"test\\\""},
   336  	}
   337  
   338  	for _, test := range testcases {
   339  		t.Run(test.out, func(t *testing.T) {
   340  			result := escapeForCSV(test.in)
   341  			if test.out != result {
   342  				t.Errorf("Failed to escape string, input %v, expected: %v, got: %v", test.in, test.out, result)
   343  			}
   344  		})
   345  	}
   346  }
   347  
   348  func TestGetFromEnv(t *testing.T) {
   349  	os.Setenv("SF_TEST", "test")
   350  	defer os.Unsetenv("SF_TEST")
   351  	result, err := GetFromEnv("SF_TEST", true)
   352  
   353  	if err != nil {
   354  		t.Error("failed to read SF_TEST environment variable")
   355  	}
   356  	if result != "test" {
   357  		t.Errorf("incorrect value read for SF_TEST. Expected: test, read %v", result)
   358  	}
   359  }
   360  
   361  func TestGetFromEnvFailOnMissing(t *testing.T) {
   362  	_, err := GetFromEnv("SF_TEST_MISSING", true)
   363  	if err == nil {
   364  		t.Error("should report error when there is missing env parameter")
   365  	}
   366  }
   367  
   368  type tcContains[T comparable] struct {
   369  	arr      []T
   370  	e        T
   371  	expected bool
   372  }
   373  
   374  func TestContains(t *testing.T) {
   375  	performContainsTestcase(tcContains[int]{[]int{1, 2, 3, 5}, 4, false}, t)
   376  	performContainsTestcase(tcContains[string]{[]string{"a", "b", "C", "F"}, "C", true}, t)
   377  	performContainsTestcase(tcContains[int]{[]int{1, 2, 3, 5}, 2, true}, t)
   378  	performContainsTestcase(tcContains[string]{[]string{"a", "b", "C", "F"}, "f", false}, t)
   379  }
   380  
   381  func performContainsTestcase[S comparable](tc tcContains[S], t *testing.T) {
   382  	result := contains(tc.arr, tc.e)
   383  	if result != tc.expected {
   384  		t.Errorf("contains failed; arr: %v, e: %v, should be %v but was %v", tc.arr, tc.e, tc.expected, result)
   385  	}
   386  }
   387  
   388  func skipOnJenkins(t *testing.T, message string) {
   389  	if os.Getenv("JENKINS_HOME") != "" {
   390  		t.Skip("Skipping test on Jenkins: " + message)
   391  	}
   392  }