vitess.io/vitess@v0.16.2/go/vt/dbconfigs/dbconfigs_test.go (about) 1 /* 2 Copyright 2019 The Vitess Authors. 3 4 Licensed under the Apache License, Version 2.0 (the "License"); 5 you may not use this file except in compliance with the License. 6 You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10 Unless required by applicable law or agreed to in writing, software 11 distributed under the License is distributed on an "AS IS" BASIS, 12 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 See the License for the specific language governing permissions and 14 limitations under the License. 15 */ 16 17 package dbconfigs 18 19 import ( 20 "fmt" 21 "os" 22 "syscall" 23 "testing" 24 "time" 25 26 "github.com/stretchr/testify/assert" 27 "github.com/stretchr/testify/require" 28 29 "vitess.io/vitess/go/mysql" 30 "vitess.io/vitess/go/yaml2" 31 ) 32 33 func TestInit(t *testing.T) { 34 dbConfigs := DBConfigs{ 35 appParams: mysql.ConnParams{UnixSocket: "socket"}, 36 dbaParams: mysql.ConnParams{Host: "host"}, 37 Charset: "utf8", 38 } 39 dbConfigs.InitWithSocket("default") 40 assert.Equal(t, mysql.ConnParams{UnixSocket: "socket", Charset: "utf8"}, dbConfigs.appParams) 41 assert.Equal(t, mysql.ConnParams{Host: "host", Charset: "utf8"}, dbConfigs.dbaParams) 42 assert.Equal(t, mysql.ConnParams{UnixSocket: "default", Charset: "utf8"}, dbConfigs.appdebugParams) 43 44 dbConfigs = DBConfigs{ 45 Host: "a", 46 Port: 1, 47 Socket: "b", 48 Charset: "utf8mb4", 49 Flags: 2, 50 Flavor: "flavor", 51 SslCa: "d", 52 SslCaPath: "e", 53 SslCert: "f", 54 SslKey: "g", 55 ConnectTimeoutMilliseconds: 250, 56 App: UserConfig{ 57 User: "app", 58 Password: "apppass", 59 }, 60 Appdebug: UserConfig{ 61 UseSSL: true, 62 }, 63 Dba: UserConfig{ 64 User: "dba", 65 Password: "dbapass", 66 UseSSL: true, 67 }, 68 appParams: mysql.ConnParams{ 69 UnixSocket: "socket", 70 }, 71 dbaParams: mysql.ConnParams{ 72 Host: "host", 73 }, 74 } 75 dbConfigs.InitWithSocket("default") 76 77 want := mysql.ConnParams{ 78 Host: "a", 79 Port: 1, 80 Uname: "app", 81 Pass: "apppass", 82 UnixSocket: "b", 83 Charset: "utf8mb4", 84 Flags: 2, 85 Flavor: "flavor", 86 ConnectTimeoutMs: 250, 87 } 88 assert.Equal(t, want, dbConfigs.appParams) 89 90 want = mysql.ConnParams{ 91 Host: "a", 92 Port: 1, 93 UnixSocket: "b", 94 Charset: "utf8mb4", 95 Flags: 2, 96 Flavor: "flavor", 97 SslCa: "d", 98 SslCaPath: "e", 99 SslCert: "f", 100 SslKey: "g", 101 ConnectTimeoutMs: 250, 102 } 103 assert.Equal(t, want, dbConfigs.appdebugParams) 104 want = mysql.ConnParams{ 105 Host: "a", 106 Port: 1, 107 Uname: "dba", 108 Pass: "dbapass", 109 UnixSocket: "b", 110 Charset: "utf8mb4", 111 Flags: 2, 112 Flavor: "flavor", 113 SslCa: "d", 114 SslCaPath: "e", 115 SslCert: "f", 116 SslKey: "g", 117 ConnectTimeoutMs: 250, 118 } 119 assert.Equal(t, want, dbConfigs.dbaParams) 120 121 // Test that baseConfig does not override Charset and Flag if they're 122 // not specified. 123 dbConfigs = DBConfigs{ 124 Host: "a", 125 Port: 1, 126 Socket: "b", 127 SslCa: "d", 128 SslCaPath: "e", 129 SslCert: "f", 130 SslKey: "g", 131 Charset: "utf8", 132 App: UserConfig{ 133 User: "app", 134 Password: "apppass", 135 }, 136 Appdebug: UserConfig{ 137 UseSSL: true, 138 }, 139 Dba: UserConfig{ 140 User: "dba", 141 Password: "dbapass", 142 UseSSL: true, 143 }, 144 appParams: mysql.ConnParams{ 145 UnixSocket: "socket", 146 Charset: "utf8mb4", 147 }, 148 dbaParams: mysql.ConnParams{ 149 Host: "host", 150 Flags: 2, 151 }, 152 } 153 dbConfigs.InitWithSocket("default") 154 want = mysql.ConnParams{ 155 Host: "a", 156 Port: 1, 157 Uname: "app", 158 Pass: "apppass", 159 UnixSocket: "b", 160 Charset: "utf8mb4", 161 } 162 assert.Equal(t, want, dbConfigs.appParams) 163 want = mysql.ConnParams{ 164 Host: "a", 165 Port: 1, 166 UnixSocket: "b", 167 SslCa: "d", 168 SslCaPath: "e", 169 SslCert: "f", 170 SslKey: "g", 171 Charset: "utf8", 172 } 173 assert.Equal(t, want, dbConfigs.appdebugParams) 174 want = mysql.ConnParams{ 175 Host: "a", 176 Port: 1, 177 Uname: "dba", 178 Pass: "dbapass", 179 UnixSocket: "b", 180 Flags: 2, 181 SslCa: "d", 182 SslCaPath: "e", 183 SslCert: "f", 184 SslKey: "g", 185 Charset: "utf8", 186 } 187 assert.Equal(t, want, dbConfigs.dbaParams) 188 } 189 190 func TestUseTCP(t *testing.T) { 191 dbConfigs := DBConfigs{ 192 Host: "a", 193 Port: 1, 194 Socket: "b", 195 App: UserConfig{ 196 User: "app", 197 UseTCP: true, 198 }, 199 Dba: UserConfig{ 200 User: "dba", 201 }, 202 Charset: "utf8", 203 } 204 dbConfigs.InitWithSocket("default") 205 206 want := mysql.ConnParams{ 207 Host: "a", 208 Port: 1, 209 Uname: "app", 210 Charset: "utf8", 211 } 212 assert.Equal(t, want, dbConfigs.appParams) 213 214 want = mysql.ConnParams{ 215 Host: "a", 216 Port: 1, 217 Uname: "dba", 218 UnixSocket: "b", 219 Charset: "utf8", 220 } 221 assert.Equal(t, want, dbConfigs.dbaParams) 222 } 223 224 func TestAccessors(t *testing.T) { 225 dbc := &DBConfigs{ 226 appParams: mysql.ConnParams{}, 227 appdebugParams: mysql.ConnParams{}, 228 allprivsParams: mysql.ConnParams{}, 229 dbaParams: mysql.ConnParams{}, 230 filteredParams: mysql.ConnParams{}, 231 replParams: mysql.ConnParams{}, 232 DBName: "db", 233 Charset: "utf8", 234 } 235 if got, want := dbc.AppWithDB().connParams.DbName, "db"; got != want { 236 t.Errorf("dbc.AppWithDB().DbName: %v, want %v", got, want) 237 } 238 if got, want := dbc.AllPrivsConnector().connParams.DbName, ""; got != want { 239 t.Errorf("dbc.AllPrivsWithDB().DbName: %v, want %v", got, want) 240 } 241 if got, want := dbc.AllPrivsWithDB().connParams.DbName, "db"; got != want { 242 t.Errorf("dbc.AllPrivsWithDB().DbName: %v, want %v", got, want) 243 } 244 if got, want := dbc.AppDebugWithDB().connParams.DbName, "db"; got != want { 245 t.Errorf("dbc.AppDebugWithDB().DbName: %v, want %v", got, want) 246 } 247 if got, want := dbc.DbaConnector().connParams.DbName, ""; got != want { 248 t.Errorf("dbc.Dba().DbName: %v, want %v", got, want) 249 } 250 if got, want := dbc.DbaWithDB().connParams.DbName, "db"; got != want { 251 t.Errorf("dbc.DbaWithDB().DbName: %v, want %v", got, want) 252 } 253 if got, want := dbc.FilteredWithDB().connParams.DbName, "db"; got != want { 254 t.Errorf("dbc.FilteredWithDB().DbName: %v, want %v", got, want) 255 } 256 if got, want := dbc.ReplConnector().connParams.DbName, ""; got != want { 257 t.Errorf("dbc.Repl().DbName: %v, want %v", got, want) 258 } 259 } 260 261 func TestCredentialsFileHUP(t *testing.T) { 262 tmpFile, err := os.CreateTemp("", "credentials.json") 263 if err != nil { 264 t.Fatalf("couldn't create temp file: %v", err) 265 } 266 defer os.Remove(tmpFile.Name()) 267 dbCredentialsFile = tmpFile.Name() 268 dbCredentialsServer = "file" 269 oldStr := "str1" 270 jsonConfig := fmt.Sprintf("{\"%s\": [\"%s\"]}", oldStr, oldStr) 271 if err := os.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil { 272 t.Fatalf("couldn't write temp file: %v", err) 273 } 274 cs := GetCredentialsServer() 275 _, pass, _ := cs.GetUserAndPassword(oldStr) 276 if pass != oldStr { 277 t.Fatalf("%s's Password should still be '%s'", oldStr, oldStr) 278 } 279 hupTest(t, tmpFile, oldStr, "str2") 280 hupTest(t, tmpFile, "str2", "str3") // still handling the signal 281 } 282 283 func hupTest(t *testing.T, tmpFile *os.File, oldStr, newStr string) { 284 cs := GetCredentialsServer() 285 jsonConfig := fmt.Sprintf("{\"%s\": [\"%s\"]}", newStr, newStr) 286 if err := os.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil { 287 t.Fatalf("couldn't overwrite temp file: %v", err) 288 } 289 _, pass, _ := cs.GetUserAndPassword(oldStr) 290 if pass != oldStr { 291 t.Fatalf("%s's Password should still be '%s'", oldStr, oldStr) 292 } 293 _ = syscall.Kill(syscall.Getpid(), syscall.SIGHUP) 294 time.Sleep(100 * time.Millisecond) // wait for signal handler 295 _, _, err := cs.GetUserAndPassword(oldStr) 296 if err != ErrUnknownUser { 297 t.Fatalf("Should not have old %s after config reload", oldStr) 298 } 299 _, pass, _ = cs.GetUserAndPassword(newStr) 300 if pass != newStr { 301 t.Fatalf("%s's Password should be '%s'", newStr, newStr) 302 } 303 } 304 305 func TestYaml(t *testing.T) { 306 db := DBConfigs{ 307 Socket: "a", 308 Port: 1, 309 Flags: 20, 310 App: UserConfig{ 311 User: "vt_app", 312 UseSSL: true, 313 }, 314 Dba: UserConfig{ 315 User: "vt_dba", 316 }, 317 } 318 gotBytes, err := yaml2.Marshal(&db) 319 require.NoError(t, err) 320 wantBytes := `allprivs: 321 password: '****' 322 app: 323 password: '****' 324 useSsl: true 325 user: vt_app 326 appdebug: 327 password: '****' 328 dba: 329 password: '****' 330 user: vt_dba 331 filtered: 332 password: '****' 333 flags: 20 334 port: 1 335 repl: 336 password: '****' 337 socket: a 338 ` 339 assert.Equal(t, wantBytes, string(gotBytes)) 340 341 inBytes := []byte(`socket: a 342 port: 1 343 flags: 20 344 app: 345 user: vt_app 346 useSsl: true 347 useTCP: false 348 dba: 349 user: vt_dba 350 `) 351 gotdb := DBConfigs{ 352 Port: 1, 353 Flags: 20, 354 App: UserConfig{ 355 UseTCP: true, 356 }, 357 Dba: UserConfig{ 358 User: "aaa", 359 }, 360 } 361 err = yaml2.Unmarshal(inBytes, &gotdb) 362 require.NoError(t, err) 363 assert.Equal(t, &db, &gotdb) 364 }