github.com/snowflakedb/gosnowflake@v1.9.0/priv_key_test.go (about) 1 // Copyright (c) 2017-2022 Snowflake Computing Inc. All rights reserved. 2 3 package gosnowflake 4 5 // For compile concern, should any newly added variables or functions here must also be added with same 6 // name or signature but with default or empty content in the priv_key_test.go(See addParseDSNTest) 7 8 import ( 9 "bytes" 10 "context" 11 "crypto/rand" 12 "crypto/rsa" 13 "crypto/x509" 14 "database/sql" 15 "encoding/base64" 16 "encoding/pem" 17 "fmt" 18 "os" 19 "testing" 20 ) 21 22 // helper function to generate PKCS8 encoded base64 string of a private key 23 func generatePKCS8StringSupress(key *rsa.PrivateKey) string { 24 // Error would only be thrown when the private key type is not supported 25 // We would be safe as long as we are using rsa.PrivateKey 26 tmpBytes, _ := x509.MarshalPKCS8PrivateKey(key) 27 privKeyPKCS8 := base64.URLEncoding.EncodeToString(tmpBytes) 28 return privKeyPKCS8 29 } 30 31 // helper function to generate PKCS1 encoded base64 string of a private key 32 func generatePKCS1String(key *rsa.PrivateKey) string { 33 tmpBytes := x509.MarshalPKCS1PrivateKey(key) 34 privKeyPKCS1 := base64.URLEncoding.EncodeToString(tmpBytes) 35 return privKeyPKCS1 36 } 37 38 // helper function to set up private key for testing 39 func setupPrivateKey() { 40 env := func(key, defaultValue string) string { 41 if value := os.Getenv(key); value != "" { 42 return value 43 } 44 return defaultValue 45 } 46 privKeyPath := env("SNOWFLAKE_TEST_PRIVATE_KEY", "") 47 if privKeyPath == "" { 48 customPrivateKey = false 49 testPrivKey, _ = rsa.GenerateKey(rand.Reader, 2048) 50 } else { 51 // path to the DER file 52 customPrivateKey = true 53 data, _ := os.ReadFile(privKeyPath) 54 block, _ := pem.Decode(data) 55 if block == nil || block.Type != "PRIVATE KEY" { 56 panic(fmt.Sprintf("%v is not a public key in PEM format.", privKeyPath)) 57 } 58 privKey, _ := x509.ParsePKCS8PrivateKey(block.Bytes) 59 testPrivKey = privKey.(*rsa.PrivateKey) 60 } 61 } 62 63 // Helper function to add encoded private key to dsn 64 func appendPrivateKeyString(dsn *string, key *rsa.PrivateKey) string { 65 var b bytes.Buffer 66 b.WriteString(*dsn) 67 b.WriteString(fmt.Sprintf("&authenticator=%v", AuthTypeJwt.String())) 68 b.WriteString(fmt.Sprintf("&privateKey=%s", generatePKCS8StringSupress(key))) 69 return b.String() 70 } 71 72 // Integration test for the JWT authentication function 73 func TestJWTAuthentication(t *testing.T) { 74 // For private key generated on the fly, we want to load the public key to the server first 75 if !customPrivateKey { 76 conn := openConn(t) 77 defer conn.Close() 78 // Load server's public key to database 79 pubKeyByte, err := x509.MarshalPKIXPublicKey(testPrivKey.Public()) 80 if err != nil { 81 t.Fatalf("error marshaling public key: %s", err.Error()) 82 } 83 if _, err = conn.ExecContext(context.Background(), "USE ROLE ACCOUNTADMIN"); err != nil { 84 t.Fatalf("error changin role: %s", err.Error()) 85 } 86 encodedKey := base64.StdEncoding.EncodeToString(pubKeyByte) 87 if _, err = conn.ExecContext(context.Background(), fmt.Sprintf("ALTER USER %v set rsa_public_key='%v'", username, encodedKey)); err != nil { 88 t.Fatalf("error setting server's public key: %s", err.Error()) 89 } 90 } 91 92 // Test that a valid private key can pass 93 jwtDSN := appendPrivateKeyString(&dsn, testPrivKey) 94 db, err := sql.Open("snowflake", jwtDSN) 95 if err != nil { 96 t.Fatalf("error creating a connection object: %s", err.Error()) 97 } 98 if _, err = db.Exec("SELECT 1"); err != nil { 99 t.Fatalf("error executing: %s", err.Error()) 100 } 101 db.Close() 102 103 // Test that an invalid private key cannot pass 104 invalidPrivateKey, err := rsa.GenerateKey(rand.Reader, 2048) 105 if err != nil { 106 t.Error(err) 107 } 108 jwtDSN = appendPrivateKeyString(&dsn, invalidPrivateKey) 109 db, err = sql.Open("snowflake", jwtDSN) 110 if err != nil { 111 t.Error(err) 112 } 113 if _, err = db.Exec("SELECT 1"); err == nil { 114 t.Fatalf("An invalid jwt token can pass") 115 } 116 117 db.Close() 118 } 119 120 func TestJWTTokenTimeout(t *testing.T) { 121 resetHTTPMocks(t) 122 123 dsn := "user:pass@localhost:12345/db/schema?account=jwtAuthTokenTimeout&protocol=http&jwtClientTimeout=1" 124 dsn = appendPrivateKeyString(&dsn, testPrivKey) 125 db, err := sql.Open("snowflake", dsn) 126 if err != nil { 127 t.Fatalf(err.Error()) 128 } 129 defer db.Close() 130 ctx := context.Background() 131 conn, err := db.Conn(ctx) 132 if err != nil { 133 t.Fatalf(err.Error()) 134 } 135 defer conn.Close() 136 137 invocations := getMocksInvocations(t) 138 if invocations != 3 { 139 t.Errorf("Unexpected number of invocations, expected 3, got %v", invocations) 140 } 141 }