github.com/matrixorigin/matrixone@v1.2.0/pkg/proxy/server_conn_test.go (about) 1 // Copyright 2021 - 2023 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package proxy 16 17 import ( 18 "bufio" 19 "context" 20 "crypto/tls" 21 "fmt" 22 "net" 23 "os" 24 "strings" 25 "sync" 26 "sync/atomic" 27 "testing" 28 "time" 29 30 "github.com/fagongzi/goetty/v2" 31 "github.com/lni/goutils/leaktest" 32 "github.com/stretchr/testify/require" 33 34 "github.com/matrixorigin/matrixone/pkg/config" 35 "github.com/matrixorigin/matrixone/pkg/container/types" 36 "github.com/matrixorigin/matrixone/pkg/frontend" 37 "github.com/matrixorigin/matrixone/pkg/pb/proxy" 38 "github.com/matrixorigin/matrixone/pkg/sql/plan" 39 ) 40 41 var testSlat = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0} 42 var testPacket = &frontend.Packet{ 43 Length: 1, 44 SequenceID: 0, 45 Payload: []byte{1}, 46 } 47 48 func testMakeCNServer( 49 uuid string, addr string, connID uint32, hash LabelHash, reqLabel labelInfo, 50 ) *CNServer { 51 if strings.Contains(addr, "sock") { 52 addr = "unix://" + addr 53 } 54 return &CNServer{ 55 connID: connID, 56 addr: addr, 57 uuid: uuid, 58 salt: testSlat, 59 hash: hash, 60 reqLabel: reqLabel, 61 } 62 } 63 64 type mockServerConn struct { 65 conn net.Conn 66 } 67 68 var _ ServerConn = (*mockServerConn)(nil) 69 70 func newMockServerConn(conn net.Conn) *mockServerConn { 71 m := &mockServerConn{ 72 conn: conn, 73 } 74 return m 75 } 76 77 func (s *mockServerConn) ConnID() uint32 { return 0 } 78 func (s *mockServerConn) RawConn() net.Conn { return s.conn } 79 func (s *mockServerConn) HandleHandshake(_ *frontend.Packet, _ time.Duration) (*frontend.Packet, error) { 80 return nil, nil 81 } 82 func (s *mockServerConn) ExecStmt(stmt internalStmt, resp chan<- []byte) (bool, error) { 83 sendResp(makeOKPacket(8), resp) 84 return true, nil 85 } 86 func (s *mockServerConn) Close() error { 87 if s.conn != nil { 88 _ = s.conn.Close() 89 } 90 return nil 91 } 92 93 var baseConnID atomic.Uint32 94 95 type tlsConfig struct { 96 enabled bool 97 caFile string 98 certFile string 99 keyFile string 100 } 101 102 type testCNServer struct { 103 sync.Mutex 104 ctx context.Context 105 scheme string 106 addr string 107 listener net.Listener 108 started bool 109 quit chan interface{} 110 111 globalVars map[string]string 112 tlsCfg tlsConfig 113 tlsConfig *tls.Config 114 115 beforeHandle func() 116 } 117 118 type testHandler struct { 119 mysqlProto *frontend.MysqlProtocolImpl 120 connID uint32 121 conn goetty.IOSession 122 sessionVars map[string]string 123 labels map[string]string 124 server *testCNServer 125 status uint16 126 } 127 128 type option func(s *testCNServer) 129 130 func withBeforeHandle(f func()) option { 131 return func(s *testCNServer) { 132 s.beforeHandle = f 133 } 134 } 135 136 func startTestCNServer(t *testing.T, ctx context.Context, addr string, cfg *tlsConfig, opts ...option) func() error { 137 b := &testCNServer{ 138 ctx: ctx, 139 scheme: "tcp", 140 addr: addr, 141 quit: make(chan interface{}), 142 globalVars: make(map[string]string), 143 } 144 for _, opt := range opts { 145 opt(b) 146 } 147 if cfg != nil { 148 b.tlsCfg = *cfg 149 } 150 if strings.Contains(addr, "sock") { 151 b.scheme = "unix" 152 } 153 go func() { 154 err := b.Start() 155 require.NoError(t, err) 156 }() 157 require.True(t, b.waitCNServerReady()) 158 return func() error { 159 return b.Stop() 160 } 161 } 162 163 func (s *testCNServer) waitCNServerReady() bool { 164 ctx, cancel := context.WithTimeout(s.ctx, time.Second*3) 165 defer cancel() 166 tick := time.NewTicker(time.Millisecond * 100) 167 for { 168 select { 169 case <-ctx.Done(): 170 return false 171 case <-tick.C: 172 s.Lock() 173 started := s.started 174 s.Unlock() 175 conn, err := net.Dial(s.scheme, s.addr) 176 if err == nil && started { 177 _ = conn.Close() 178 return true 179 } 180 if conn != nil { 181 _ = conn.Close() 182 } 183 } 184 } 185 } 186 187 func (s *testCNServer) Start() error { 188 var err error 189 if s.tlsCfg.enabled { 190 s.tlsConfig, err = frontend.ConstructTLSConfig( 191 context.TODO(), 192 s.tlsCfg.caFile, 193 s.tlsCfg.certFile, 194 s.tlsCfg.keyFile, 195 ) 196 if err != nil { 197 return err 198 } 199 } 200 s.listener, err = net.Listen(s.scheme, s.addr) 201 if err != nil { 202 return err 203 } 204 s.Lock() 205 s.started = true 206 s.Unlock() 207 208 for { 209 select { 210 case <-s.ctx.Done(): 211 return nil 212 default: 213 conn, err := s.listener.Accept() 214 if conn == nil { 215 continue 216 } 217 if err != nil { 218 select { 219 case <-s.quit: 220 return nil 221 default: 222 return err 223 } 224 } else { 225 fp := config.FrontendParameters{ 226 EnableTls: s.tlsCfg.enabled, 227 } 228 fp.SetDefaultValues() 229 cid := baseConnID.Add(1) 230 c := goetty.NewIOSession(goetty.WithSessionCodec(frontend.NewSqlCodec()), 231 goetty.WithSessionConn(uint64(cid), conn)) 232 h := &testHandler{ 233 connID: cid, 234 conn: c, 235 mysqlProto: frontend.NewMysqlClientProtocol( 236 cid, c, 0, &fp), 237 sessionVars: make(map[string]string), 238 labels: make(map[string]string), 239 server: s, 240 } 241 if s.beforeHandle != nil { 242 s.beforeHandle() 243 } 244 go func(h *testHandler) { 245 testHandle(h) 246 }(h) 247 } 248 } 249 } 250 } 251 252 func testHandle(h *testHandler) { 253 // read extra info from proxy. 254 extraInfo := proxy.ExtraInfo{} 255 reader := bufio.NewReader(h.conn.RawConn()) 256 _ = extraInfo.Decode(reader) 257 // server writes init handshake. 258 _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeHandshakePayload()) 259 // server reads auth information from client. 260 _, _ = h.conn.Read(goetty.ReadOptions{}) 261 // server writes ok packet. 262 _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeOKPayload(0, uint64(h.connID), 0, 0, "")) 263 for { 264 msg, err := h.conn.Read(goetty.ReadOptions{}) 265 if err != nil { 266 break 267 } 268 packet, ok := msg.(*frontend.Packet) 269 if !ok { 270 return 271 } 272 if packet.Length > 1 && packet.Payload[0] == 3 { 273 if strings.HasPrefix(string(packet.Payload[1:]), "set session") { 274 h.handleSetVar(packet) 275 } else if string(packet.Payload[1:]) == "show session variables" { 276 h.handleShowVar() 277 } else if string(packet.Payload[1:]) == "show global variables" { 278 h.handleShowGlobalVar() 279 } else if string(packet.Payload[1:]) == "begin" { 280 h.handleStartTxn() 281 } else if string(packet.Payload[1:]) == "commit" || string(packet.Payload[1:]) == "rollback" { 282 h.handleStopTxn() 283 } else if strings.HasPrefix(string(packet.Payload[1:]), "kill connection") { 284 h.handleKillConn() 285 } else { 286 h.handleCommon() 287 } 288 } else { 289 h.handleCommon() 290 } 291 } 292 } 293 294 func (h *testHandler) handleCommon() { 295 h.mysqlProto.SetSequenceID(1) 296 // set last insert id as connection id to do test more easily. 297 _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeOKPayload(0, uint64(h.connID), h.status, 0, "")) 298 } 299 300 func (h *testHandler) handleSetVar(packet *frontend.Packet) { 301 words := strings.Split(string(packet.Payload[1:]), " ") 302 v := strings.Split(words[2], "=") 303 h.sessionVars[v[0]] = strings.Trim(v[1], "'") 304 h.mysqlProto.SetSequenceID(1) 305 _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeOKPayload(0, uint64(h.connID), h.status, 0, "")) 306 } 307 308 func (h *testHandler) handleKillConn() { 309 h.server.globalVars["killed"] = "yes" 310 h.mysqlProto.SetSequenceID(1) 311 _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeOKPayload(0, uint64(h.connID), h.status, 0, "")) 312 } 313 314 func (h *testHandler) handleShowVar() { 315 h.mysqlProto.SetSequenceID(1) 316 err := h.mysqlProto.SendColumnCountPacket(2) 317 if err != nil { 318 _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeErrPayload(0, "", err.Error())) 319 return 320 } 321 cols := []*plan.ColDef{ 322 {Typ: plan.Type{Id: int32(types.T_char)}, Name: "Variable_name"}, 323 {Typ: plan.Type{Id: int32(types.T_char)}, Name: "Value"}, 324 } 325 columns := make([]interface{}, len(cols)) 326 res := &frontend.MysqlResultSet{} 327 for i, col := range cols { 328 c := new(frontend.MysqlColumn) 329 c.SetName(col.Name) 330 c.SetOrgName(col.Name) 331 c.SetTable(col.Typ.Table) 332 c.SetOrgTable(col.Typ.Table) 333 c.SetAutoIncr(col.Typ.AutoIncr) 334 c.SetSchema("") 335 c.SetDecimal(col.Typ.Scale) 336 columns[i] = c 337 res.AddColumn(c) 338 } 339 for _, c := range columns { 340 if err := h.mysqlProto.SendColumnDefinitionPacket(context.TODO(), c.(frontend.Column), 3); err != nil { 341 _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeErrPayload(0, "", err.Error())) 342 return 343 } 344 } 345 _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeEOFPayload(0, h.status)) 346 for k, v := range h.sessionVars { 347 row := make([]interface{}, 2) 348 row[0] = k 349 row[1] = v 350 res.AddRow(row) 351 } 352 ses := &frontend.Session{} 353 h.mysqlProto.SetSession(ses) 354 if err := h.mysqlProto.SendResultSetTextBatchRow(res, res.GetRowCount()); err != nil { 355 _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeErrPayload(0, "", err.Error())) 356 return 357 } 358 _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeEOFPayload(0, h.status)) 359 } 360 361 func (h *testHandler) handleShowGlobalVar() { 362 h.mysqlProto.SetSequenceID(1) 363 err := h.mysqlProto.SendColumnCountPacket(2) 364 if err != nil { 365 _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeErrPayload(0, "", err.Error())) 366 return 367 } 368 cols := []*plan.ColDef{ 369 {Typ: plan.Type{Id: int32(types.T_char)}, Name: "Variable_name"}, 370 {Typ: plan.Type{Id: int32(types.T_char)}, Name: "Value"}, 371 } 372 columns := make([]interface{}, len(cols)) 373 res := &frontend.MysqlResultSet{} 374 for i, col := range cols { 375 c := new(frontend.MysqlColumn) 376 c.SetName(col.Name) 377 c.SetOrgName(col.Name) 378 c.SetTable(col.Typ.Table) 379 c.SetOrgTable(col.Typ.Table) 380 c.SetAutoIncr(col.Typ.AutoIncr) 381 c.SetSchema("") 382 c.SetDecimal(col.Typ.Scale) 383 columns[i] = c 384 res.AddColumn(c) 385 } 386 for _, c := range columns { 387 if err := h.mysqlProto.SendColumnDefinitionPacket(context.TODO(), c.(frontend.Column), 3); err != nil { 388 _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeErrPayload(0, "", err.Error())) 389 return 390 } 391 } 392 _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeEOFPayload(0, h.status)) 393 for k, v := range h.server.globalVars { 394 row := make([]interface{}, 2) 395 row[0] = k 396 row[1] = v 397 res.AddRow(row) 398 } 399 ses := &frontend.Session{} 400 h.mysqlProto.SetSession(ses) 401 if err := h.mysqlProto.SendResultSetTextBatchRow(res, res.GetRowCount()); err != nil { 402 _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeErrPayload(0, "", err.Error())) 403 return 404 } 405 _ = h.mysqlProto.WritePacket(h.mysqlProto.MakeEOFPayload(0, h.status)) 406 } 407 408 func (h *testHandler) handleStartTxn() { 409 h.status |= frontend.SERVER_STATUS_IN_TRANS 410 h.handleCommon() 411 } 412 413 func (h *testHandler) handleStopTxn() { 414 h.status &= ^frontend.SERVER_STATUS_IN_TRANS 415 h.handleCommon() 416 } 417 418 func (s *testCNServer) Stop() error { 419 close(s.quit) 420 _ = s.listener.Close() 421 return nil 422 } 423 424 func TestServerConn_Create(t *testing.T) { 425 defer leaktest.AfterTest(t) 426 427 temp := os.TempDir() 428 addr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond()) 429 require.NoError(t, os.RemoveAll(addr)) 430 cn1 := testMakeCNServer("cn11", addr, 0, "", labelInfo{}) 431 cn1.reqLabel = newLabelInfo("t1", map[string]string{ 432 "k1": "v1", 433 "k2": "v2", 434 }) 435 // server not started. 436 sc, err := newServerConn(cn1, nil, nil, 0) 437 require.Error(t, err) 438 require.Nil(t, sc) 439 440 // start server. 441 tp := newTestProxyHandler(t) 442 defer tp.closeFn() 443 stopFn := startTestCNServer(t, tp.ctx, addr, nil) 444 defer func() { 445 require.NoError(t, stopFn()) 446 }() 447 448 sc, err = newServerConn(cn1, nil, nil, 0) 449 require.NoError(t, err) 450 require.NotNil(t, sc) 451 } 452 453 func TestServerConn_Connect(t *testing.T) { 454 defer leaktest.AfterTest(t) 455 temp := os.TempDir() 456 addr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond()) 457 require.NoError(t, os.RemoveAll(addr)) 458 cn1 := testMakeCNServer("cn11", addr, 0, "", labelInfo{}) 459 cn1.reqLabel = newLabelInfo("t1", map[string]string{ 460 "k1": "v1", 461 "k2": "v2", 462 }) 463 tp := newTestProxyHandler(t) 464 defer tp.closeFn() 465 stopFn := startTestCNServer(t, tp.ctx, addr, nil) 466 defer func() { 467 require.NoError(t, stopFn()) 468 }() 469 470 sc, err := newServerConn(cn1, nil, tp.re, 0) 471 require.NoError(t, err) 472 require.NotNil(t, sc) 473 _, err = sc.HandleHandshake(&frontend.Packet{Payload: []byte{1}}, time.Second*3) 474 require.NoError(t, err) 475 require.NotEqual(t, 0, int(sc.ConnID())) 476 err = sc.Close() 477 require.NoError(t, err) 478 } 479 480 func TestFakeCNServer(t *testing.T) { 481 defer leaktest.AfterTest(t) 482 483 tp := newTestProxyHandler(t) 484 defer tp.closeFn() 485 486 temp := os.TempDir() 487 addr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond()) 488 require.NoError(t, os.RemoveAll(addr)) 489 stopFn := startTestCNServer(t, tp.ctx, addr, nil) 490 defer func() { 491 require.NoError(t, stopFn()) 492 }() 493 494 li := labelInfo{} 495 cn1 := testMakeCNServer("cn11", addr, 0, "", labelInfo{}) 496 cn1.reqLabel = newLabelInfo("t1", map[string]string{ 497 "k1": "v1", 498 "k2": "v2", 499 }) 500 501 cleanup := testStartClient(t, tp, clientInfo{labelInfo: li}, cn1) 502 defer cleanup() 503 } 504 505 func TestServerConn_ExecStmt(t *testing.T) { 506 defer leaktest.AfterTest(t) 507 508 temp := os.TempDir() 509 addr := fmt.Sprintf("%s/%d.sock", temp, time.Now().Nanosecond()) 510 require.NoError(t, os.RemoveAll(addr)) 511 cn1 := testMakeCNServer("cn11", addr, 0, "", labelInfo{}) 512 cn1.reqLabel = newLabelInfo("t1", map[string]string{ 513 "k1": "v1", 514 "k2": "v2", 515 }) 516 tp := newTestProxyHandler(t) 517 defer tp.closeFn() 518 stopFn := startTestCNServer(t, tp.ctx, addr, nil) 519 defer func() { 520 require.NoError(t, stopFn()) 521 }() 522 523 sc, err := newServerConn(cn1, nil, tp.re, 0) 524 require.NoError(t, err) 525 require.NotNil(t, sc) 526 _, err = sc.HandleHandshake(&frontend.Packet{Payload: []byte{1}}, time.Second*3) 527 require.NoError(t, err) 528 require.NotEqual(t, 0, int(sc.ConnID())) 529 resp := make(chan []byte, 10) 530 _, err = sc.ExecStmt(internalStmt{cmdType: cmdQuery, s: "kill query"}, resp) 531 require.NoError(t, err) 532 res := <-resp 533 ok := isOKPacket(res) 534 require.True(t, ok) 535 } 536 537 func TestServerConnParseConnID(t *testing.T) { 538 t.Run("too short error", func(t *testing.T) { 539 s := &serverConn{} 540 p := &frontend.Packet{ 541 Payload: []byte{10}, 542 } 543 err := s.parseConnID(p) 544 require.Error(t, err) 545 }) 546 547 t.Run("no string", func(t *testing.T) { 548 s := &serverConn{} 549 p := &frontend.Packet{ 550 Length: 8, 551 Payload: []byte{10}, 552 } 553 p.Payload = append(p.Payload, []byte("v1")...) 554 err := s.parseConnID(p) 555 require.Error(t, err) 556 }) 557 558 t.Run("no conn id", func(t *testing.T) { 559 s := &serverConn{} 560 p := &frontend.Packet{ 561 Length: 5, 562 Payload: []byte{10}, 563 } 564 p.Payload = append(p.Payload, []byte("v1")...) 565 p.Payload = append(p.Payload, []byte{0}...) 566 p.Payload = append(p.Payload, []byte{2, 0, 0, 0}...) 567 err := s.parseConnID(p) 568 require.Error(t, err) 569 }) 570 571 t.Run("no error", func(t *testing.T) { 572 s := &serverConn{} 573 p := &frontend.Packet{ 574 Length: 8, 575 Payload: []byte{10}, 576 } 577 p.Payload = append(p.Payload, []byte("v1")...) 578 p.Payload = append(p.Payload, []byte{0}...) 579 p.Payload = append(p.Payload, []byte{2, 0, 0, 0}...) 580 err := s.parseConnID(p) 581 require.NoError(t, err) 582 }) 583 }