vitess.io/vitess@v0.16.2/go/vt/dbconfigs/dbconfigs_test.go (about)

     1  /*
     2  Copyright 2019 The Vitess Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package dbconfigs
    18  
    19  import (
    20  	"fmt"
    21  	"os"
    22  	"syscall"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/stretchr/testify/assert"
    27  	"github.com/stretchr/testify/require"
    28  
    29  	"vitess.io/vitess/go/mysql"
    30  	"vitess.io/vitess/go/yaml2"
    31  )
    32  
    33  func TestInit(t *testing.T) {
    34  	dbConfigs := DBConfigs{
    35  		appParams: mysql.ConnParams{UnixSocket: "socket"},
    36  		dbaParams: mysql.ConnParams{Host: "host"},
    37  		Charset:   "utf8",
    38  	}
    39  	dbConfigs.InitWithSocket("default")
    40  	assert.Equal(t, mysql.ConnParams{UnixSocket: "socket", Charset: "utf8"}, dbConfigs.appParams)
    41  	assert.Equal(t, mysql.ConnParams{Host: "host", Charset: "utf8"}, dbConfigs.dbaParams)
    42  	assert.Equal(t, mysql.ConnParams{UnixSocket: "default", Charset: "utf8"}, dbConfigs.appdebugParams)
    43  
    44  	dbConfigs = DBConfigs{
    45  		Host:                       "a",
    46  		Port:                       1,
    47  		Socket:                     "b",
    48  		Charset:                    "utf8mb4",
    49  		Flags:                      2,
    50  		Flavor:                     "flavor",
    51  		SslCa:                      "d",
    52  		SslCaPath:                  "e",
    53  		SslCert:                    "f",
    54  		SslKey:                     "g",
    55  		ConnectTimeoutMilliseconds: 250,
    56  		App: UserConfig{
    57  			User:     "app",
    58  			Password: "apppass",
    59  		},
    60  		Appdebug: UserConfig{
    61  			UseSSL: true,
    62  		},
    63  		Dba: UserConfig{
    64  			User:     "dba",
    65  			Password: "dbapass",
    66  			UseSSL:   true,
    67  		},
    68  		appParams: mysql.ConnParams{
    69  			UnixSocket: "socket",
    70  		},
    71  		dbaParams: mysql.ConnParams{
    72  			Host: "host",
    73  		},
    74  	}
    75  	dbConfigs.InitWithSocket("default")
    76  
    77  	want := mysql.ConnParams{
    78  		Host:             "a",
    79  		Port:             1,
    80  		Uname:            "app",
    81  		Pass:             "apppass",
    82  		UnixSocket:       "b",
    83  		Charset:          "utf8mb4",
    84  		Flags:            2,
    85  		Flavor:           "flavor",
    86  		ConnectTimeoutMs: 250,
    87  	}
    88  	assert.Equal(t, want, dbConfigs.appParams)
    89  
    90  	want = mysql.ConnParams{
    91  		Host:             "a",
    92  		Port:             1,
    93  		UnixSocket:       "b",
    94  		Charset:          "utf8mb4",
    95  		Flags:            2,
    96  		Flavor:           "flavor",
    97  		SslCa:            "d",
    98  		SslCaPath:        "e",
    99  		SslCert:          "f",
   100  		SslKey:           "g",
   101  		ConnectTimeoutMs: 250,
   102  	}
   103  	assert.Equal(t, want, dbConfigs.appdebugParams)
   104  	want = mysql.ConnParams{
   105  		Host:             "a",
   106  		Port:             1,
   107  		Uname:            "dba",
   108  		Pass:             "dbapass",
   109  		UnixSocket:       "b",
   110  		Charset:          "utf8mb4",
   111  		Flags:            2,
   112  		Flavor:           "flavor",
   113  		SslCa:            "d",
   114  		SslCaPath:        "e",
   115  		SslCert:          "f",
   116  		SslKey:           "g",
   117  		ConnectTimeoutMs: 250,
   118  	}
   119  	assert.Equal(t, want, dbConfigs.dbaParams)
   120  
   121  	// Test that baseConfig does not override Charset and Flag if they're
   122  	// not specified.
   123  	dbConfigs = DBConfigs{
   124  		Host:      "a",
   125  		Port:      1,
   126  		Socket:    "b",
   127  		SslCa:     "d",
   128  		SslCaPath: "e",
   129  		SslCert:   "f",
   130  		SslKey:    "g",
   131  		Charset:   "utf8",
   132  		App: UserConfig{
   133  			User:     "app",
   134  			Password: "apppass",
   135  		},
   136  		Appdebug: UserConfig{
   137  			UseSSL: true,
   138  		},
   139  		Dba: UserConfig{
   140  			User:     "dba",
   141  			Password: "dbapass",
   142  			UseSSL:   true,
   143  		},
   144  		appParams: mysql.ConnParams{
   145  			UnixSocket: "socket",
   146  			Charset:    "utf8mb4",
   147  		},
   148  		dbaParams: mysql.ConnParams{
   149  			Host:  "host",
   150  			Flags: 2,
   151  		},
   152  	}
   153  	dbConfigs.InitWithSocket("default")
   154  	want = mysql.ConnParams{
   155  		Host:       "a",
   156  		Port:       1,
   157  		Uname:      "app",
   158  		Pass:       "apppass",
   159  		UnixSocket: "b",
   160  		Charset:    "utf8mb4",
   161  	}
   162  	assert.Equal(t, want, dbConfigs.appParams)
   163  	want = mysql.ConnParams{
   164  		Host:       "a",
   165  		Port:       1,
   166  		UnixSocket: "b",
   167  		SslCa:      "d",
   168  		SslCaPath:  "e",
   169  		SslCert:    "f",
   170  		SslKey:     "g",
   171  		Charset:    "utf8",
   172  	}
   173  	assert.Equal(t, want, dbConfigs.appdebugParams)
   174  	want = mysql.ConnParams{
   175  		Host:       "a",
   176  		Port:       1,
   177  		Uname:      "dba",
   178  		Pass:       "dbapass",
   179  		UnixSocket: "b",
   180  		Flags:      2,
   181  		SslCa:      "d",
   182  		SslCaPath:  "e",
   183  		SslCert:    "f",
   184  		SslKey:     "g",
   185  		Charset:    "utf8",
   186  	}
   187  	assert.Equal(t, want, dbConfigs.dbaParams)
   188  }
   189  
   190  func TestUseTCP(t *testing.T) {
   191  	dbConfigs := DBConfigs{
   192  		Host:   "a",
   193  		Port:   1,
   194  		Socket: "b",
   195  		App: UserConfig{
   196  			User:   "app",
   197  			UseTCP: true,
   198  		},
   199  		Dba: UserConfig{
   200  			User: "dba",
   201  		},
   202  		Charset: "utf8",
   203  	}
   204  	dbConfigs.InitWithSocket("default")
   205  
   206  	want := mysql.ConnParams{
   207  		Host:    "a",
   208  		Port:    1,
   209  		Uname:   "app",
   210  		Charset: "utf8",
   211  	}
   212  	assert.Equal(t, want, dbConfigs.appParams)
   213  
   214  	want = mysql.ConnParams{
   215  		Host:       "a",
   216  		Port:       1,
   217  		Uname:      "dba",
   218  		UnixSocket: "b",
   219  		Charset:    "utf8",
   220  	}
   221  	assert.Equal(t, want, dbConfigs.dbaParams)
   222  }
   223  
   224  func TestAccessors(t *testing.T) {
   225  	dbc := &DBConfigs{
   226  		appParams:      mysql.ConnParams{},
   227  		appdebugParams: mysql.ConnParams{},
   228  		allprivsParams: mysql.ConnParams{},
   229  		dbaParams:      mysql.ConnParams{},
   230  		filteredParams: mysql.ConnParams{},
   231  		replParams:     mysql.ConnParams{},
   232  		DBName:         "db",
   233  		Charset:        "utf8",
   234  	}
   235  	if got, want := dbc.AppWithDB().connParams.DbName, "db"; got != want {
   236  		t.Errorf("dbc.AppWithDB().DbName: %v, want %v", got, want)
   237  	}
   238  	if got, want := dbc.AllPrivsConnector().connParams.DbName, ""; got != want {
   239  		t.Errorf("dbc.AllPrivsWithDB().DbName: %v, want %v", got, want)
   240  	}
   241  	if got, want := dbc.AllPrivsWithDB().connParams.DbName, "db"; got != want {
   242  		t.Errorf("dbc.AllPrivsWithDB().DbName: %v, want %v", got, want)
   243  	}
   244  	if got, want := dbc.AppDebugWithDB().connParams.DbName, "db"; got != want {
   245  		t.Errorf("dbc.AppDebugWithDB().DbName: %v, want %v", got, want)
   246  	}
   247  	if got, want := dbc.DbaConnector().connParams.DbName, ""; got != want {
   248  		t.Errorf("dbc.Dba().DbName: %v, want %v", got, want)
   249  	}
   250  	if got, want := dbc.DbaWithDB().connParams.DbName, "db"; got != want {
   251  		t.Errorf("dbc.DbaWithDB().DbName: %v, want %v", got, want)
   252  	}
   253  	if got, want := dbc.FilteredWithDB().connParams.DbName, "db"; got != want {
   254  		t.Errorf("dbc.FilteredWithDB().DbName: %v, want %v", got, want)
   255  	}
   256  	if got, want := dbc.ReplConnector().connParams.DbName, ""; got != want {
   257  		t.Errorf("dbc.Repl().DbName: %v, want %v", got, want)
   258  	}
   259  }
   260  
   261  func TestCredentialsFileHUP(t *testing.T) {
   262  	tmpFile, err := os.CreateTemp("", "credentials.json")
   263  	if err != nil {
   264  		t.Fatalf("couldn't create temp file: %v", err)
   265  	}
   266  	defer os.Remove(tmpFile.Name())
   267  	dbCredentialsFile = tmpFile.Name()
   268  	dbCredentialsServer = "file"
   269  	oldStr := "str1"
   270  	jsonConfig := fmt.Sprintf("{\"%s\": [\"%s\"]}", oldStr, oldStr)
   271  	if err := os.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil {
   272  		t.Fatalf("couldn't write temp file: %v", err)
   273  	}
   274  	cs := GetCredentialsServer()
   275  	_, pass, _ := cs.GetUserAndPassword(oldStr)
   276  	if pass != oldStr {
   277  		t.Fatalf("%s's Password should still be '%s'", oldStr, oldStr)
   278  	}
   279  	hupTest(t, tmpFile, oldStr, "str2")
   280  	hupTest(t, tmpFile, "str2", "str3") // still handling the signal
   281  }
   282  
   283  func hupTest(t *testing.T, tmpFile *os.File, oldStr, newStr string) {
   284  	cs := GetCredentialsServer()
   285  	jsonConfig := fmt.Sprintf("{\"%s\": [\"%s\"]}", newStr, newStr)
   286  	if err := os.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil {
   287  		t.Fatalf("couldn't overwrite temp file: %v", err)
   288  	}
   289  	_, pass, _ := cs.GetUserAndPassword(oldStr)
   290  	if pass != oldStr {
   291  		t.Fatalf("%s's Password should still be '%s'", oldStr, oldStr)
   292  	}
   293  	_ = syscall.Kill(syscall.Getpid(), syscall.SIGHUP)
   294  	time.Sleep(100 * time.Millisecond) // wait for signal handler
   295  	_, _, err := cs.GetUserAndPassword(oldStr)
   296  	if err != ErrUnknownUser {
   297  		t.Fatalf("Should not have old %s after config reload", oldStr)
   298  	}
   299  	_, pass, _ = cs.GetUserAndPassword(newStr)
   300  	if pass != newStr {
   301  		t.Fatalf("%s's Password should be '%s'", newStr, newStr)
   302  	}
   303  }
   304  
   305  func TestYaml(t *testing.T) {
   306  	db := DBConfigs{
   307  		Socket: "a",
   308  		Port:   1,
   309  		Flags:  20,
   310  		App: UserConfig{
   311  			User:   "vt_app",
   312  			UseSSL: true,
   313  		},
   314  		Dba: UserConfig{
   315  			User: "vt_dba",
   316  		},
   317  	}
   318  	gotBytes, err := yaml2.Marshal(&db)
   319  	require.NoError(t, err)
   320  	wantBytes := `allprivs:
   321    password: '****'
   322  app:
   323    password: '****'
   324    useSsl: true
   325    user: vt_app
   326  appdebug:
   327    password: '****'
   328  dba:
   329    password: '****'
   330    user: vt_dba
   331  filtered:
   332    password: '****'
   333  flags: 20
   334  port: 1
   335  repl:
   336    password: '****'
   337  socket: a
   338  `
   339  	assert.Equal(t, wantBytes, string(gotBytes))
   340  
   341  	inBytes := []byte(`socket: a
   342  port: 1
   343  flags: 20
   344  app:
   345    user: vt_app
   346    useSsl: true
   347    useTCP: false
   348  dba:
   349    user: vt_dba
   350  `)
   351  	gotdb := DBConfigs{
   352  		Port:  1,
   353  		Flags: 20,
   354  		App: UserConfig{
   355  			UseTCP: true,
   356  		},
   357  		Dba: UserConfig{
   358  			User: "aaa",
   359  		},
   360  	}
   361  	err = yaml2.Unmarshal(inBytes, &gotdb)
   362  	require.NoError(t, err)
   363  	assert.Equal(t, &db, &gotdb)
   364  }