github.com/pingcap/tiflow@v0.0.0-20240520035814-5bf52d54e205/dm/config/security_test.go (about)

     1  // Copyright 2021 PingCAP, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package config
    15  
    16  import (
    17  	"bytes"
    18  	"fmt"
    19  	"os"
    20  	"path"
    21  	"reflect"
    22  	"testing"
    23  
    24  	"github.com/pingcap/tiflow/dm/config/security"
    25  	"github.com/stretchr/testify/require"
    26  	"github.com/stretchr/testify/suite"
    27  )
    28  
    29  const (
    30  	caFile        = "ca.pem"
    31  	caFileContent = `
    32  -----BEGIN CERTIFICATE-----
    33  test no content
    34  -----END CERTIFICATE-----
    35  `
    36  	certFile        = "cert.pem"
    37  	certFileContent = `
    38  -----BEGIN CERTIFICATE-----
    39  test no content
    40  -----END CERTIFICATE-----
    41  `
    42  	keyFile        = "key.pem"
    43  	keyFileContent = `
    44  -----BEGIN RSA PRIVATE KEY-----
    45  test no content
    46  -----END RSA PRIVATE KEY-----
    47  `
    48  )
    49  
    50  var (
    51  	caFilePath   string
    52  	certFilePath string
    53  	keyFilePath  string
    54  )
    55  
    56  func createTestFixture(t *testing.T) {
    57  	t.Helper()
    58  
    59  	dir := t.TempDir()
    60  
    61  	caFilePath = path.Join(dir, caFile)
    62  	err := os.WriteFile(caFilePath, []byte(caFileContent), 0o644)
    63  	require.NoError(t, err)
    64  
    65  	certFilePath = path.Join(dir, certFile)
    66  	err = os.WriteFile(certFilePath, []byte(certFileContent), 0o644)
    67  	require.NoError(t, err)
    68  
    69  	keyFilePath = path.Join(dir, keyFile)
    70  	err = os.WriteFile(keyFilePath, []byte(keyFileContent), 0o644)
    71  	require.NoError(t, err)
    72  }
    73  
    74  func TestPessimistSuite(t *testing.T) {
    75  	suite.Run(t, new(testTLSConfig))
    76  }
    77  
    78  type testTLSConfig struct {
    79  	suite.Suite
    80  
    81  	noContent []byte
    82  }
    83  
    84  func (c *testTLSConfig) SetupSuite() {
    85  	createTestFixture(c.T())
    86  	c.noContent = []byte("test no content")
    87  }
    88  
    89  func (c *testTLSConfig) TestLoadAndClearContent() {
    90  	s := &security.Security{
    91  		SSLCA:   caFilePath,
    92  		SSLCert: certFilePath,
    93  		SSLKey:  keyFilePath,
    94  	}
    95  	err := s.LoadTLSContent()
    96  	c.Require().NoError(err)
    97  	c.Require().Greater(len(s.SSLCABytes), 0)
    98  	c.Require().Greater(len(s.SSLCertBytes), 0)
    99  	c.Require().Greater(len(s.SSLKeyBytes), 0)
   100  
   101  	c.Require().True(bytes.Contains(s.SSLCABytes, c.noContent))
   102  	c.Require().True(bytes.Contains(s.SSLCertBytes, c.noContent))
   103  	c.Require().True(bytes.Contains(s.SSLKeyBytes, c.noContent))
   104  
   105  	s.ClearSSLBytesData()
   106  	c.Require().Len(s.SSLCABytes, 0)
   107  	c.Require().Len(s.SSLCertBytes, 0)
   108  	c.Require().Len(s.SSLKeyBytes, 0)
   109  
   110  	s.SSLCABase64 = "MTIz"
   111  	err = s.LoadTLSContent()
   112  	c.Require().NoError(err)
   113  	c.Require().Greater(len(s.SSLCABytes), 0)
   114  }
   115  
   116  func (c *testTLSConfig) TestTLSTaskConfig() {
   117  	taskRowStr := fmt.Sprintf(`---
   118  name: test
   119  task-mode: all
   120  target-database:
   121      host: "127.0.0.1"
   122      port: 3307
   123      user: "root"
   124      password: "123456"
   125      security:
   126        ssl-ca: %s
   127        ssl-cert: %s
   128        ssl-key: %s
   129  block-allow-list:
   130    instance:
   131      do-dbs: ["dm_benchmark"]
   132  mysql-instances:
   133    - source-id: "mysql-replica-01-tls"
   134      block-allow-list: "instance"
   135  `, caFilePath, certFilePath, keyFilePath)
   136  	task1 := NewTaskConfig()
   137  	err := task1.RawDecode(taskRowStr)
   138  	c.Require().NoError(err)
   139  	c.Require().NoError(task1.TargetDB.Security.LoadTLSContent())
   140  	// test load tls content
   141  	c.Require().True(bytes.Contains(task1.TargetDB.Security.SSLCABytes, c.noContent))
   142  	c.Require().True(bytes.Contains(task1.TargetDB.Security.SSLCertBytes, c.noContent))
   143  	c.Require().True(bytes.Contains(task1.TargetDB.Security.SSLKeyBytes, c.noContent))
   144  
   145  	// test after to string, taskStr can be `Decode` normally
   146  	taskStr := task1.String()
   147  	task2 := NewTaskConfig()
   148  	err = task2.FromYaml(taskStr)
   149  	c.Require().NoError(err)
   150  	c.Require().True(bytes.Contains(task2.TargetDB.Security.SSLCABytes, c.noContent))
   151  	c.Require().True(bytes.Contains(task2.TargetDB.Security.SSLCertBytes, c.noContent))
   152  	c.Require().True(bytes.Contains(task2.TargetDB.Security.SSLKeyBytes, c.noContent))
   153  	c.Require().NoError(task2.adjust())
   154  }
   155  
   156  func (c *testTLSConfig) TestClone() {
   157  	s := &security.Security{
   158  		SSLCA:         "a",
   159  		SSLCert:       "b",
   160  		SSLKey:        "c",
   161  		CertAllowedCN: []string{"d"},
   162  		SSLCABytes:    nil,
   163  		SSLKeyBytes:   []byte("e"),
   164  		SSLCertBytes:  []byte("f"),
   165  	}
   166  	// When add new fields, also update this value
   167  	// TODO: check it
   168  	c.Require().Equal(10, reflect.TypeOf(*s).NumField())
   169  	clone := s.Clone()
   170  	c.Require().Equal(s, clone)
   171  	clone.CertAllowedCN[0] = "g"
   172  	c.Require().NotEqual(s, clone)
   173  }