vitess.io/vitess@v0.16.2/go/mysql/auth_server_static_flaky_test.go (about) 1 /* 2 copyright 2019 The Vitess Authors. 3 4 licensed under the apache license, version 2.0 (the "license"); 5 you may not use this file except in compliance with the license. 6 you may obtain a copy of the license at 7 8 http://www.apache.org/licenses/license-2.0 9 10 unless required by applicable law or agreed to in writing, software 11 distributed under the license is distributed on an "as is" basis, 12 without warranties or conditions of any kind, either express or implied. 13 see the license for the specific language governing permissions and 14 limitations under the license. 15 */ 16 17 package mysql 18 19 import ( 20 "fmt" 21 "net" 22 "os" 23 "syscall" 24 "testing" 25 "time" 26 27 "github.com/stretchr/testify/require" 28 ) 29 30 // getEntries is a test-only method for AuthServerStatic. 31 func (a *AuthServerStatic) getEntries() map[string][]*AuthServerStaticEntry { 32 a.mu.Lock() 33 defer a.mu.Unlock() 34 return a.entries 35 } 36 37 func TestJsonConfigParser(t *testing.T) { 38 // works with legacy format 39 config := make(map[string][]*AuthServerStaticEntry) 40 jsonConfig := "{\"mysql_user\":{\"Password\":\"123\", \"UserData\":\"dummy\"}, \"mysql_user_2\": {\"Password\": \"123\", \"UserData\": \"mysql_user_2\"}}" 41 err := ParseConfig([]byte(jsonConfig), &config) 42 require.NoError(t, err, "should not get an error, but got: %v", err) 43 require.Equal(t, 1, len(config["mysql_user"]), "mysql_user config size should be equal to 1") 44 require.Equal(t, 1, len(config["mysql_user_2"]), "mysql_user config size should be equal to 1") 45 46 // works with new format 47 jsonConfig = `{"mysql_user":[ 48 {"Password":"123", "UserData":"dummy", "SourceHost": "localhost"}, 49 {"Password": "123", "UserData": "mysql_user_all"}, 50 {"Password": "456", "UserData": "mysql_user_with_groups", "Groups": ["user_group"]} 51 ]}` 52 err = ParseConfig([]byte(jsonConfig), &config) 53 require.NoError(t, err, "should not get an error, but got: %v", err) 54 require.Equal(t, 3, len(config["mysql_user"]), "mysql_user config size should be equal to 3") 55 require.Equal(t, "localhost", config["mysql_user"][0].SourceHost, "SourceHost should be equal to localhost") 56 57 if len(config["mysql_user"][2].Groups) != 1 || config["mysql_user"][2].Groups[0] != "user_group" { 58 t.Fatalf("Groups should be equal to [\"user_group\"]") 59 } 60 61 jsonConfig = `{ 62 "mysql_user": [{"Password": "123", "UserData": "mysql_user_all", "InvalidKey": "oops"}] 63 }` 64 err = ParseConfig([]byte(jsonConfig), &config) 65 require.Error(t, err, "Invalid config should have errored, but didn't") 66 67 } 68 69 func TestValidateHashGetter(t *testing.T) { 70 jsonConfig := `{"mysql_user": [{"Password": "password", "UserData": "user.name", "Groups": ["user_group"]}]}` 71 72 auth := NewAuthServerStatic("", jsonConfig, 0) 73 defer auth.close() 74 ip := net.ParseIP("127.0.0.1") 75 addr := &net.IPAddr{IP: ip, Zone: ""} 76 77 salt, err := newSalt() 78 require.NoError(t, err, "error generating salt: %v", err) 79 80 scrambled := ScrambleMysqlNativePassword(salt, []byte("password")) 81 getter, err := auth.UserEntryWithHash(nil, salt, "mysql_user", scrambled, addr) 82 require.NoError(t, err, "error validating password: %v", err) 83 84 callerID := getter.Get() 85 require.Equal(t, "user.name", callerID.Username, "getter username incorrect, expected \"user.name\", got %v", callerID.Username) 86 87 if len(callerID.Groups) != 1 || callerID.Groups[0] != "user_group" { 88 t.Fatalf("getter groups incorrect, expected [\"user_group\"], got %v", callerID.Groups) 89 } 90 } 91 92 func TestHostMatcher(t *testing.T) { 93 ip := net.ParseIP("192.168.0.1") 94 addr := &net.TCPAddr{IP: ip, Port: 9999} 95 match := MatchSourceHost(net.Addr(addr), "") 96 require.True(t, match, "Should match any address when target is empty") 97 98 match = MatchSourceHost(net.Addr(addr), "localhost") 99 require.False(t, match, "Should not match address when target is localhost") 100 101 socket := &net.UnixAddr{Name: "unixSocket", Net: "1"} 102 match = MatchSourceHost(net.Addr(socket), "localhost") 103 require.True(t, match, "Should match socket when target is localhost") 104 105 } 106 107 func TestStaticConfigHUP(t *testing.T) { 108 tmpFile, err := os.CreateTemp("", "mysql_auth_server_static_file.json") 109 require.NoError(t, err, "couldn't create temp file: %v", err) 110 111 defer os.Remove(tmpFile.Name()) 112 113 oldStr := "str5" 114 jsonConfig := fmt.Sprintf("{\"%s\":[{\"Password\":\"%s\"}]}", oldStr, oldStr) 115 if err := os.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil { 116 t.Fatalf("couldn't write temp file: %v", err) 117 } 118 119 aStatic := NewAuthServerStatic(tmpFile.Name(), "", 0) 120 defer aStatic.close() 121 require.Equal(t, oldStr, aStatic.getEntries()[oldStr][0].Password, "%s's Password should still be '%s'", oldStr, oldStr) 122 123 hupTest(t, aStatic, tmpFile, oldStr, "str2") 124 hupTest(t, aStatic, tmpFile, "str2", "str3") // still handling the signal 125 126 mu.Lock() 127 defer mu.Unlock() 128 // delete registered Auth server 129 for auth := range authServers { 130 delete(authServers, auth) 131 } 132 } 133 134 func TestStaticConfigHUPWithRotation(t *testing.T) { 135 tmpFile, err := os.CreateTemp("", "mysql_auth_server_static_file.json") 136 require.NoError(t, err, "couldn't create temp file: %v", err) 137 138 defer os.Remove(tmpFile.Name()) 139 140 oldStr := "str1" 141 jsonConfig := fmt.Sprintf("{\"%s\":[{\"Password\":\"%s\"}]}", oldStr, oldStr) 142 if err := os.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil { 143 t.Fatalf("couldn't write temp file: %v", err) 144 } 145 146 aStatic := NewAuthServerStatic(tmpFile.Name(), "", 10*time.Millisecond) 147 defer aStatic.close() 148 require.Equal(t, oldStr, aStatic.getEntries()[oldStr][0].Password, "%s's Password should still be '%s'", oldStr, oldStr) 149 150 hupTestWithRotation(t, aStatic, tmpFile, oldStr, "str4") 151 hupTestWithRotation(t, aStatic, tmpFile, "str4", "str5") 152 } 153 154 func hupTest(t *testing.T, aStatic *AuthServerStatic, tmpFile *os.File, oldStr, newStr string) { 155 jsonConfig := fmt.Sprintf("{\"%s\":[{\"Password\":\"%s\"}]}", newStr, newStr) 156 if err := os.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil { 157 t.Fatalf("couldn't overwrite temp file: %v", err) 158 } 159 require.Equal(t, oldStr, aStatic.getEntries()[oldStr][0].Password, "%s's Password should still be '%s'", oldStr, oldStr) 160 161 syscall.Kill(syscall.Getpid(), syscall.SIGHUP) 162 time.Sleep(100 * time.Millisecond) 163 // wait for signal handler 164 require.Nil(t, aStatic.getEntries()[oldStr], "Should not have old %s after config reload", oldStr) 165 require.Equal(t, newStr, aStatic.getEntries()[newStr][0].Password, "%s's Password should be '%s'", newStr, newStr) 166 167 } 168 169 func hupTestWithRotation(t *testing.T, aStatic *AuthServerStatic, tmpFile *os.File, oldStr, newStr string) { 170 jsonConfig := fmt.Sprintf("{\"%s\":[{\"Password\":\"%s\"}]}", newStr, newStr) 171 if err := os.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil { 172 t.Fatalf("couldn't overwrite temp file: %v", err) 173 } 174 175 time.Sleep(20 * time.Millisecond) 176 // wait for signal handler 177 require.Nil(t, aStatic.getEntries()[oldStr], "Should not have old %s after config reload", oldStr) 178 require.Equal(t, newStr, aStatic.getEntries()[newStr][0].Password, "%s's Password should be '%s'", newStr, newStr) 179 180 } 181 182 func TestStaticPasswords(t *testing.T) { 183 jsonConfig := ` 184 { 185 "user01": [{ "Password": "user01" }], 186 "user02": [{ 187 "MysqlNativePassword": "*B3AD996B12F211BEA47A7C666CC136FB26DC96AF" 188 }], 189 "user03": [{ 190 "MysqlNativePassword": "*211E0153B172BAED4352D5E4628BD76731AF83E7", 191 "Password": "invalid" 192 }], 193 "user04": [ 194 { "MysqlNativePassword": "*668425423DB5193AF921380129F465A6425216D0" }, 195 { "Password": "password2" } 196 ] 197 }` 198 199 tests := []struct { 200 user string 201 password string 202 success bool 203 }{ 204 {"user01", "user01", true}, 205 {"user01", "password", false}, 206 {"user01", "", false}, 207 {"user02", "user02", true}, 208 {"user02", "password", false}, 209 {"user02", "", false}, 210 {"user03", "user03", true}, 211 {"user03", "password", false}, 212 {"user03", "invalid", false}, 213 {"user03", "", false}, 214 {"user04", "password1", true}, 215 {"user04", "password2", true}, 216 {"user04", "", false}, 217 {"userXX", "", false}, 218 {"userXX", "", false}, 219 {"", "", false}, 220 {"", "password", false}, 221 } 222 223 auth := NewAuthServerStatic("", jsonConfig, 0) 224 defer auth.close() 225 ip := net.ParseIP("127.0.0.1") 226 addr := &net.IPAddr{IP: ip, Zone: ""} 227 228 for _, c := range tests { 229 t.Run(fmt.Sprintf("%s-%s", c.user, c.password), func(t *testing.T) { 230 salt, err := newSalt() 231 require.NoError(t, err, "error generating salt: %v", err) 232 233 scrambled := ScrambleMysqlNativePassword(salt, []byte(c.password)) 234 _, err = auth.UserEntryWithHash(nil, salt, c.user, scrambled, addr) 235 236 if c.success { 237 require.NoError(t, err, "authentication should have succeeded: %v", err) 238 239 } else { 240 require.Error(t, err, "authentication should have failed") 241 242 } 243 }) 244 } 245 }