github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/dm/config/security_test.go (about) 1 // Copyright 2021 PingCAP, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package config 15 16 import ( 17 "bytes" 18 "fmt" 19 "os" 20 "path" 21 "reflect" 22 "testing" 23 24 "github.com/pingcap/tiflow/dm/config/security" 25 "github.com/stretchr/testify/require" 26 "github.com/stretchr/testify/suite" 27 ) 28 29 const ( 30 caFile = "ca.pem" 31 caFileContent = ` 32 -----BEGIN CERTIFICATE----- 33 test no content 34 -----END CERTIFICATE----- 35 ` 36 certFile = "cert.pem" 37 certFileContent = ` 38 -----BEGIN CERTIFICATE----- 39 test no content 40 -----END CERTIFICATE----- 41 ` 42 keyFile = "key.pem" 43 keyFileContent = ` 44 -----BEGIN RSA PRIVATE KEY----- 45 test no content 46 -----END RSA PRIVATE KEY----- 47 ` 48 ) 49 50 var ( 51 caFilePath string 52 certFilePath string 53 keyFilePath string 54 ) 55 56 func createTestFixture(t *testing.T) { 57 t.Helper() 58 59 dir := t.TempDir() 60 61 caFilePath = path.Join(dir, caFile) 62 err := os.WriteFile(caFilePath, []byte(caFileContent), 0o644) 63 require.NoError(t, err) 64 65 certFilePath = path.Join(dir, certFile) 66 err = os.WriteFile(certFilePath, []byte(certFileContent), 0o644) 67 require.NoError(t, err) 68 69 keyFilePath = path.Join(dir, keyFile) 70 err = os.WriteFile(keyFilePath, []byte(keyFileContent), 0o644) 71 require.NoError(t, err) 72 } 73 74 func TestPessimistSuite(t *testing.T) { 75 suite.Run(t, new(testTLSConfig)) 76 } 77 78 type testTLSConfig struct { 79 suite.Suite 80 81 noContent []byte 82 } 83 84 func (c *testTLSConfig) SetupSuite() { 85 createTestFixture(c.T()) 86 c.noContent = []byte("test no content") 87 } 88 89 func (c *testTLSConfig) TestLoadAndClearContent() { 90 s := &security.Security{ 91 SSLCA: caFilePath, 92 SSLCert: certFilePath, 93 SSLKey: keyFilePath, 94 } 95 err := s.LoadTLSContent() 96 c.Require().NoError(err) 97 c.Require().Greater(len(s.SSLCABytes), 0) 98 c.Require().Greater(len(s.SSLCertBytes), 0) 99 c.Require().Greater(len(s.SSLKeyBytes), 0) 100 101 c.Require().True(bytes.Contains(s.SSLCABytes, c.noContent)) 102 c.Require().True(bytes.Contains(s.SSLCertBytes, c.noContent)) 103 c.Require().True(bytes.Contains(s.SSLKeyBytes, c.noContent)) 104 105 s.ClearSSLBytesData() 106 c.Require().Len(s.SSLCABytes, 0) 107 c.Require().Len(s.SSLCertBytes, 0) 108 c.Require().Len(s.SSLKeyBytes, 0) 109 110 s.SSLCABase64 = "MTIz" 111 err = s.LoadTLSContent() 112 c.Require().NoError(err) 113 c.Require().Greater(len(s.SSLCABytes), 0) 114 } 115 116 func (c *testTLSConfig) TestTLSTaskConfig() { 117 taskRowStr := fmt.Sprintf(`--- 118 name: test 119 task-mode: all 120 target-database: 121 host: "127.0.0.1" 122 port: 3307 123 user: "root" 124 password: "123456" 125 security: 126 ssl-ca: %s 127 ssl-cert: %s 128 ssl-key: %s 129 block-allow-list: 130 instance: 131 do-dbs: ["dm_benchmark"] 132 mysql-instances: 133 - source-id: "mysql-replica-01-tls" 134 block-allow-list: "instance" 135 `, caFilePath, certFilePath, keyFilePath) 136 task1 := NewTaskConfig() 137 err := task1.RawDecode(taskRowStr) 138 c.Require().NoError(err) 139 c.Require().NoError(task1.TargetDB.Security.LoadTLSContent()) 140 // test load tls content 141 c.Require().True(bytes.Contains(task1.TargetDB.Security.SSLCABytes, c.noContent)) 142 c.Require().True(bytes.Contains(task1.TargetDB.Security.SSLCertBytes, c.noContent)) 143 c.Require().True(bytes.Contains(task1.TargetDB.Security.SSLKeyBytes, c.noContent)) 144 145 // test after to string, taskStr can be `Decode` normally 146 taskStr := task1.String() 147 task2 := NewTaskConfig() 148 err = task2.FromYaml(taskStr) 149 c.Require().NoError(err) 150 c.Require().True(bytes.Contains(task2.TargetDB.Security.SSLCABytes, c.noContent)) 151 c.Require().True(bytes.Contains(task2.TargetDB.Security.SSLCertBytes, c.noContent)) 152 c.Require().True(bytes.Contains(task2.TargetDB.Security.SSLKeyBytes, c.noContent)) 153 c.Require().NoError(task2.adjust()) 154 } 155 156 func (c *testTLSConfig) TestClone() { 157 s := &security.Security{ 158 SSLCA: "a", 159 SSLCert: "b", 160 SSLKey: "c", 161 CertAllowedCN: []string{"d"}, 162 SSLCABytes: nil, 163 SSLKeyBytes: []byte("e"), 164 SSLCertBytes: []byte("f"), 165 } 166 // When add new fields, also update this value 167 // TODO: check it 168 c.Require().Equal(10, reflect.TypeOf(*s).NumField()) 169 clone := s.Clone() 170 c.Require().Equal(s, clone) 171 clone.CertAllowedCN[0] = "g" 172 c.Require().NotEqual(s, clone) 173 }