github.com/hashicorp/vault/sdk@v0.13.0/database/helper/connutil/sql_test.go (about) 1 // Copyright (c) HashiCorp, Inc. 2 // SPDX-License-Identifier: MPL-2.0 3 4 package connutil 5 6 import ( 7 "context" 8 "net/url" 9 "strings" 10 "testing" 11 12 "github.com/stretchr/testify/assert" 13 ) 14 15 func TestSQLPasswordChars(t *testing.T) { 16 testCases := []struct { 17 Username string 18 Password string 19 }{ 20 {"postgres", "password{0}"}, 21 {"postgres", "pass:word"}, 22 {"postgres", "pass/word"}, 23 {"postgres", "p@ssword"}, 24 {"postgres", "pass\"word\""}, 25 } 26 for _, tc := range testCases { 27 t.Logf("username %q password %q", tc.Username, tc.Password) 28 29 sql := &SQLConnectionProducer{} 30 ctx := context.Background() 31 conf := map[string]interface{}{ 32 "connection_url": "postgres://{{username}}:{{password}}@localhost:5432/mydb", 33 "username": tc.Username, 34 "password": tc.Password, 35 "disable_escaping": false, 36 } 37 _, err := sql.Init(ctx, conf, false) 38 if err != nil { 39 t.Errorf("Init error on %q %q: %+v", tc.Username, tc.Password, err) 40 } else { 41 // This jumps down a few layers... 42 // Connection() uses sql.Open uses lib/pq uses net/url.Parse 43 u, err := url.Parse(sql.ConnectionURL) 44 if err != nil { 45 t.Errorf("URL parse error on %q %q: %+v", tc.Username, tc.Password, err) 46 } else { 47 username := u.User.Username() 48 password, pPresent := u.User.Password() 49 if username != tc.Username { 50 t.Errorf("Parsed username %q != original username %q", username, tc.Username) 51 } 52 if !pPresent { 53 t.Errorf("Password %q not present", tc.Password) 54 } else if password != tc.Password { 55 t.Errorf("Parsed password %q != original password %q", password, tc.Password) 56 } 57 } 58 } 59 } 60 } 61 62 func TestSQLDisableEscaping(t *testing.T) { 63 testCases := []struct { 64 Username string 65 Password string 66 DisableEscaping bool 67 }{ 68 {"mssql{0}", "password{0}", true}, 69 {"mssql{0}", "password{0}", false}, 70 {"ms\"sql\"", "pass\"word\"", true}, 71 {"ms\"sql\"", "pass\"word\"", false}, 72 {"ms'sq;l", "pass'wor;d", true}, 73 {"ms'sq;l", "pass'wor;d", false}, 74 } 75 for _, tc := range testCases { 76 t.Logf("username %q password %q disable_escaling %t", tc.Username, tc.Password, tc.DisableEscaping) 77 78 sql := &SQLConnectionProducer{} 79 ctx := context.Background() 80 conf := map[string]interface{}{ 81 "connection_url": "server=localhost;port=1433;user id={{username}};password={{password}};database=mydb;", 82 "username": tc.Username, 83 "password": tc.Password, 84 "disable_escaping": tc.DisableEscaping, 85 } 86 _, err := sql.Init(ctx, conf, false) 87 if err != nil { 88 t.Errorf("Init error on %q %q: %+v", tc.Username, tc.Password, err) 89 } else { 90 if tc.DisableEscaping { 91 if !strings.Contains(sql.ConnectionURL, tc.Username) || !strings.Contains(sql.ConnectionURL, tc.Password) { 92 t.Errorf("Raw username and/or password missing from ConnectionURL") 93 } 94 } else { 95 if strings.Contains(sql.ConnectionURL, tc.Username) || strings.Contains(sql.ConnectionURL, tc.Password) { 96 t.Errorf("Raw username and/or password was present in ConnectionURL") 97 } 98 } 99 } 100 } 101 } 102 103 func TestSQLDisallowTemplates(t *testing.T) { 104 testCases := []struct { 105 Username string 106 Password string 107 }{ 108 {"{{username}}", "pass"}, 109 {"{{password}}", "pass"}, 110 {"user", "{{username}}"}, 111 {"user", "{{password}}"}, 112 {"{{username}}", "{{password}}"}, 113 {"abc{username}xyz", "123{password}789"}, 114 {"abc{{username}}xyz", "123{{password}}789"}, 115 {"abc{{{username}}}xyz", "123{{{password}}}789"}, 116 } 117 for _, disableEscaping := range []bool{true, false} { 118 for _, tc := range testCases { 119 t.Logf("username %q password %q disable_escaping %t", tc.Username, tc.Password, disableEscaping) 120 121 sql := &SQLConnectionProducer{} 122 ctx := context.Background() 123 conf := map[string]interface{}{ 124 "connection_url": "server=localhost;port=1433;user id={{username}};password={{password}};database=mydb;", 125 "username": tc.Username, 126 "password": tc.Password, 127 "disable_escaping": disableEscaping, 128 } 129 _, err := sql.Init(ctx, conf, false) 130 if disableEscaping { 131 if err != nil { 132 if !assert.EqualError(t, err, "username and/or password cannot contain the template variables") { 133 t.Errorf("Init error on %q %q: %+v", tc.Username, tc.Password, err) 134 } 135 } else { 136 assert.Equal(t, sql.ConnectionURL, "server=localhost;port=1433;user id=abc{username}xyz;password=123{password}789;database=mydb;") 137 } 138 } 139 } 140 } 141 }