vitess.io/vitess@v0.16.2/go/mysql/auth_server_static_flaky_test.go (about)

     1  /*
     2  copyright 2019 The Vitess Authors.
     3  
     4  licensed under the apache license, version 2.0 (the "license");
     5  you may not use this file except in compliance with the license.
     6  you may obtain a copy of the license at
     7  
     8      http://www.apache.org/licenses/license-2.0
     9  
    10  unless required by applicable law or agreed to in writing, software
    11  distributed under the license is distributed on an "as is" basis,
    12  without warranties or conditions of any kind, either express or implied.
    13  see the license for the specific language governing permissions and
    14  limitations under the license.
    15  */
    16  
    17  package mysql
    18  
    19  import (
    20  	"fmt"
    21  	"net"
    22  	"os"
    23  	"syscall"
    24  	"testing"
    25  	"time"
    26  
    27  	"github.com/stretchr/testify/require"
    28  )
    29  
    30  // getEntries is a test-only method for AuthServerStatic.
    31  func (a *AuthServerStatic) getEntries() map[string][]*AuthServerStaticEntry {
    32  	a.mu.Lock()
    33  	defer a.mu.Unlock()
    34  	return a.entries
    35  }
    36  
    37  func TestJsonConfigParser(t *testing.T) {
    38  	// works with legacy format
    39  	config := make(map[string][]*AuthServerStaticEntry)
    40  	jsonConfig := "{\"mysql_user\":{\"Password\":\"123\", \"UserData\":\"dummy\"}, \"mysql_user_2\": {\"Password\": \"123\", \"UserData\": \"mysql_user_2\"}}"
    41  	err := ParseConfig([]byte(jsonConfig), &config)
    42  	require.NoError(t, err, "should not get an error, but got: %v", err)
    43  	require.Equal(t, 1, len(config["mysql_user"]), "mysql_user config size should be equal to 1")
    44  	require.Equal(t, 1, len(config["mysql_user_2"]), "mysql_user config size should be equal to 1")
    45  
    46  	// works with new format
    47  	jsonConfig = `{"mysql_user":[
    48  		{"Password":"123", "UserData":"dummy", "SourceHost": "localhost"},
    49  		{"Password": "123", "UserData": "mysql_user_all"},
    50  		{"Password": "456", "UserData": "mysql_user_with_groups", "Groups": ["user_group"]}
    51  	]}`
    52  	err = ParseConfig([]byte(jsonConfig), &config)
    53  	require.NoError(t, err, "should not get an error, but got: %v", err)
    54  	require.Equal(t, 3, len(config["mysql_user"]), "mysql_user config size should be equal to 3")
    55  	require.Equal(t, "localhost", config["mysql_user"][0].SourceHost, "SourceHost should be equal to localhost")
    56  
    57  	if len(config["mysql_user"][2].Groups) != 1 || config["mysql_user"][2].Groups[0] != "user_group" {
    58  		t.Fatalf("Groups should be equal to [\"user_group\"]")
    59  	}
    60  
    61  	jsonConfig = `{
    62  		"mysql_user": [{"Password": "123", "UserData": "mysql_user_all", "InvalidKey": "oops"}]
    63  	}`
    64  	err = ParseConfig([]byte(jsonConfig), &config)
    65  	require.Error(t, err, "Invalid config should have errored, but didn't")
    66  
    67  }
    68  
    69  func TestValidateHashGetter(t *testing.T) {
    70  	jsonConfig := `{"mysql_user": [{"Password": "password", "UserData": "user.name", "Groups": ["user_group"]}]}`
    71  
    72  	auth := NewAuthServerStatic("", jsonConfig, 0)
    73  	defer auth.close()
    74  	ip := net.ParseIP("127.0.0.1")
    75  	addr := &net.IPAddr{IP: ip, Zone: ""}
    76  
    77  	salt, err := newSalt()
    78  	require.NoError(t, err, "error generating salt: %v", err)
    79  
    80  	scrambled := ScrambleMysqlNativePassword(salt, []byte("password"))
    81  	getter, err := auth.UserEntryWithHash(nil, salt, "mysql_user", scrambled, addr)
    82  	require.NoError(t, err, "error validating password: %v", err)
    83  
    84  	callerID := getter.Get()
    85  	require.Equal(t, "user.name", callerID.Username, "getter username incorrect, expected \"user.name\", got %v", callerID.Username)
    86  
    87  	if len(callerID.Groups) != 1 || callerID.Groups[0] != "user_group" {
    88  		t.Fatalf("getter groups incorrect, expected [\"user_group\"], got %v", callerID.Groups)
    89  	}
    90  }
    91  
    92  func TestHostMatcher(t *testing.T) {
    93  	ip := net.ParseIP("192.168.0.1")
    94  	addr := &net.TCPAddr{IP: ip, Port: 9999}
    95  	match := MatchSourceHost(net.Addr(addr), "")
    96  	require.True(t, match, "Should match any address when target is empty")
    97  
    98  	match = MatchSourceHost(net.Addr(addr), "localhost")
    99  	require.False(t, match, "Should not match address when target is localhost")
   100  
   101  	socket := &net.UnixAddr{Name: "unixSocket", Net: "1"}
   102  	match = MatchSourceHost(net.Addr(socket), "localhost")
   103  	require.True(t, match, "Should match socket when target is localhost")
   104  
   105  }
   106  
   107  func TestStaticConfigHUP(t *testing.T) {
   108  	tmpFile, err := os.CreateTemp("", "mysql_auth_server_static_file.json")
   109  	require.NoError(t, err, "couldn't create temp file: %v", err)
   110  
   111  	defer os.Remove(tmpFile.Name())
   112  
   113  	oldStr := "str5"
   114  	jsonConfig := fmt.Sprintf("{\"%s\":[{\"Password\":\"%s\"}]}", oldStr, oldStr)
   115  	if err := os.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil {
   116  		t.Fatalf("couldn't write temp file: %v", err)
   117  	}
   118  
   119  	aStatic := NewAuthServerStatic(tmpFile.Name(), "", 0)
   120  	defer aStatic.close()
   121  	require.Equal(t, oldStr, aStatic.getEntries()[oldStr][0].Password, "%s's Password should still be '%s'", oldStr, oldStr)
   122  
   123  	hupTest(t, aStatic, tmpFile, oldStr, "str2")
   124  	hupTest(t, aStatic, tmpFile, "str2", "str3") // still handling the signal
   125  
   126  	mu.Lock()
   127  	defer mu.Unlock()
   128  	// delete registered Auth server
   129  	for auth := range authServers {
   130  		delete(authServers, auth)
   131  	}
   132  }
   133  
   134  func TestStaticConfigHUPWithRotation(t *testing.T) {
   135  	tmpFile, err := os.CreateTemp("", "mysql_auth_server_static_file.json")
   136  	require.NoError(t, err, "couldn't create temp file: %v", err)
   137  
   138  	defer os.Remove(tmpFile.Name())
   139  
   140  	oldStr := "str1"
   141  	jsonConfig := fmt.Sprintf("{\"%s\":[{\"Password\":\"%s\"}]}", oldStr, oldStr)
   142  	if err := os.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil {
   143  		t.Fatalf("couldn't write temp file: %v", err)
   144  	}
   145  
   146  	aStatic := NewAuthServerStatic(tmpFile.Name(), "", 10*time.Millisecond)
   147  	defer aStatic.close()
   148  	require.Equal(t, oldStr, aStatic.getEntries()[oldStr][0].Password, "%s's Password should still be '%s'", oldStr, oldStr)
   149  
   150  	hupTestWithRotation(t, aStatic, tmpFile, oldStr, "str4")
   151  	hupTestWithRotation(t, aStatic, tmpFile, "str4", "str5")
   152  }
   153  
   154  func hupTest(t *testing.T, aStatic *AuthServerStatic, tmpFile *os.File, oldStr, newStr string) {
   155  	jsonConfig := fmt.Sprintf("{\"%s\":[{\"Password\":\"%s\"}]}", newStr, newStr)
   156  	if err := os.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil {
   157  		t.Fatalf("couldn't overwrite temp file: %v", err)
   158  	}
   159  	require.Equal(t, oldStr, aStatic.getEntries()[oldStr][0].Password, "%s's Password should still be '%s'", oldStr, oldStr)
   160  
   161  	syscall.Kill(syscall.Getpid(), syscall.SIGHUP)
   162  	time.Sleep(100 * time.Millisecond)
   163  	// wait for signal handler
   164  	require.Nil(t, aStatic.getEntries()[oldStr], "Should not have old %s after config reload", oldStr)
   165  	require.Equal(t, newStr, aStatic.getEntries()[newStr][0].Password, "%s's Password should be '%s'", newStr, newStr)
   166  
   167  }
   168  
   169  func hupTestWithRotation(t *testing.T, aStatic *AuthServerStatic, tmpFile *os.File, oldStr, newStr string) {
   170  	jsonConfig := fmt.Sprintf("{\"%s\":[{\"Password\":\"%s\"}]}", newStr, newStr)
   171  	if err := os.WriteFile(tmpFile.Name(), []byte(jsonConfig), 0600); err != nil {
   172  		t.Fatalf("couldn't overwrite temp file: %v", err)
   173  	}
   174  
   175  	time.Sleep(20 * time.Millisecond)
   176  	// wait for signal handler
   177  	require.Nil(t, aStatic.getEntries()[oldStr], "Should not have old %s after config reload", oldStr)
   178  	require.Equal(t, newStr, aStatic.getEntries()[newStr][0].Password, "%s's Password should be '%s'", newStr, newStr)
   179  
   180  }
   181  
   182  func TestStaticPasswords(t *testing.T) {
   183  	jsonConfig := `
   184  {
   185  	"user01": [{ "Password": "user01" }],
   186  	"user02": [{
   187  		"MysqlNativePassword": "*B3AD996B12F211BEA47A7C666CC136FB26DC96AF"
   188  	}],
   189  	"user03": [{
   190  		"MysqlNativePassword": "*211E0153B172BAED4352D5E4628BD76731AF83E7",
   191  		"Password": "invalid"
   192  	}],
   193  	"user04": [
   194  		{ "MysqlNativePassword": "*668425423DB5193AF921380129F465A6425216D0" },
   195  		{ "Password": "password2" }
   196  	]
   197  }`
   198  
   199  	tests := []struct {
   200  		user     string
   201  		password string
   202  		success  bool
   203  	}{
   204  		{"user01", "user01", true},
   205  		{"user01", "password", false},
   206  		{"user01", "", false},
   207  		{"user02", "user02", true},
   208  		{"user02", "password", false},
   209  		{"user02", "", false},
   210  		{"user03", "user03", true},
   211  		{"user03", "password", false},
   212  		{"user03", "invalid", false},
   213  		{"user03", "", false},
   214  		{"user04", "password1", true},
   215  		{"user04", "password2", true},
   216  		{"user04", "", false},
   217  		{"userXX", "", false},
   218  		{"userXX", "", false},
   219  		{"", "", false},
   220  		{"", "password", false},
   221  	}
   222  
   223  	auth := NewAuthServerStatic("", jsonConfig, 0)
   224  	defer auth.close()
   225  	ip := net.ParseIP("127.0.0.1")
   226  	addr := &net.IPAddr{IP: ip, Zone: ""}
   227  
   228  	for _, c := range tests {
   229  		t.Run(fmt.Sprintf("%s-%s", c.user, c.password), func(t *testing.T) {
   230  			salt, err := newSalt()
   231  			require.NoError(t, err, "error generating salt: %v", err)
   232  
   233  			scrambled := ScrambleMysqlNativePassword(salt, []byte(c.password))
   234  			_, err = auth.UserEntryWithHash(nil, salt, c.user, scrambled, addr)
   235  
   236  			if c.success {
   237  				require.NoError(t, err, "authentication should have succeeded: %v", err)
   238  
   239  			} else {
   240  				require.Error(t, err, "authentication should have failed")
   241  
   242  			}
   243  		})
   244  	}
   245  }