github.com/snowflakedb/gosnowflake@v1.9.0/priv_key_test.go (about)

     1  // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved.
     2  
     3  package gosnowflake
     4  
     5  // For compile concern, should any newly added variables or functions here must also be added with same
     6  // name or signature but with default or empty content in the priv_key_test.go(See addParseDSNTest)
     7  
     8  import (
     9  	"bytes"
    10  	"context"
    11  	"crypto/rand"
    12  	"crypto/rsa"
    13  	"crypto/x509"
    14  	"database/sql"
    15  	"encoding/base64"
    16  	"encoding/pem"
    17  	"fmt"
    18  	"os"
    19  	"testing"
    20  )
    21  
    22  // helper function to generate PKCS8 encoded base64 string of a private key
    23  func generatePKCS8StringSupress(key *rsa.PrivateKey) string {
    24  	// Error would only be thrown when the private key type is not supported
    25  	// We would be safe as long as we are using rsa.PrivateKey
    26  	tmpBytes, _ := x509.MarshalPKCS8PrivateKey(key)
    27  	privKeyPKCS8 := base64.URLEncoding.EncodeToString(tmpBytes)
    28  	return privKeyPKCS8
    29  }
    30  
    31  // helper function to generate PKCS1 encoded base64 string of a private key
    32  func generatePKCS1String(key *rsa.PrivateKey) string {
    33  	tmpBytes := x509.MarshalPKCS1PrivateKey(key)
    34  	privKeyPKCS1 := base64.URLEncoding.EncodeToString(tmpBytes)
    35  	return privKeyPKCS1
    36  }
    37  
    38  // helper function to set up private key for testing
    39  func setupPrivateKey() {
    40  	env := func(key, defaultValue string) string {
    41  		if value := os.Getenv(key); value != "" {
    42  			return value
    43  		}
    44  		return defaultValue
    45  	}
    46  	privKeyPath := env("SNOWFLAKE_TEST_PRIVATE_KEY", "")
    47  	if privKeyPath == "" {
    48  		customPrivateKey = false
    49  		testPrivKey, _ = rsa.GenerateKey(rand.Reader, 2048)
    50  	} else {
    51  		// path to the DER file
    52  		customPrivateKey = true
    53  		data, _ := os.ReadFile(privKeyPath)
    54  		block, _ := pem.Decode(data)
    55  		if block == nil || block.Type != "PRIVATE KEY" {
    56  			panic(fmt.Sprintf("%v is not a public key in PEM format.", privKeyPath))
    57  		}
    58  		privKey, _ := x509.ParsePKCS8PrivateKey(block.Bytes)
    59  		testPrivKey = privKey.(*rsa.PrivateKey)
    60  	}
    61  }
    62  
    63  // Helper function to add encoded private key to dsn
    64  func appendPrivateKeyString(dsn *string, key *rsa.PrivateKey) string {
    65  	var b bytes.Buffer
    66  	b.WriteString(*dsn)
    67  	b.WriteString(fmt.Sprintf("&authenticator=%v", AuthTypeJwt.String()))
    68  	b.WriteString(fmt.Sprintf("&privateKey=%s", generatePKCS8StringSupress(key)))
    69  	return b.String()
    70  }
    71  
    72  // Integration test for the JWT authentication function
    73  func TestJWTAuthentication(t *testing.T) {
    74  	// For private key generated on the fly, we want to load the public key to the server first
    75  	if !customPrivateKey {
    76  		conn := openConn(t)
    77  		defer conn.Close()
    78  		// Load server's public key to database
    79  		pubKeyByte, err := x509.MarshalPKIXPublicKey(testPrivKey.Public())
    80  		if err != nil {
    81  			t.Fatalf("error marshaling public key: %s", err.Error())
    82  		}
    83  		if _, err = conn.ExecContext(context.Background(), "USE ROLE ACCOUNTADMIN"); err != nil {
    84  			t.Fatalf("error changin role: %s", err.Error())
    85  		}
    86  		encodedKey := base64.StdEncoding.EncodeToString(pubKeyByte)
    87  		if _, err = conn.ExecContext(context.Background(), fmt.Sprintf("ALTER USER %v set rsa_public_key='%v'", username, encodedKey)); err != nil {
    88  			t.Fatalf("error setting server's public key: %s", err.Error())
    89  		}
    90  	}
    91  
    92  	// Test that a valid private key can pass
    93  	jwtDSN := appendPrivateKeyString(&dsn, testPrivKey)
    94  	db, err := sql.Open("snowflake", jwtDSN)
    95  	if err != nil {
    96  		t.Fatalf("error creating a connection object: %s", err.Error())
    97  	}
    98  	if _, err = db.Exec("SELECT 1"); err != nil {
    99  		t.Fatalf("error executing: %s", err.Error())
   100  	}
   101  	db.Close()
   102  
   103  	// Test that an invalid private key cannot pass
   104  	invalidPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048)
   105  	if err != nil {
   106  		t.Error(err)
   107  	}
   108  	jwtDSN = appendPrivateKeyString(&dsn, invalidPrivateKey)
   109  	db, err = sql.Open("snowflake", jwtDSN)
   110  	if err != nil {
   111  		t.Error(err)
   112  	}
   113  	if _, err = db.Exec("SELECT 1"); err == nil {
   114  		t.Fatalf("An invalid jwt token can pass")
   115  	}
   116  
   117  	db.Close()
   118  }
   119  
   120  func TestJWTTokenTimeout(t *testing.T) {
   121  	resetHTTPMocks(t)
   122  
   123  	dsn := "user:pass@localhost:12345/db/schema?account=jwtAuthTokenTimeout&protocol=http&jwtClientTimeout=1"
   124  	dsn = appendPrivateKeyString(&dsn, testPrivKey)
   125  	db, err := sql.Open("snowflake", dsn)
   126  	if err != nil {
   127  		t.Fatalf(err.Error())
   128  	}
   129  	defer db.Close()
   130  	ctx := context.Background()
   131  	conn, err := db.Conn(ctx)
   132  	if err != nil {
   133  		t.Fatalf(err.Error())
   134  	}
   135  	defer conn.Close()
   136  
   137  	invocations := getMocksInvocations(t)
   138  	if invocations != 3 {
   139  		t.Errorf("Unexpected number of invocations, expected 3, got %v", invocations)
   140  	}
   141  }