go.temporal.io/server@v1.23.0/common/rpc/encryption/tls_config_test.go (about)

     1  // The MIT License
     2  //
     3  // Copyright (c) 2020 Temporal Technologies Inc.  All rights reserved.
     4  //
     5  // Copyright (c) 2020 Uber Technologies, Inc.
     6  //
     7  // Permission is hereby granted, free of charge, to any person obtaining a copy
     8  // of this software and associated documentation files (the "Software"), to deal
     9  // in the Software without restriction, including without limitation the rights
    10  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
    11  // copies of the Software, and to permit persons to whom the Software is
    12  // furnished to do so, subject to the following conditions:
    13  //
    14  // The above copyright notice and this permission notice shall be included in
    15  // all copies or substantial portions of the Software.
    16  //
    17  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    18  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    19  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    20  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    21  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    22  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    23  // THE SOFTWARE.
    24  
    25  package encryption
    26  
    27  import (
    28  	"testing"
    29  
    30  	"github.com/stretchr/testify/require"
    31  	"github.com/stretchr/testify/suite"
    32  
    33  	"go.temporal.io/server/common/config"
    34  )
    35  
    36  type (
    37  	tlsConfigTest struct {
    38  		suite.Suite
    39  		*require.Assertions
    40  	}
    41  )
    42  
    43  func TestTLSConfigSuite(t *testing.T) {
    44  	s := new(tlsConfigTest)
    45  	suite.Run(t, s)
    46  }
    47  
    48  func (s *tlsConfigTest) SetupTest() {
    49  	s.Assertions = require.New(s.T())
    50  }
    51  
    52  func (s *tlsConfigTest) TestIsEnabled() {
    53  
    54  	emptyCfg := config.GroupTLS{}
    55  	s.False(emptyCfg.IsServerEnabled())
    56  	s.False(emptyCfg.IsClientEnabled())
    57  	cfg := config.GroupTLS{Server: config.ServerTLS{KeyFile: "foo"}}
    58  	s.True(cfg.IsServerEnabled())
    59  	s.False(cfg.IsClientEnabled())
    60  	cfg = config.GroupTLS{Server: config.ServerTLS{KeyData: "foo"}}
    61  	s.True(cfg.IsServerEnabled())
    62  	s.False(cfg.IsClientEnabled())
    63  	cfg = config.GroupTLS{Client: config.ClientTLS{RootCAFiles: []string{"bar"}}}
    64  	s.False(cfg.IsServerEnabled())
    65  	s.True(cfg.IsClientEnabled())
    66  	cfg = config.GroupTLS{Client: config.ClientTLS{RootCAData: []string{"bar"}}}
    67  	s.False(cfg.IsServerEnabled())
    68  	s.True(cfg.IsClientEnabled())
    69  	cfg = config.GroupTLS{Client: config.ClientTLS{ForceTLS: true}}
    70  	s.False(cfg.IsServerEnabled())
    71  	s.True(cfg.IsClientEnabled())
    72  	cfg = config.GroupTLS{Client: config.ClientTLS{ForceTLS: false}}
    73  	s.False(cfg.IsServerEnabled())
    74  	s.False(cfg.IsClientEnabled())
    75  
    76  }
    77  
    78  func (s *tlsConfigTest) TestIsSystemWorker() {
    79  
    80  	cfg := &config.RootTLS{}
    81  	s.False(isSystemWorker(cfg))
    82  	cfg = &config.RootTLS{SystemWorker: config.WorkerTLS{CertFile: "foo"}}
    83  	s.True(isSystemWorker(cfg))
    84  	cfg = &config.RootTLS{SystemWorker: config.WorkerTLS{CertData: "foo"}}
    85  	s.True(isSystemWorker(cfg))
    86  	cfg = &config.RootTLS{SystemWorker: config.WorkerTLS{Client: config.ClientTLS{RootCAData: []string{"bar"}}}}
    87  	s.True(isSystemWorker(cfg))
    88  	cfg = &config.RootTLS{SystemWorker: config.WorkerTLS{Client: config.ClientTLS{RootCAFiles: []string{"bar"}}}}
    89  	s.True(isSystemWorker(cfg))
    90  	cfg = &config.RootTLS{SystemWorker: config.WorkerTLS{Client: config.ClientTLS{ForceTLS: true}}}
    91  	s.True(isSystemWorker(cfg))
    92  	cfg = &config.RootTLS{SystemWorker: config.WorkerTLS{Client: config.ClientTLS{ForceTLS: false}}}
    93  	s.False(isSystemWorker(cfg))
    94  }
    95  
    96  func (s *tlsConfigTest) TestCertFileAndData() {
    97  	s.testGroupTLS(s.testCertFileAndData)
    98  }
    99  
   100  func (s *tlsConfigTest) TestKeyFileAndData() {
   101  	s.testGroupTLS(s.testKeyFileAndData)
   102  }
   103  
   104  func (s *tlsConfigTest) TestClientCAData() {
   105  	s.testGroupTLS(s.testClientCAData)
   106  }
   107  
   108  func (s *tlsConfigTest) TestClientCAFiles() {
   109  	s.testGroupTLS(s.testClientCAFiles)
   110  }
   111  
   112  func (s *tlsConfigTest) TestRootCAData() {
   113  	s.testGroupTLS(s.testRootCAData)
   114  }
   115  
   116  func (s *tlsConfigTest) TestRootCAFiles() {
   117  	s.testGroupTLS(s.testRootCAFiles)
   118  }
   119  
   120  func (s *tlsConfigTest) testGroupTLS(f func(*config.RootTLS, *config.GroupTLS)) {
   121  
   122  	cfg := &config.RootTLS{Internode: config.GroupTLS{}}
   123  	f(cfg, &cfg.Internode)
   124  	cfg = &config.RootTLS{Frontend: config.GroupTLS{}}
   125  	f(cfg, &cfg.Frontend)
   126  }
   127  
   128  func (s *tlsConfigTest) testCertFileAndData(cfg *config.RootTLS, group *config.GroupTLS) {
   129  
   130  	group.Server = config.ServerTLS{}
   131  	s.Nil(validateRootTLS(cfg))
   132  	group.Server = config.ServerTLS{CertFile: "foo"}
   133  	s.Nil(validateRootTLS(cfg))
   134  	group.Server = config.ServerTLS{CertData: "bar"}
   135  	s.Nil(validateRootTLS(cfg))
   136  	group.Server = config.ServerTLS{CertFile: "foo", CertData: "bar"}
   137  	s.Error(validateRootTLS(cfg))
   138  }
   139  
   140  func (s *tlsConfigTest) testKeyFileAndData(cfg *config.RootTLS, group *config.GroupTLS) {
   141  
   142  	group.Server = config.ServerTLS{}
   143  	s.Nil(validateRootTLS(cfg))
   144  	group.Server = config.ServerTLS{KeyFile: "foo"}
   145  	s.Nil(validateRootTLS(cfg))
   146  	group.Server = config.ServerTLS{KeyData: "bar"}
   147  	s.Nil(validateRootTLS(cfg))
   148  	group.Server = config.ServerTLS{KeyFile: "foo", KeyData: "bar"}
   149  	s.Error(validateRootTLS(cfg))
   150  }
   151  
   152  func (s *tlsConfigTest) testClientCAData(cfg *config.RootTLS, group *config.GroupTLS) {
   153  
   154  	group.Server = config.ServerTLS{}
   155  	s.Nil(validateRootTLS(cfg))
   156  	group.Server = config.ServerTLS{ClientCAData: []string{}}
   157  	s.Nil(validateRootTLS(cfg))
   158  	group.Server = config.ServerTLS{ClientCAData: []string{"foo"}}
   159  	s.Nil(validateRootTLS(cfg))
   160  	group.Server = config.ServerTLS{ClientCAData: []string{"foo", "bar"}}
   161  	s.Nil(validateRootTLS(cfg))
   162  	group.Server = config.ServerTLS{ClientCAData: []string{"foo", " "}}
   163  	s.Error(validateRootTLS(cfg))
   164  	group.Server = config.ServerTLS{ClientCAData: []string{""}}
   165  	s.Error(validateRootTLS(cfg))
   166  }
   167  
   168  func (s *tlsConfigTest) testClientCAFiles(cfg *config.RootTLS, group *config.GroupTLS) {
   169  
   170  	group.Server = config.ServerTLS{}
   171  	s.Nil(validateRootTLS(cfg))
   172  	group.Server = config.ServerTLS{ClientCAFiles: []string{}}
   173  	s.Nil(validateRootTLS(cfg))
   174  	group.Server = config.ServerTLS{ClientCAFiles: []string{"foo"}}
   175  	s.Nil(validateRootTLS(cfg))
   176  	group.Server = config.ServerTLS{ClientCAFiles: []string{"foo", "bar"}}
   177  	s.Nil(validateRootTLS(cfg))
   178  	group.Server = config.ServerTLS{ClientCAFiles: []string{"foo", " "}}
   179  	s.Error(validateRootTLS(cfg))
   180  	group.Server = config.ServerTLS{ClientCAFiles: []string{""}}
   181  	s.Error(validateRootTLS(cfg))
   182  }
   183  
   184  func (s *tlsConfigTest) testRootCAData(cfg *config.RootTLS, group *config.GroupTLS) {
   185  
   186  	group.Client = config.ClientTLS{}
   187  	s.Nil(validateRootTLS(cfg))
   188  	group.Client = config.ClientTLS{RootCAData: []string{}}
   189  	s.Nil(validateRootTLS(cfg))
   190  	group.Client = config.ClientTLS{RootCAData: []string{"foo"}}
   191  	s.Nil(validateRootTLS(cfg))
   192  	group.Client = config.ClientTLS{RootCAData: []string{"foo", "bar"}}
   193  	s.Nil(validateRootTLS(cfg))
   194  	group.Client = config.ClientTLS{RootCAData: []string{"foo", " "}}
   195  	s.Error(validateRootTLS(cfg))
   196  	group.Client = config.ClientTLS{RootCAData: []string{""}}
   197  	s.Error(validateRootTLS(cfg))
   198  }
   199  
   200  func (s *tlsConfigTest) testRootCAFiles(cfg *config.RootTLS, group *config.GroupTLS) {
   201  
   202  	group.Client = config.ClientTLS{}
   203  	s.Nil(validateRootTLS(cfg))
   204  	group.Client = config.ClientTLS{RootCAFiles: []string{}}
   205  	s.Nil(validateRootTLS(cfg))
   206  	group.Client = config.ClientTLS{RootCAFiles: []string{"foo"}}
   207  	s.Nil(validateRootTLS(cfg))
   208  	group.Client = config.ClientTLS{RootCAFiles: []string{"foo", "bar"}}
   209  	s.Nil(validateRootTLS(cfg))
   210  	group.Client = config.ClientTLS{RootCAFiles: []string{"foo", " "}}
   211  	s.Error(validateRootTLS(cfg))
   212  	group.Client = config.ClientTLS{RootCAFiles: []string{""}}
   213  	s.Error(validateRootTLS(cfg))
   214  }
   215  
   216  func (s *tlsConfigTest) TestSystemWorkerTLSConfig() {
   217  	cfg := &config.RootTLS{}
   218  	cfg.SystemWorker = config.WorkerTLS{}
   219  	s.Nil(validateRootTLS(cfg))
   220  	cfg.SystemWorker = config.WorkerTLS{CertFile: "foo"}
   221  	s.Nil(validateRootTLS(cfg))
   222  	cfg.SystemWorker = config.WorkerTLS{CertData: "bar"}
   223  	s.Nil(validateRootTLS(cfg))
   224  	cfg.SystemWorker = config.WorkerTLS{CertFile: "foo", CertData: "bar"}
   225  	s.Error(validateRootTLS(cfg))
   226  	cfg.SystemWorker = config.WorkerTLS{KeyFile: "foo"}
   227  	s.Nil(validateRootTLS(cfg))
   228  	cfg.SystemWorker = config.WorkerTLS{KeyData: "bar"}
   229  	s.Nil(validateRootTLS(cfg))
   230  	cfg.SystemWorker = config.WorkerTLS{KeyFile: "foo", KeyData: "bar"}
   231  	s.Error(validateRootTLS(cfg))
   232  
   233  	cfg.SystemWorker = config.WorkerTLS{Client: config.ClientTLS{}}
   234  	client := &cfg.SystemWorker.Client
   235  	client.RootCAData = []string{}
   236  	s.Nil(validateRootTLS(cfg))
   237  	client.RootCAData = []string{"foo"}
   238  	s.Nil(validateRootTLS(cfg))
   239  	client.RootCAData = []string{"foo", "bar"}
   240  	s.Nil(validateRootTLS(cfg))
   241  	client.RootCAData = []string{"foo", " "}
   242  	s.Error(validateRootTLS(cfg))
   243  	client.RootCAData = []string{""}
   244  	s.Error(validateRootTLS(cfg))
   245  }