github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/dm/config/subtask_test.go (about)

     1  // Copyright 2019 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package config
    15  
    16  import (
    17  	"context"
    18  	"crypto/rand"
    19  	"reflect"
    20  	"testing"
    21  
    22  	"github.com/DATA-DOG/go-sqlmock"
    23  	"github.com/pingcap/tidb/pkg/util/filter"
    24  	"github.com/pingcap/tiflow/dm/config/dbconfig"
    25  	"github.com/pingcap/tiflow/dm/config/security"
    26  	"github.com/pingcap/tiflow/dm/pkg/encrypt"
    27  	"github.com/pingcap/tiflow/dm/pkg/terror"
    28  	"github.com/pingcap/tiflow/dm/pkg/utils"
    29  	"github.com/stretchr/testify/require"
    30  )
    31  
    32  func TestSubTask(t *testing.T) {
    33  	key := make([]byte, 32)
    34  	_, err := rand.Read(key)
    35  	require.NoError(t, err)
    36  
    37  	t.Cleanup(func() {
    38  		encrypt.InitCipher(nil)
    39  	})
    40  	encrypt.InitCipher(key)
    41  	encryptedPass, err := utils.Encrypt("1234")
    42  	require.NoError(t, err)
    43  	require.NotEqual(t, "1234", encryptedPass)
    44  	cfg := &SubTaskConfig{
    45  		Name:            "test-task",
    46  		IsSharding:      true,
    47  		ShardMode:       "optimistic",
    48  		SourceID:        "mysql-instance-01",
    49  		OnlineDDL:       false,
    50  		OnlineDDLScheme: PT,
    51  		From: dbconfig.DBConfig{
    52  			Host:     "127.0.0.1",
    53  			Port:     3306,
    54  			User:     "root",
    55  			Password: encryptedPass,
    56  		},
    57  		To: dbconfig.DBConfig{
    58  			Host:     "127.0.0.1",
    59  			Port:     4306,
    60  			User:     "root",
    61  			Password: "",
    62  		},
    63  	}
    64  	cfg.From.Adjust()
    65  	cfg.To.Adjust()
    66  
    67  	clone1, err := cfg.Clone()
    68  	require.NoError(t, err)
    69  	require.Equal(t, cfg, clone1)
    70  
    71  	clone1.From.Password = "1234"
    72  	clone2, err := cfg.DecryptedClone()
    73  	require.NoError(t, err)
    74  	require.Equal(t, clone1, clone2)
    75  
    76  	cfg.From.Password = "xxx"
    77  	_, err = cfg.DecryptedClone()
    78  	require.NoError(t, err)
    79  	err = cfg.Adjust(true)
    80  	require.NoError(t, err)
    81  	require.True(t, cfg.OnlineDDL)
    82  	err = cfg.Adjust(false)
    83  	require.NoError(t, err)
    84  
    85  	cfg.From.Password = ""
    86  	clone3, err := cfg.DecryptedClone()
    87  	require.NoError(t, err)
    88  	require.Equal(t, cfg, clone3)
    89  
    90  	err = cfg.Adjust(true)
    91  	require.NoError(t, err)
    92  
    93  	cfg.ValidatorCfg = ValidatorConfig{Mode: ValidationFast}
    94  	err = cfg.Adjust(true)
    95  	require.NoError(t, err)
    96  
    97  	cfg.ValidatorCfg = ValidatorConfig{Mode: "invalid-mode"}
    98  	err = cfg.Adjust(true)
    99  	require.True(t, terror.ErrConfigValidationMode.Equal(err))
   100  }
   101  
   102  func TestSubTaskAdjustFail(t *testing.T) {
   103  	newSubTaskConfig := func() *SubTaskConfig {
   104  		return &SubTaskConfig{
   105  			Name:      "test-task",
   106  			SourceID:  "mysql-instance-01",
   107  			OnlineDDL: true,
   108  			From: dbconfig.DBConfig{
   109  				Host:     "127.0.0.1",
   110  				Port:     3306,
   111  				User:     "root",
   112  				Password: "Up8156jArvIPymkVC+5LxkAT6rek",
   113  			},
   114  			To: dbconfig.DBConfig{
   115  				Host:     "127.0.0.1",
   116  				Port:     4306,
   117  				User:     "root",
   118  				Password: "",
   119  			},
   120  		}
   121  	}
   122  	testCases := []struct {
   123  		genFunc func() *SubTaskConfig
   124  		errMsg  string
   125  	}{
   126  		{
   127  			func() *SubTaskConfig {
   128  				cfg := newSubTaskConfig()
   129  				cfg.Name = ""
   130  				return cfg
   131  			},
   132  			"Message: task name should not be empty",
   133  		},
   134  		{
   135  			func() *SubTaskConfig {
   136  				cfg := newSubTaskConfig()
   137  				cfg.SourceID = ""
   138  				return cfg
   139  			},
   140  			"Message: empty source-id not valid",
   141  		},
   142  		{
   143  			func() *SubTaskConfig {
   144  				cfg := newSubTaskConfig()
   145  				cfg.SourceID = "source-id-length-more-than-thirty-two"
   146  				return cfg
   147  			},
   148  			"Message: too long source-id not valid",
   149  		},
   150  		{
   151  			func() *SubTaskConfig {
   152  				cfg := newSubTaskConfig()
   153  				cfg.ShardMode = "invalid-shard-mode"
   154  				return cfg
   155  			},
   156  			"Message: shard mode invalid-shard-mode not supported",
   157  		},
   158  		{
   159  			func() *SubTaskConfig {
   160  				cfg := newSubTaskConfig()
   161  				cfg.OnlineDDLScheme = "rtc"
   162  				return cfg
   163  			},
   164  			"Message: online scheme rtc not supported",
   165  		},
   166  	}
   167  
   168  	for _, tc := range testCases {
   169  		cfg := tc.genFunc()
   170  		err := cfg.Adjust(true)
   171  		require.ErrorContains(t, err, tc.errMsg)
   172  	}
   173  }
   174  
   175  func TestSubTaskBlockAllowList(t *testing.T) {
   176  	filterRules1 := &filter.Rules{
   177  		DoDBs: []string{"s1"},
   178  	}
   179  
   180  	filterRules2 := &filter.Rules{
   181  		DoDBs: []string{"s2"},
   182  	}
   183  
   184  	cfg := &SubTaskConfig{
   185  		Name:     "test",
   186  		SourceID: "source-1",
   187  		BWList:   filterRules1,
   188  	}
   189  
   190  	// BAList is nil, will set BAList = BWList
   191  	err := cfg.Adjust(false)
   192  	require.NoError(t, err)
   193  	require.Equal(t, filterRules1, cfg.BAList)
   194  
   195  	// BAList is not nil, will not update it
   196  	cfg.BAList = filterRules2
   197  	err = cfg.Adjust(false)
   198  	require.NoError(t, err)
   199  	require.Equal(t, filterRules2, cfg.BAList)
   200  }
   201  
   202  func TestSubTaskAdjustLoaderS3Dir(t *testing.T) {
   203  	cfg := &SubTaskConfig{
   204  		Name:     "test",
   205  		SourceID: "source-1",
   206  		Mode:     ModeAll,
   207  	}
   208  
   209  	// default loader
   210  	cfg.LoaderConfig = DefaultLoaderConfig()
   211  	err := cfg.Adjust(false)
   212  	require.NoError(t, err)
   213  	require.Equal(t, defaultDir+"."+cfg.Name, cfg.LoaderConfig.Dir)
   214  
   215  	// file
   216  	cfg.LoaderConfig = LoaderConfig{
   217  		PoolSize:   defaultPoolSize,
   218  		Dir:        "file:///tmp/storage",
   219  		ImportMode: LoadModeSQL,
   220  	}
   221  	err = cfg.Adjust(false)
   222  	require.NoError(t, err)
   223  	require.Equal(t, "file:///tmp/storage"+"."+cfg.Name, cfg.LoaderConfig.Dir)
   224  
   225  	cfg.LoaderConfig = LoaderConfig{
   226  		PoolSize:   defaultPoolSize,
   227  		Dir:        "./dump_data",
   228  		ImportMode: LoadModeSQL,
   229  	}
   230  	err = cfg.Adjust(false)
   231  	require.NoError(t, err)
   232  	require.Equal(t, "./dump_data"+"."+cfg.Name, cfg.LoaderConfig.Dir)
   233  
   234  	// s3
   235  	cfg.LoaderConfig = LoaderConfig{
   236  		PoolSize:   defaultPoolSize,
   237  		Dir:        "s3://bucket2/prefix",
   238  		ImportMode: LoadModeSQL,
   239  	}
   240  	err = cfg.Adjust(false)
   241  	require.NoError(t, err)
   242  	require.Equal(t, "s3://bucket2/prefix"+"/"+cfg.Name+"."+cfg.SourceID, cfg.LoaderConfig.Dir)
   243  
   244  	cfg.LoaderConfig = LoaderConfig{
   245  		PoolSize:   defaultPoolSize,
   246  		Dir:        "s3://bucket3/prefix/path?endpoint=https://127.0.0.1:9000&force_path_style=0&SSE=aws:kms&sse-kms-key-id=TestKey&xyz=abc",
   247  		ImportMode: LoadModeSQL,
   248  	}
   249  	err = cfg.Adjust(false)
   250  	require.NoError(t, err)
   251  	require.Equal(t, "s3://bucket3/prefix/path/"+cfg.Name+"."+cfg.SourceID+"?endpoint=https://127.0.0.1:9000&force_path_style=0&SSE=aws:kms&sse-kms-key-id=TestKey&xyz=abc", cfg.LoaderConfig.Dir)
   252  
   253  	// invaild dir
   254  	cfg.LoaderConfig = LoaderConfig{
   255  		PoolSize:   defaultPoolSize,
   256  		Dir:        "1invalid:",
   257  		ImportMode: LoadModeSQL,
   258  	}
   259  	err = cfg.Adjust(false)
   260  	require.ErrorContains(t, err, "Message: loader's dir 1invalid: is invalid")
   261  
   262  	// use loader and not s3
   263  	cfg.LoaderConfig = LoaderConfig{
   264  		PoolSize:   defaultPoolSize,
   265  		Dir:        "file:///tmp/storage",
   266  		ImportMode: LoadModeSQL,
   267  	}
   268  	err = cfg.Adjust(false)
   269  	require.NoError(t, err)
   270  	require.Equal(t, "file:///tmp/storage."+cfg.Name, cfg.LoaderConfig.Dir)
   271  
   272  	cfg.LoaderConfig = LoaderConfig{
   273  		PoolSize:   defaultPoolSize,
   274  		Dir:        "./dumpdir",
   275  		ImportMode: LoadModeSQL,
   276  	}
   277  	err = cfg.Adjust(false)
   278  	require.NoError(t, err)
   279  	require.Equal(t, "./dumpdir."+cfg.Name, cfg.LoaderConfig.Dir)
   280  
   281  	// use loader and s3
   282  	cfg.LoaderConfig = LoaderConfig{
   283  		PoolSize:   defaultPoolSize,
   284  		Dir:        "s3://bucket2/prefix",
   285  		ImportMode: LoadModeLoader,
   286  	}
   287  	err = cfg.Adjust(false)
   288  	require.ErrorContains(t, err, "Message: loader's dir s3://bucket2/prefix is s3 dir, but s3 is not supported")
   289  
   290  	// not all or full mode
   291  	cfg.Mode = ModeIncrement
   292  	cfg.LoaderConfig = LoaderConfig{
   293  		PoolSize:   defaultPoolSize,
   294  		Dir:        "1invalid:",
   295  		ImportMode: LoadModeSQL,
   296  	}
   297  	err = cfg.Adjust(false)
   298  	require.NoError(t, err)
   299  	require.Equal(t, "1invalid:", cfg.LoaderConfig.Dir)
   300  }
   301  
   302  func TestDBConfigClone(t *testing.T) {
   303  	a := &dbconfig.DBConfig{
   304  		Host:     "127.0.0.1",
   305  		Port:     4306,
   306  		User:     "root",
   307  		Password: "123",
   308  		Session:  map[string]string{"1": "1"},
   309  		RawDBCfg: dbconfig.DefaultRawDBConfig(),
   310  	}
   311  
   312  	// When add new fields, also update this value
   313  	require.Equal(t, 9, reflect.Indirect(reflect.ValueOf(a)).NumField())
   314  
   315  	b := a.Clone()
   316  	require.Equal(t, a, b)
   317  	require.NotSame(t, a.RawDBCfg, b.RawDBCfg)
   318  
   319  	a.RawDBCfg.MaxIdleConns = 123
   320  	require.NotEqual(t, a, b)
   321  
   322  	packet := 1
   323  	a.MaxAllowedPacket = &packet
   324  	b = a.Clone()
   325  	require.Equal(t, a, b)
   326  	require.NotSame(t, a.MaxAllowedPacket, b.MaxAllowedPacket)
   327  
   328  	a.Session["2"] = "2"
   329  	require.NotEqual(t, a, b)
   330  
   331  	a.RawDBCfg = nil
   332  	a.Security = &security.Security{}
   333  	b = a.Clone()
   334  	require.Equal(t, a, b)
   335  	require.NotSame(t, a.Security, b.Security)
   336  }
   337  
   338  func TestFetchTZSetting(t *testing.T) {
   339  	db, mock, err := sqlmock.New()
   340  	require.NoError(t, err)
   341  
   342  	mock.ExpectQuery("SELECT cast\\(TIMEDIFF\\(NOW\\(6\\), UTC_TIMESTAMP\\(6\\)\\) as time\\);").
   343  		WillReturnRows(mock.NewRows([]string{""}).AddRow("01:00:00"))
   344  	tz, err := FetchTimeZoneSetting(context.Background(), db)
   345  	require.NoError(t, err)
   346  	require.Equal(t, "+01:00", tz)
   347  }