github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/dm/config/source_config_test.go (about)

     1  // Copyright 2019 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package config
    15  
    16  import (
    17  	"context"
    18  	"crypto/rand"
    19  	"errors"
    20  	"fmt"
    21  	"os"
    22  	"path"
    23  	"reflect"
    24  	"strings"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/DATA-DOG/go-sqlmock"
    29  	"github.com/go-mysql-org/go-mysql/mysql"
    30  	"github.com/pingcap/tiflow/dm/config/dbconfig"
    31  	"github.com/pingcap/tiflow/dm/pkg/conn"
    32  	tcontext "github.com/pingcap/tiflow/dm/pkg/context"
    33  	"github.com/pingcap/tiflow/dm/pkg/encrypt"
    34  	"github.com/pingcap/tiflow/dm/pkg/utils"
    35  	bf "github.com/pingcap/tiflow/pkg/binlog-filter"
    36  	"github.com/stretchr/testify/require"
    37  	"gopkg.in/yaml.v2"
    38  )
    39  
    40  func TestConfigFunctions(t *testing.T) {
    41  	cfg, err := SourceCfgFromYaml(SampleSourceConfig)
    42  	require.NoError(t, err)
    43  	cfg.RelayDir = "./xx"
    44  	require.Equal(t, uint32(101), cfg.ServerID)
    45  
    46  	// test clone
    47  	clone1 := cfg.Clone()
    48  	require.Equal(t, cfg, clone1)
    49  	clone1.ServerID = 100
    50  	require.Equal(t, uint32(101), cfg.ServerID)
    51  
    52  	// test format
    53  	require.Contains(t, cfg.String(), `server-id":101`)
    54  	tomlStr, err := clone1.Toml()
    55  	require.NoError(t, err)
    56  	require.Contains(t, tomlStr, `server-id = 100`)
    57  	yamlStr, err := clone1.Yaml()
    58  	require.NoError(t, err)
    59  	require.Contains(t, yamlStr, `server-id: 100`)
    60  	originCfgStr, err := cfg.Toml()
    61  	require.NoError(t, err)
    62  	require.Contains(t, originCfgStr, `server-id = 101`)
    63  	originCfgYamlStr, err := cfg.Yaml()
    64  	require.NoError(t, err)
    65  	require.Contains(t, originCfgYamlStr, `server-id: 101`)
    66  
    67  	// test update config file and reload
    68  	require.NoError(t, cfg.FromToml(tomlStr))
    69  	require.Equal(t, uint32(100), cfg.ServerID)
    70  	cfg1, err := SourceCfgFromYaml(yamlStr)
    71  	require.NoError(t, err)
    72  	require.Equal(t, uint32(100), cfg1.ServerID)
    73  	cfg.Filters = []*bf.BinlogEventRule{}
    74  	cfg.Tracer = map[string]interface{}{}
    75  
    76  	var cfg2 SourceConfig
    77  	require.NoError(t, cfg2.FromToml(originCfgStr))
    78  	require.Equal(t, uint32(101), cfg2.ServerID)
    79  
    80  	cfg3, err := SourceCfgFromYaml(originCfgYamlStr)
    81  	require.NoError(t, err)
    82  	require.Equal(t, uint32(101), cfg3.ServerID)
    83  
    84  	cfg.From.Password = "xxx"
    85  	cfg.GetDecryptedClone()
    86  
    87  	cfg.From.Password = ""
    88  	clone3 := cfg.GetDecryptedClone()
    89  	require.Equal(t, cfg, clone3)
    90  
    91  	// test toml and parse again
    92  	clone4 := cfg.Clone()
    93  	clone4.Checker.CheckEnable = true
    94  	clone4.Checker.BackoffRollback = Duration{time.Minute * 5}
    95  	clone4.Checker.BackoffMax = Duration{time.Minute * 5}
    96  	clone4toml, err := clone4.Toml()
    97  	require.NoError(t, err)
    98  	require.Contains(t, clone4toml, `backoff-rollback = "5m`)
    99  	require.Contains(t, clone4toml, `backoff-max = "5m`)
   100  
   101  	var clone5 SourceConfig
   102  	require.NoError(t, clone5.FromToml(clone4toml))
   103  	require.Equal(t, *clone4, clone5)
   104  	clone4yaml, err := clone4.Yaml()
   105  	require.NoError(t, err)
   106  	require.Contains(t, clone4yaml, `backoff-rollback: 5m`)
   107  	require.Contains(t, clone4yaml, `backoff-max: 5m`)
   108  
   109  	clone6, err := SourceCfgFromYaml(clone4yaml)
   110  	require.NoError(t, err)
   111  	clone6.From.Session = nil
   112  	require.Equal(t, clone4, clone6)
   113  
   114  	// test invalid config
   115  	dir2 := t.TempDir()
   116  	configFile := path.Join(dir2, "dm-worker-invalid.toml")
   117  	configContent := []byte(`
   118  source-id: haha
   119  aaa: xxx
   120  `)
   121  	err = os.WriteFile(configFile, configContent, 0o644)
   122  	require.NoError(t, err)
   123  	_, err = LoadFromFile(configFile)
   124  	require.ErrorContains(t, err, "field aaa not found in type config.SourceConfig")
   125  }
   126  
   127  func TestConfigVerify(t *testing.T) {
   128  	key := make([]byte, 32)
   129  	_, err := rand.Read(key)
   130  	require.NoError(t, err)
   131  
   132  	t.Cleanup(func() {
   133  		encrypt.InitCipher(nil)
   134  	})
   135  	encrypt.InitCipher(key)
   136  	encryptedPass, err := utils.Encrypt("this is password")
   137  	require.NoError(t, err)
   138  
   139  	newConfig := func() *SourceConfig {
   140  		cfg, err := SourceCfgFromYaml(SampleSourceConfig)
   141  		require.NoError(t, err)
   142  		cfg.RelayDir = "./xx"
   143  		return cfg
   144  	}
   145  	testCases := []struct {
   146  		genFunc        func() *SourceConfig
   147  		expectPassword string
   148  		errorFormat    string
   149  	}{
   150  		{
   151  			func() *SourceConfig {
   152  				return newConfig()
   153  			},
   154  			"123456",
   155  			"",
   156  		},
   157  		{
   158  			func() *SourceConfig {
   159  				cfg := newConfig()
   160  				cfg.SourceID = ""
   161  				return cfg
   162  			},
   163  			"123456",
   164  			".*dm-worker should bind a non-empty source ID which represents a MySQL/MariaDB instance or a replica group.*",
   165  		},
   166  		{
   167  			func() *SourceConfig {
   168  				cfg := newConfig()
   169  				cfg.SourceID = "source-id-length-more-than-thirty-two"
   170  				return cfg
   171  			},
   172  			"123456",
   173  			fmt.Sprintf(".*the length of source ID .* is more than max allowed value %d.*", MaxSourceIDLength),
   174  		},
   175  		{
   176  			func() *SourceConfig {
   177  				cfg := newConfig()
   178  				cfg.EnableRelay = true
   179  				cfg.RelayBinLogName = "mysql-binlog"
   180  				return cfg
   181  			},
   182  			"123456",
   183  			".*not valid.*",
   184  		},
   185  		{
   186  			// after support `start-relay`, we always check Relay related config
   187  			func() *SourceConfig {
   188  				cfg := newConfig()
   189  				cfg.RelayBinLogName = "mysql-binlog"
   190  				return cfg
   191  			},
   192  			"123456",
   193  			".*not valid.*",
   194  		},
   195  		{
   196  			func() *SourceConfig {
   197  				cfg := newConfig()
   198  				cfg.EnableRelay = true
   199  				cfg.RelayBinlogGTID = "9afe121c-40c2-11e9-9ec7-0242ac110002:1-rtc"
   200  				return cfg
   201  			},
   202  			"123456",
   203  			".*relay-binlog-gtid 9afe121c-40c2-11e9-9ec7-0242ac110002:1-rtc:.*",
   204  		},
   205  		{
   206  			func() *SourceConfig {
   207  				cfg := newConfig()
   208  				cfg.From.Password = "not-encrypt"
   209  				return cfg
   210  			},
   211  			"not-encrypt",
   212  			"",
   213  		},
   214  		{
   215  			func() *SourceConfig {
   216  				cfg := newConfig()
   217  				cfg.From.Password = "" // password empty
   218  				return cfg
   219  			},
   220  			"",
   221  			"",
   222  		},
   223  		{
   224  			func() *SourceConfig {
   225  				cfg := newConfig()
   226  				cfg.From.Password = "aaaaaa" // plaintext password
   227  				return cfg
   228  			},
   229  			"aaaaaa",
   230  			"",
   231  		},
   232  		{
   233  			func() *SourceConfig {
   234  				cfg := newConfig()
   235  				cfg.From.Password = encryptedPass
   236  				return cfg
   237  			},
   238  			"this is password",
   239  			"",
   240  		},
   241  	}
   242  
   243  	runCasesFn := func() {
   244  		for _, tc := range testCases {
   245  			cfg := tc.genFunc()
   246  			oldPass := cfg.From.Password
   247  			err := cfg.Verify()
   248  			if tc.errorFormat != "" {
   249  				require.Error(t, err)
   250  				lines := strings.Split(err.Error(), "\n")
   251  				require.Regexp(t, tc.errorFormat, lines[0])
   252  			} else {
   253  				require.NoError(t, err)
   254  			}
   255  			newCfg := cfg.GetDecryptedClone()
   256  			if encrypt.IsInitialized() {
   257  				require.Equal(t, tc.expectPassword, newCfg.From.Password)
   258  			} else {
   259  				require.Equal(t, oldPass, newCfg.From.Password)
   260  			}
   261  		}
   262  	}
   263  
   264  	require.True(t, encrypt.IsInitialized())
   265  	runCasesFn()
   266  	encrypt.InitCipher(nil)
   267  	require.False(t, encrypt.IsInitialized())
   268  	runCasesFn()
   269  }
   270  
   271  func TestSourceConfigForDowngrade(t *testing.T) {
   272  	cfg, err := SourceCfgFromYaml(SampleSourceConfig)
   273  	require.NoError(t, err)
   274  
   275  	// make sure all new field were added
   276  	cfgForDowngrade := NewSourceConfigForDowngrade(cfg)
   277  	cfgReflect := reflect.Indirect(reflect.ValueOf(cfg))
   278  	cfgForDowngradeReflect := reflect.Indirect(reflect.ValueOf(cfgForDowngrade))
   279  	// auto-fix-gtid, meta-dir are not written when downgrade
   280  	require.Equal(t, cfgForDowngradeReflect.NumField()+2, cfgReflect.NumField())
   281  
   282  	// make sure all field were copied
   283  	cfgForClone := &SourceConfigForDowngrade{}
   284  	Clone(cfgForClone, cfg)
   285  	require.Equal(t, cfgForClone, cfgForDowngrade)
   286  }
   287  
   288  func subtestFlavor(t *testing.T, cfg *SourceConfig, sqlInfo, expectedFlavor, expectedError string) {
   289  	t.Helper()
   290  
   291  	cfg.Flavor = ""
   292  	db, mock, err := sqlmock.New()
   293  	require.NoError(t, err)
   294  	mock.ExpectQuery("SHOW GLOBAL VARIABLES LIKE 'version';").
   295  		WillReturnRows(sqlmock.NewRows([]string{"Variable_name", "Value"}).
   296  			AddRow("version", sqlInfo))
   297  	mock.ExpectClose()
   298  
   299  	err = cfg.AdjustFlavor(context.Background(), conn.NewBaseDBForTest(db))
   300  	if expectedError == "" {
   301  		require.NoError(t, err)
   302  		require.Equal(t, expectedFlavor, cfg.Flavor)
   303  	} else {
   304  		require.ErrorContains(t, err, expectedError)
   305  	}
   306  }
   307  
   308  func TestAdjustFlavor(t *testing.T) {
   309  	cfg, err := SourceCfgFromYaml(SampleSourceConfig)
   310  	require.NoError(t, err)
   311  	cfg.RelayDir = "./xx"
   312  
   313  	cfg.Flavor = "mariadb"
   314  	err = cfg.AdjustFlavor(context.Background(), nil)
   315  	require.NoError(t, err)
   316  	require.Equal(t, mysql.MariaDBFlavor, cfg.Flavor)
   317  	cfg.Flavor = "MongoDB"
   318  	err = cfg.AdjustFlavor(context.Background(), nil)
   319  	require.ErrorContains(t, err, "flavor MongoDB not supported")
   320  
   321  	subtestFlavor(t, cfg, "10.4.8-MariaDB-1:10.4.8+maria~bionic", mysql.MariaDBFlavor, "")
   322  	subtestFlavor(t, cfg, "5.7.26-log", mysql.MySQLFlavor, "")
   323  }
   324  
   325  func TestAdjustServerID(t *testing.T) {
   326  	originGetAllServerIDFunc := getAllServerIDFunc
   327  	defer func() {
   328  		getAllServerIDFunc = originGetAllServerIDFunc
   329  	}()
   330  	getAllServerIDFunc = getMockServerIDs
   331  
   332  	cfg, err := SourceCfgFromYaml(SampleSourceConfig)
   333  	require.NoError(t, err)
   334  	cfg.RelayDir = "./xx"
   335  
   336  	require.NoError(t, cfg.AdjustServerID(context.Background(), nil))
   337  	require.Equal(t, uint32(101), cfg.ServerID)
   338  
   339  	cfg.ServerID = 0
   340  	require.NoError(t, cfg.AdjustServerID(context.Background(), nil))
   341  	require.NotEqual(t, 0, cfg.ServerID)
   342  }
   343  
   344  func TestAdjustServerIDFallback(t *testing.T) {
   345  	db, mock, err := sqlmock.New()
   346  	require.NoError(t, err)
   347  	mock.ExpectQuery("SHOW SLAVE HOSTS").
   348  		WillReturnError(errors.New("mysql error 1227: Access denied; you need (at least one of) the REPLICATION SLAVE privilege(s) for this operation"))
   349  	mock.ExpectClose()
   350  
   351  	cfg, err := SourceCfgFromYaml(SampleSourceConfig)
   352  	require.NoError(t, err)
   353  	cfg.ServerID = 0
   354  
   355  	err = cfg.AdjustServerID(context.Background(), conn.NewBaseDBForTest(db))
   356  	require.NoError(t, err)
   357  	require.NotEqual(t, 0, cfg.ServerID)
   358  }
   359  
   360  func getMockServerIDs(ctx *tcontext.Context, db *conn.BaseDB) (map[uint32]struct{}, error) {
   361  	return map[uint32]struct{}{
   362  		1: {},
   363  		2: {},
   364  	}, nil
   365  }
   366  
   367  func TestAdjustCaseSensitive(t *testing.T) {
   368  	cfg, err := SourceCfgFromYaml(SampleSourceConfig)
   369  	require.NoError(t, err)
   370  
   371  	db, mock, err := sqlmock.New()
   372  	require.NoError(t, err)
   373  
   374  	mock.ExpectQuery("SELECT @@lower_case_table_names;").
   375  		WillReturnRows(sqlmock.NewRows([]string{"@@lower_case_table_names"}).AddRow(conn.LCTableNamesMixed))
   376  	require.NoError(t, cfg.AdjustCaseSensitive(context.Background(), conn.NewBaseDBForTest(db)))
   377  	require.False(t, cfg.CaseSensitive)
   378  
   379  	mock.ExpectQuery("SELECT @@lower_case_table_names;").
   380  		WillReturnRows(sqlmock.NewRows([]string{"@@lower_case_table_names"}).AddRow(conn.LCTableNamesSensitive))
   381  	require.NoError(t, cfg.AdjustCaseSensitive(context.Background(), conn.NewBaseDBForTest(db)))
   382  	require.True(t, cfg.CaseSensitive)
   383  
   384  	require.NoError(t, mock.ExpectationsWereMet())
   385  }
   386  
   387  func TestEmbedSampleFile(t *testing.T) {
   388  	data, err := os.ReadFile("./source.yaml")
   389  	require.NoError(t, err)
   390  	require.Equal(t, SampleSourceConfig, string(data))
   391  }
   392  
   393  func TestSourceYamlForDowngrade(t *testing.T) {
   394  	originCfg := SourceConfig{
   395  		SourceID: "mysql-3306",
   396  		From: dbconfig.DBConfig{
   397  			Password: "123456",
   398  		},
   399  	}
   400  	// when secret key is empty, the password should be kept
   401  	content, err := originCfg.YamlForDowngrade()
   402  	require.NoError(t, err)
   403  	newCfg := &SourceConfig{}
   404  	require.NoError(t, yaml.UnmarshalStrict([]byte(content), newCfg))
   405  	require.Equal(t, originCfg.From.Password, newCfg.From.Password)
   406  
   407  	// when secret key is not empty, the password should be encrypted
   408  	key := make([]byte, 32)
   409  	_, err = rand.Read(key)
   410  	require.NoError(t, err)
   411  	t.Cleanup(func() {
   412  		encrypt.InitCipher(nil)
   413  	})
   414  	encrypt.InitCipher(key)
   415  	content, err = originCfg.YamlForDowngrade()
   416  	require.NoError(t, err)
   417  	newCfg = &SourceConfig{}
   418  	require.NoError(t, yaml.UnmarshalStrict([]byte(content), newCfg))
   419  	require.NotEqual(t, originCfg.From.Password, newCfg.From.Password)
   420  	decryptedPass, err := utils.Decrypt(newCfg.From.Password)
   421  	require.NoError(t, err)
   422  	require.Equal(t, originCfg.From.Password, decryptedPass)
   423  }