github.com/matrixorigin/matrixone@v1.2.0/pkg/proxy/handler_test.go (about) 1 // Copyright 2021 - 2023 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package proxy 16 17 import ( 18 "context" 19 "crypto/rand" 20 "crypto/rsa" 21 "crypto/tls" 22 "crypto/x509" 23 "crypto/x509/pkix" 24 "database/sql" 25 "encoding/pem" 26 "fmt" 27 "math/big" 28 "os" 29 "testing" 30 "time" 31 32 "github.com/go-sql-driver/mysql" 33 "github.com/lni/goutils/leaktest" 34 "github.com/matrixorigin/matrixone/pkg/clusterservice" 35 "github.com/matrixorigin/matrixone/pkg/common/log" 36 "github.com/matrixorigin/matrixone/pkg/common/runtime" 37 "github.com/matrixorigin/matrixone/pkg/common/stopper" 38 "github.com/matrixorigin/matrixone/pkg/pb/metadata" 39 "github.com/stretchr/testify/require" 40 ) 41 42 type testProxyHandler struct { 43 ctx context.Context 44 st *stopper.Stopper 45 logger *log.MOLogger 46 hc *mockHAKeeperClient 47 mc clusterservice.MOCluster 48 re *rebalancer 49 ru Router 50 closeFn func() 51 counterSet *counterSet 52 } 53 54 func newTestProxyHandler(t *testing.T) *testProxyHandler { 55 rt := runtime.DefaultRuntime() 56 runtime.SetupProcessLevelRuntime(rt) 57 ctx, cancel := context.WithCancel(context.TODO()) 58 hc := &mockHAKeeperClient{} 59 mc := clusterservice.NewMOCluster(hc, 3*time.Second) 60 rt.SetGlobalVariables(runtime.ClusterService, mc) 61 logger := rt.Logger() 62 st := stopper.NewStopper("test-proxy", stopper.WithLogger(rt.Logger().RawLogger())) 63 re := testRebalancer(t, st, logger, mc) 64 return &testProxyHandler{ 65 ctx: ctx, 66 st: st, 67 logger: logger, 68 hc: hc, 69 mc: mc, 70 re: re, 71 ru: newRouter(mc, re, false), 72 closeFn: func() { 73 mc.Close() 74 st.Stop() 75 cancel() 76 }, 77 counterSet: newCounterSet(), 78 } 79 } 80 81 func certGen(basePath string) (*tlsConfig, error) { 82 max := new(big.Int).Lsh(big.NewInt(1), 128) 83 serialNumber, _ := rand.Int(rand.Reader, max) 84 subject := pkix.Name{ 85 Country: []string{"CN"}, 86 Province: []string{"SH"}, 87 Organization: []string{"MO"}, 88 OrganizationalUnit: []string{"Dev"}, 89 } 90 91 // set up CA certificate 92 ca := &x509.Certificate{ 93 SerialNumber: serialNumber, 94 Subject: subject, 95 NotBefore: time.Now(), 96 NotAfter: time.Now().Add(365 * 24 * time.Hour), 97 IsCA: true, 98 ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, 99 } 100 101 // create our private and public key 102 caPrivKey, err := rsa.GenerateKey(rand.Reader, 2048) 103 if err != nil { 104 return nil, err 105 } 106 107 // create the CA 108 caBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, &caPrivKey.PublicKey, caPrivKey) 109 if err != nil { 110 return nil, err 111 } 112 113 // pem encode 114 caFile := basePath + "/ca.pem" 115 caOut, _ := os.Create(caFile) 116 if err := pem.Encode(caOut, &pem.Block{ 117 Type: "CERTIFICATE", 118 Bytes: caBytes, 119 }); err != nil { 120 return nil, err 121 } 122 defer func() { 123 _ = caOut.Close() 124 }() 125 126 // set up server certificate 127 cert := &x509.Certificate{ 128 SerialNumber: serialNumber, 129 Subject: subject, 130 NotBefore: time.Now(), 131 NotAfter: time.Now().Add(365 * 24 * time.Hour), 132 ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, 133 } 134 135 certPrivKey, err := rsa.GenerateKey(rand.Reader, 2048) 136 if err != nil { 137 return nil, err 138 } 139 140 certBytes, err := x509.CreateCertificate(rand.Reader, cert, ca, &certPrivKey.PublicKey, caPrivKey) 141 if err != nil { 142 return nil, err 143 } 144 145 certFile := basePath + "/server-cert.pem" 146 certOut, _ := os.Create(certFile) 147 if err := pem.Encode(certOut, &pem.Block{ 148 Type: "CERTIFICATE", 149 Bytes: certBytes, 150 }); err != nil { 151 return nil, err 152 } 153 defer func() { 154 _ = certOut.Close() 155 }() 156 157 keyFile := basePath + "/server-key.pem" 158 keyOut, _ := os.Create(keyFile) 159 if err := pem.Encode(keyOut, &pem.Block{ 160 Type: "RSA PRIVATE KEY", 161 Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey), 162 }); err != nil { 163 return nil, err 164 } 165 defer func() { 166 _ = keyOut.Close() 167 }() 168 169 return &tlsConfig{ 170 caFile: caFile, 171 certFile: certFile, 172 keyFile: keyFile, 173 }, nil 174 } 175 176 func TestHandler_Handle(t *testing.T) { 177 defer leaktest.AfterTest(t)() 178 179 temp := os.TempDir() 180 ctx, cancel := context.WithCancel(context.Background()) 181 defer cancel() 182 rt := runtime.DefaultRuntime() 183 runtime.SetupProcessLevelRuntime(rt) 184 listenAddr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond()) 185 require.NoError(t, os.RemoveAll(listenAddr)) 186 cfg := Config{ 187 ListenAddress: "unix://" + listenAddr, 188 RebalanceDisabled: true, 189 } 190 hc := &mockHAKeeperClient{} 191 mc := clusterservice.NewMOCluster(hc, 3*time.Second) 192 defer mc.Close() 193 rt.SetGlobalVariables(runtime.ClusterService, mc) 194 addr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond()) 195 require.NoError(t, os.RemoveAll(addr)) 196 cn1 := testMakeCNServer("cn11", addr, 0, "", labelInfo{}) 197 hc.updateCN(cn1.uuid, cn1.addr, map[string]metadata.LabelList{}) 198 // start backend server. 199 stopFn := startTestCNServer(t, ctx, addr, nil) 200 defer func() { 201 require.NoError(t, stopFn()) 202 }() 203 mc.ForceRefresh(true) 204 205 // start proxy. 206 s, err := NewServer(ctx, cfg, WithRuntime(runtime.DefaultRuntime()), 207 WithHAKeeperClient(hc)) 208 defer func() { 209 err := s.Close() 210 require.NoError(t, err) 211 }() 212 require.NoError(t, err) 213 require.NotNil(t, s) 214 err = s.Start() 215 require.NoError(t, err) 216 217 db, err := sql.Open("mysql", fmt.Sprintf("dump:111@unix(%s)/db1", listenAddr)) 218 // connect to server. 219 require.NoError(t, err) 220 require.NotNil(t, db) 221 defer func() { 222 _ = db.Close() 223 timeout := time.NewTimer(time.Second * 15) 224 tick := time.NewTicker(time.Millisecond * 100) 225 var connTotal int64 226 tt := false 227 for { 228 select { 229 case <-tick.C: 230 connTotal = s.counterSet.connTotal.Load() 231 case <-timeout.C: 232 tt = true 233 } 234 if connTotal == 0 || tt { 235 break 236 } 237 } 238 tick.Stop() 239 timeout.Stop() 240 require.Equal(t, int64(0), connTotal) 241 }() 242 _, err = db.Exec("anystmt") 243 require.NoError(t, err) 244 245 require.Equal(t, int64(1), s.counterSet.connAccepted.Load()) 246 require.Equal(t, int64(1), s.counterSet.connTotal.Load()) 247 } 248 249 func TestHandler_HandleErr(t *testing.T) { 250 defer leaktest.AfterTest(t)() 251 252 temp := os.TempDir() 253 ctx, cancel := context.WithCancel(context.Background()) 254 defer cancel() 255 rt := runtime.DefaultRuntime() 256 runtime.SetupProcessLevelRuntime(rt) 257 listenAddr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond()) 258 require.NoError(t, os.RemoveAll(listenAddr)) 259 cfg := Config{ 260 ListenAddress: "unix://" + listenAddr, 261 RebalanceDisabled: true, 262 } 263 hc := &mockHAKeeperClient{} 264 mc := clusterservice.NewMOCluster(hc, 3*time.Second) 265 defer mc.Close() 266 rt.SetGlobalVariables(runtime.ClusterService, mc) 267 addr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond()) 268 require.NoError(t, os.RemoveAll(addr)) 269 270 // start proxy. 271 s, err := NewServer(ctx, cfg, WithRuntime(runtime.DefaultRuntime()), 272 WithHAKeeperClient(hc)) 273 defer func() { 274 err := s.Close() 275 require.NoError(t, err) 276 }() 277 require.NoError(t, err) 278 require.NotNil(t, s) 279 err = s.Start() 280 require.NoError(t, err) 281 282 db, err := sql.Open("mysql", fmt.Sprintf("dump:111@unix(%s)/db1", listenAddr)) 283 // connect to server. 284 require.NoError(t, err) 285 require.NotNil(t, db) 286 defer func() { 287 _ = db.Close() 288 timeout := time.NewTimer(time.Second * 15) 289 tick := time.NewTicker(time.Millisecond * 100) 290 var connTotal int64 291 tt := false 292 for { 293 select { 294 case <-tick.C: 295 connTotal = s.counterSet.connTotal.Load() 296 case <-timeout.C: 297 tt = true 298 } 299 if connTotal == 0 || tt { 300 break 301 } 302 } 303 tick.Stop() 304 timeout.Stop() 305 require.Equal(t, int64(0), connTotal) 306 }() 307 _, err = db.Exec("anystmt") 308 require.Error(t, err) 309 310 require.Equal(t, int64(1), s.counterSet.connAccepted.Load()) 311 } 312 313 func TestHandler_HandleWithSSL(t *testing.T) { 314 defer leaktest.AfterTest(t)() 315 316 temp := os.TempDir() 317 ctx, cancel := context.WithCancel(context.Background()) 318 defer cancel() 319 rt := runtime.DefaultRuntime() 320 runtime.SetupProcessLevelRuntime(rt) 321 listenAddr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond()) 322 require.NoError(t, os.RemoveAll(listenAddr)) 323 cfg := Config{ 324 ListenAddress: "unix://" + listenAddr, 325 RebalanceDisabled: true, 326 } 327 hc := &mockHAKeeperClient{} 328 mc := clusterservice.NewMOCluster(hc, 3*time.Second) 329 defer mc.Close() 330 rt.SetGlobalVariables(runtime.ClusterService, mc) 331 addr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond()) 332 require.NoError(t, os.RemoveAll(addr)) 333 cn1 := testMakeCNServer("cn11", addr, 0, "", labelInfo{}) 334 hc.updateCN(cn1.uuid, cn1.addr, map[string]metadata.LabelList{}) 335 336 tlsC, err := certGen(temp) 337 require.NoError(t, err) 338 tlsC.enabled = true 339 // start backend server. 340 stopFn := startTestCNServer(t, ctx, addr, tlsC) 341 defer func() { 342 require.NoError(t, stopFn()) 343 }() 344 mc.ForceRefresh(true) 345 346 // start proxy. 347 s, err := NewServer(ctx, cfg, WithRuntime(runtime.DefaultRuntime()), 348 WithHAKeeperClient(hc), 349 WithTLSEnabled(), 350 WithTLSCAFile(tlsC.caFile), 351 WithTLSCertFile(tlsC.certFile), 352 WithTLSKeyFile(tlsC.keyFile)) 353 defer func() { 354 err := s.Close() 355 require.NoError(t, err) 356 }() 357 require.NoError(t, err) 358 require.NotNil(t, s) 359 err = s.Start() 360 require.NoError(t, err) 361 362 rootCertPool := x509.NewCertPool() 363 364 pem1, err := os.ReadFile(tlsC.caFile) 365 require.NoError(t, err) 366 367 ok := rootCertPool.AppendCertsFromPEM(pem1) 368 require.True(t, ok) 369 370 err = mysql.RegisterTLSConfig("custom", &tls.Config{ 371 RootCAs: rootCertPool, 372 InsecureSkipVerify: true, 373 }) 374 require.NoError(t, err) 375 376 db, err := sql.Open("mysql", 377 fmt.Sprintf("dump:111@unix(%s)/db1?tls=custom", listenAddr)) 378 // connect to server. 379 require.NoError(t, err) 380 require.NotNil(t, db) 381 defer func() { 382 _ = db.Close() 383 }() 384 _, _ = db.Exec("any stmt") 385 _, err = db.Exec("any stmt") 386 require.NoError(t, err) 387 require.Equal(t, int64(1), s.counterSet.connAccepted.Load()) 388 require.Equal(t, int64(1), s.counterSet.connTotal.Load()) 389 } 390 391 func testWithServer(t *testing.T, fn func(*testing.T, string, *Server)) { 392 defer leaktest.AfterTest(t)() 393 394 temp := os.TempDir() 395 ctx, cancel := context.WithCancel(context.Background()) 396 defer cancel() 397 rt := runtime.DefaultRuntime() 398 runtime.SetupProcessLevelRuntime(rt) 399 listenAddr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond()) 400 require.NoError(t, os.RemoveAll(listenAddr)) 401 cfg := Config{ 402 ListenAddress: "unix://" + listenAddr, 403 RebalanceDisabled: true, 404 } 405 hc := &mockHAKeeperClient{} 406 mc := clusterservice.NewMOCluster(hc, 3*time.Second) 407 defer mc.Close() 408 rt.SetGlobalVariables(runtime.ClusterService, mc) 409 addr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond()) 410 require.NoError(t, os.RemoveAll(addr)) 411 cn1 := testMakeCNServer("cn11", addr, 0, "", labelInfo{}) 412 hc.updateCN(cn1.uuid, cn1.addr, map[string]metadata.LabelList{}) 413 // start backend server. 414 stopFn := startTestCNServer(t, ctx, addr, nil) 415 defer func() { 416 require.NoError(t, stopFn()) 417 }() 418 mc.ForceRefresh(true) 419 420 // start proxy. 421 s, err := NewServer(ctx, cfg, WithRuntime(runtime.DefaultRuntime()), 422 WithHAKeeperClient(hc)) 423 defer func() { 424 err := s.Close() 425 require.NoError(t, err) 426 }() 427 require.NoError(t, err) 428 require.NotNil(t, s) 429 err = s.Start() 430 require.NoError(t, err) 431 432 fn(t, listenAddr, s) 433 } 434 435 func TestHandler_HandleEventKillQuery(t *testing.T) { 436 testWithServer(t, func(t *testing.T, addr string, s *Server) { 437 db1, err := sql.Open("mysql", fmt.Sprintf("dump:111@unix(%s)/db1", addr)) 438 // connect to server. 439 require.NoError(t, err) 440 require.NotNil(t, db1) 441 defer func() { 442 _ = db1.Close() 443 }() 444 res, err := db1.Exec("select 1") 445 require.NoError(t, err) 446 connID, _ := res.LastInsertId() // fake connection id 447 448 db2, err := sql.Open("mysql", fmt.Sprintf("dump:111@unix(%s)/db1", addr)) 449 // connect to server. 450 require.NoError(t, err) 451 require.NotNil(t, db2) 452 defer func() { 453 _ = db2.Close() 454 }() 455 456 _, err = db2.Exec(fmt.Sprintf("kill query %d", connID)) 457 require.NoError(t, err) 458 459 require.Equal(t, int64(2), s.counterSet.connAccepted.Load()) 460 }) 461 } 462 463 func TestHandler_HandleEventSetVar(t *testing.T) { 464 testWithServer(t, func(t *testing.T, addr string, s *Server) { 465 db1, err := sql.Open("mysql", fmt.Sprintf("dump:111@unix(%s)/db1", addr)) 466 // connect to server. 467 require.NoError(t, err) 468 require.NotNil(t, db1) 469 defer func() { 470 _ = db1.Close() 471 }() 472 _, err = db1.Exec("set session cn_label='acc1'") 473 require.NoError(t, err) 474 475 res, err := db1.Query("show session variables") 476 require.NoError(t, err) 477 defer res.Close() 478 var varName, varValue string 479 for res.Next() { 480 err := res.Scan(&varName, &varValue) 481 require.NoError(t, err) 482 require.Equal(t, "cn_label", varName) 483 require.Equal(t, "acc1", varValue) 484 } 485 err = res.Err() 486 require.NoError(t, err) 487 488 require.Equal(t, int64(1), s.counterSet.connAccepted.Load()) 489 }) 490 } 491 492 func TestHandler_HandleTxn(t *testing.T) { 493 testWithServer(t, func(t *testing.T, addr string, s *Server) { 494 db1, err := sql.Open("mysql", fmt.Sprintf("a1#root:111@unix(%s)/db1", addr)) 495 // connect to server. 496 require.NoError(t, err) 497 require.NotNil(t, db1) 498 defer func() { 499 _ = db1.Close() 500 }() 501 _, err = db1.Exec("select 1") 502 require.NoError(t, err) 503 }) 504 }