github.com/hashicorp/vault/sdk@v0.13.0/database/helper/connutil/sql_test.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package connutil
     5  
     6  import (
     7  	"context"
     8  	"net/url"
     9  	"strings"
    10  	"testing"
    11  
    12  	"github.com/stretchr/testify/assert"
    13  )
    14  
    15  func TestSQLPasswordChars(t *testing.T) {
    16  	testCases := []struct {
    17  		Username string
    18  		Password string
    19  	}{
    20  		{"postgres", "password{0}"},
    21  		{"postgres", "pass:word"},
    22  		{"postgres", "pass/word"},
    23  		{"postgres", "p@ssword"},
    24  		{"postgres", "pass\"word\""},
    25  	}
    26  	for _, tc := range testCases {
    27  		t.Logf("username %q password %q", tc.Username, tc.Password)
    28  
    29  		sql := &SQLConnectionProducer{}
    30  		ctx := context.Background()
    31  		conf := map[string]interface{}{
    32  			"connection_url":   "postgres://{{username}}:{{password}}@localhost:5432/mydb",
    33  			"username":         tc.Username,
    34  			"password":         tc.Password,
    35  			"disable_escaping": false,
    36  		}
    37  		_, err := sql.Init(ctx, conf, false)
    38  		if err != nil {
    39  			t.Errorf("Init error on %q %q: %+v", tc.Username, tc.Password, err)
    40  		} else {
    41  			// This jumps down a few layers...
    42  			// Connection() uses sql.Open uses lib/pq uses net/url.Parse
    43  			u, err := url.Parse(sql.ConnectionURL)
    44  			if err != nil {
    45  				t.Errorf("URL parse error on %q %q: %+v", tc.Username, tc.Password, err)
    46  			} else {
    47  				username := u.User.Username()
    48  				password, pPresent := u.User.Password()
    49  				if username != tc.Username {
    50  					t.Errorf("Parsed username %q != original username %q", username, tc.Username)
    51  				}
    52  				if !pPresent {
    53  					t.Errorf("Password %q not present", tc.Password)
    54  				} else if password != tc.Password {
    55  					t.Errorf("Parsed password %q != original password %q", password, tc.Password)
    56  				}
    57  			}
    58  		}
    59  	}
    60  }
    61  
    62  func TestSQLDisableEscaping(t *testing.T) {
    63  	testCases := []struct {
    64  		Username        string
    65  		Password        string
    66  		DisableEscaping bool
    67  	}{
    68  		{"mssql{0}", "password{0}", true},
    69  		{"mssql{0}", "password{0}", false},
    70  		{"ms\"sql\"", "pass\"word\"", true},
    71  		{"ms\"sql\"", "pass\"word\"", false},
    72  		{"ms'sq;l", "pass'wor;d", true},
    73  		{"ms'sq;l", "pass'wor;d", false},
    74  	}
    75  	for _, tc := range testCases {
    76  		t.Logf("username %q password %q disable_escaling %t", tc.Username, tc.Password, tc.DisableEscaping)
    77  
    78  		sql := &SQLConnectionProducer{}
    79  		ctx := context.Background()
    80  		conf := map[string]interface{}{
    81  			"connection_url":   "server=localhost;port=1433;user id={{username}};password={{password}};database=mydb;",
    82  			"username":         tc.Username,
    83  			"password":         tc.Password,
    84  			"disable_escaping": tc.DisableEscaping,
    85  		}
    86  		_, err := sql.Init(ctx, conf, false)
    87  		if err != nil {
    88  			t.Errorf("Init error on %q %q: %+v", tc.Username, tc.Password, err)
    89  		} else {
    90  			if tc.DisableEscaping {
    91  				if !strings.Contains(sql.ConnectionURL, tc.Username) || !strings.Contains(sql.ConnectionURL, tc.Password) {
    92  					t.Errorf("Raw username and/or password missing from ConnectionURL")
    93  				}
    94  			} else {
    95  				if strings.Contains(sql.ConnectionURL, tc.Username) || strings.Contains(sql.ConnectionURL, tc.Password) {
    96  					t.Errorf("Raw username and/or password was present in ConnectionURL")
    97  				}
    98  			}
    99  		}
   100  	}
   101  }
   102  
   103  func TestSQLDisallowTemplates(t *testing.T) {
   104  	testCases := []struct {
   105  		Username string
   106  		Password string
   107  	}{
   108  		{"{{username}}", "pass"},
   109  		{"{{password}}", "pass"},
   110  		{"user", "{{username}}"},
   111  		{"user", "{{password}}"},
   112  		{"{{username}}", "{{password}}"},
   113  		{"abc{username}xyz", "123{password}789"},
   114  		{"abc{{username}}xyz", "123{{password}}789"},
   115  		{"abc{{{username}}}xyz", "123{{{password}}}789"},
   116  	}
   117  	for _, disableEscaping := range []bool{true, false} {
   118  		for _, tc := range testCases {
   119  			t.Logf("username %q password %q disable_escaping %t", tc.Username, tc.Password, disableEscaping)
   120  
   121  			sql := &SQLConnectionProducer{}
   122  			ctx := context.Background()
   123  			conf := map[string]interface{}{
   124  				"connection_url":   "server=localhost;port=1433;user id={{username}};password={{password}};database=mydb;",
   125  				"username":         tc.Username,
   126  				"password":         tc.Password,
   127  				"disable_escaping": disableEscaping,
   128  			}
   129  			_, err := sql.Init(ctx, conf, false)
   130  			if disableEscaping {
   131  				if err != nil {
   132  					if !assert.EqualError(t, err, "username and/or password cannot contain the template variables") {
   133  						t.Errorf("Init error on %q %q: %+v", tc.Username, tc.Password, err)
   134  					}
   135  				} else {
   136  					assert.Equal(t, sql.ConnectionURL, "server=localhost;port=1433;user id=abc{username}xyz;password=123{password}789;database=mydb;")
   137  				}
   138  			}
   139  		}
   140  	}
   141  }