github.com/matrixorigin/matrixone@v1.2.0/pkg/proxy/client_conn_test.go (about) 1 // Copyright 2021 - 2023 Matrix Origin 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package proxy 16 17 import ( 18 "context" 19 "encoding/binary" 20 "net" 21 "strings" 22 "sync" 23 "testing" 24 "time" 25 26 "github.com/fagongzi/goetty/v2" 27 "github.com/fagongzi/goetty/v2/buf" 28 "github.com/lni/goutils/leaktest" 29 "github.com/matrixorigin/matrixone/pkg/common/moerr" 30 "github.com/matrixorigin/matrixone/pkg/common/runtime" 31 "github.com/matrixorigin/matrixone/pkg/frontend" 32 "github.com/stretchr/testify/require" 33 ) 34 35 type mockNetConn struct { 36 localIP string 37 localPort int 38 remoteIP string 39 remotePort int 40 c net.Conn 41 } 42 43 func newMockNetConn( 44 localIP string, localPort int, remoteIP string, remotePort int, c net.Conn, 45 ) *mockNetConn { 46 return &mockNetConn{ 47 localIP: localIP, 48 localPort: localPort, 49 remoteIP: remoteIP, 50 remotePort: remotePort, 51 c: c, 52 } 53 } 54 55 func (c *mockNetConn) SetRemote(addr string) { 56 c.remoteIP = addr 57 } 58 59 func (c *mockNetConn) Read(b []byte) (n int, err error) { 60 return c.c.Read(b) 61 } 62 63 func (c *mockNetConn) Write(b []byte) (n int, err error) { 64 return c.c.Write(b) 65 } 66 67 func (c *mockNetConn) Close() error { 68 return nil 69 } 70 71 func (c *mockNetConn) LocalAddr() net.Addr { 72 return &net.TCPAddr{ 73 IP: []byte(c.localIP), 74 Port: c.localPort, 75 } 76 } 77 78 func (c *mockNetConn) RemoteAddr() net.Addr { 79 return &net.TCPAddr{ 80 IP: []byte(c.remoteIP), 81 Port: c.remotePort, 82 } 83 } 84 85 func (c *mockNetConn) SetDeadline(t time.Time) error { 86 return nil 87 } 88 89 func (c *mockNetConn) SetReadDeadline(t time.Time) error { 90 return nil 91 } 92 93 func (c *mockNetConn) SetWriteDeadline(t time.Time) error { 94 return nil 95 } 96 97 type mockClientConn struct { 98 conn net.Conn 99 tenant Tenant 100 clientInfo clientInfo // need to set it explicitly 101 router Router 102 tun *tunnel 103 redoStmts []internalStmt 104 } 105 106 var _ ClientConn = (*mockClientConn)(nil) 107 108 func newMockClientConn( 109 conn net.Conn, tenant Tenant, ci clientInfo, router Router, tun *tunnel, 110 ) ClientConn { 111 c := &mockClientConn{ 112 conn: conn, 113 tenant: tenant, 114 clientInfo: ci, 115 router: router, 116 tun: tun, 117 } 118 return c 119 } 120 121 func (c *mockClientConn) ConnID() uint32 { return 0 } 122 func (c *mockClientConn) GetSalt() []byte { return nil } 123 func (c *mockClientConn) GetHandshakePack() *frontend.Packet { return nil } 124 func (c *mockClientConn) RawConn() net.Conn { return c.conn } 125 func (c *mockClientConn) GetTenant() Tenant { return c.tenant } 126 func (c *mockClientConn) SendErrToClient(err error) {} 127 func (c *mockClientConn) BuildConnWithServer(_ string) (ServerConn, error) { 128 cn, err := c.router.Route(context.TODO(), c.clientInfo, nil) 129 if err != nil { 130 return nil, err 131 } 132 cn.salt = testSlat 133 sc, _, err := c.router.Connect(cn, testPacket, c.tun) 134 if err != nil { 135 return nil, err 136 } 137 // Set the use defined variables, including session variables and user variables. 138 for _, stmt := range c.redoStmts { 139 if _, err := sc.ExecStmt(stmt, nil); err != nil { 140 return nil, err 141 } 142 } 143 return sc, nil 144 } 145 146 func (c *mockClientConn) HandleEvent(ctx context.Context, e IEvent, resp chan<- []byte) error { 147 switch ev := e.(type) { 148 case *killQueryEvent: 149 cn, err := c.router.SelectByConnID(ev.connID) 150 if err != nil { 151 sendResp([]byte(err.Error()), resp) 152 return err 153 } 154 sendResp([]byte(cn.addr), resp) 155 return nil 156 case *setVarEvent: 157 c.redoStmts = append(c.redoStmts, internalStmt{cmdType: cmdQuery, s: ev.stmt}) 158 sendResp([]byte("ok"), resp) 159 return nil 160 default: 161 sendResp([]byte("type not supported"), resp) 162 return moerr.NewInternalErrorNoCtx("type not supported") 163 } 164 } 165 func (c *mockClientConn) Close() error { return nil } 166 167 func testStartClient(t *testing.T, tp *testProxyHandler, ci clientInfo, cn *CNServer) func() { 168 if cn.salt == nil || len(cn.salt) != 20 { 169 cn.salt = testSlat 170 } 171 clientProxy, client := net.Pipe() 172 go func(ctx context.Context) { 173 b := make([]byte, 10) 174 for { 175 select { 176 case <-ctx.Done(): 177 return 178 default: 179 } 180 _, _ = client.Read(b) 181 } 182 }(tp.ctx) 183 tu := newTunnel(tp.ctx, tp.logger, tp.counterSet) 184 sc, _, err := tp.ru.Connect(cn, testPacket, tu) 185 require.NoError(t, err) 186 cc := newMockClientConn(clientProxy, "t1", ci, tp.ru, tu) 187 err = tu.run(cc, sc) 188 require.NoError(t, err) 189 select { 190 case err := <-tu.errC: 191 t.Fatalf("tunnel error: %v", err) 192 default: 193 } 194 return func() { 195 _ = tu.Close() 196 } 197 } 198 199 func testStartNClients(t *testing.T, tp *testProxyHandler, ci clientInfo, cn *CNServer, n int) func() { 200 var cleanFns []func() 201 for i := 0; i < n; i++ { 202 c := testStartClient(t, tp, ci, cn) 203 cleanFns = append(cleanFns, c) 204 } 205 return func() { 206 for _, f := range cleanFns { 207 f() 208 } 209 } 210 } 211 212 func TestAccountParser(t *testing.T) { 213 cases := []struct { 214 str string 215 tenant string 216 username string 217 hasErr bool 218 }{ 219 { 220 str: "t1:u1", 221 tenant: "t1", 222 username: "u1", 223 hasErr: false, 224 }, 225 { 226 str: "t1#u1", 227 tenant: "t1", 228 username: "u1", 229 hasErr: false, 230 }, 231 { 232 str: ":u1", 233 tenant: "", 234 username: "", 235 hasErr: true, 236 }, 237 { 238 str: "a:", 239 tenant: "", 240 username: "", 241 hasErr: true, 242 }, 243 { 244 str: "u1", 245 tenant: frontend.GetDefaultTenant(), 246 username: "u1", 247 hasErr: false, 248 }, 249 { 250 str: "t1:u1?a=1", 251 tenant: "t1", 252 username: "u1", 253 hasErr: false, 254 }, 255 } 256 for _, item := range cases { 257 a := clientInfo{} 258 err := a.parse(item.str) 259 if item.hasErr { 260 require.Error(t, err) 261 } else { 262 require.NoError(t, err) 263 } 264 require.Equal(t, string(a.labelInfo.Tenant), item.tenant) 265 require.Equal(t, a.username, item.username) 266 } 267 } 268 269 func createNewClientConn(t *testing.T) (ClientConn, func()) { 270 s := goetty.NewIOSession(goetty.WithSessionConn(1, 271 newMockNetConn("127.0.0.1", 30001, 272 "127.0.0.1", 30010, nil)), 273 goetty.WithSessionCodec(WithProxyProtocolCodec(frontend.NewSqlCodec()))) 274 ctx, cancel := context.WithCancel(context.Background()) 275 clientBaseConnID = 90 276 rt := runtime.DefaultRuntime() 277 logger := rt.Logger() 278 cs := newCounterSet() 279 cc, err := newClientConn(ctx, &Config{}, logger, cs, s, nil, nil, nil, nil, nil) 280 require.NoError(t, err) 281 require.NotNil(t, cc) 282 return cc, func() { 283 cancel() 284 _ = cc.Close() 285 } 286 } 287 288 func TestNewClientConn(t *testing.T) { 289 cc, cleanup := createNewClientConn(t) 290 defer cleanup() 291 require.Equal(t, 91, int(cc.ConnID())) 292 require.Equal(t, 20, len(cc.GetSalt())) 293 require.NotNil(t, cc.RawConn()) 294 } 295 296 func makeClientHandshakeResp() []byte { 297 payload := make([]byte, 200) 298 pos := 0 299 copy(payload[pos:], []byte{141, 162, 10, 0}) // Capabilities Flags 300 pos += 4 301 copy(payload[pos:], []byte{0, 0, 0, 0}) // maximum packet size 302 pos += 4 303 payload[pos] = 45 // client charset 304 pos += 1 305 pos += 23 // filler 306 username := "tenant1:user1" 307 copy(payload[pos:], username) // login username 308 pos += len(username) 309 payload[pos] = 0 // the end of username 310 pos += 1 311 payload[pos] = 20 // length of auth response 312 pos += 1 313 pos += 20 // auth response 314 dbname := "db1" 315 copy(payload[pos:], dbname) // db name 316 pos += len(dbname) 317 payload[pos] = 0 // end of db name 318 pos += 1 319 plugin := "mysql_native_password" 320 copy(payload[pos:], plugin) 321 pos += 1 + len(plugin) 322 data := make([]byte, pos+4) 323 data[0] = uint8(pos) 324 data[1] = uint8(pos >> 8) 325 data[2] = uint8(pos >> 16) 326 data[3] = 1 327 copy(data[4:], payload) 328 return data 329 } 330 331 func TestClientConn_ConnectToBackend(t *testing.T) { 332 defer leaktest.AfterTest(t)() 333 334 runtime.SetupProcessLevelRuntime(runtime.DefaultRuntime()) 335 rt := runtime.DefaultRuntime() 336 logger := rt.Logger() 337 338 t.Run("cannot connect", func(t *testing.T) { 339 nilC := (*clientConn)(nil) 340 require.Equal(t, "", string(nilC.GetTenant())) 341 require.Nil(t, nilC.RawConn()) 342 343 cc := &clientConn{ 344 log: logger, 345 } 346 cc.testHelper.connectToBackend = func() (ServerConn, error) { 347 return nil, moerr.NewInternalErrorNoCtx("123 456") 348 } 349 350 sc, err := cc.BuildConnWithServer("aaa") 351 require.ErrorContains(t, err, "123 456") 352 require.Nil(t, sc) 353 }) 354 355 t.Run("ok connect", func(t *testing.T) { 356 local, remote := net.Pipe() 357 require.NotNil(t, local) 358 require.NotNil(t, remote) 359 360 cc, cleanup := createNewClientConn(t) 361 defer cleanup() 362 c, ok := cc.(*clientConn) 363 require.True(t, ok) 364 require.NotNil(t, c) 365 c.conn.UseConn(local) 366 require.Equal(t, "", string(cc.GetTenant())) 367 368 var wg sync.WaitGroup 369 wg.Add(1) 370 go func() { 371 defer wg.Done() 372 b := make([]byte, 100) 373 // client reads init handshake. 374 n, err := remote.Read(b) 375 require.NoError(t, err) 376 require.NotEqual(t, 0, n) 377 378 // client sends handshake resp. 379 resp := makeClientHandshakeResp() 380 n, err = remote.Write(resp) 381 require.NoError(t, err) 382 require.Equal(t, len(resp), n) 383 }() 384 385 _, err := cc.BuildConnWithServer("") 386 require.Error(t, err) // just test client, no router set 387 require.Equal(t, "tenant1", string(cc.GetTenant())) 388 require.NotNil(t, cc.GetHandshakePack()) 389 wg.Wait() 390 }) 391 } 392 393 func TestClientConn_ReadPacket(t *testing.T) { 394 defer leaktest.AfterTest(t)() 395 396 cc, cleanup := createNewClientConn(t) 397 defer cleanup() 398 c, ok := cc.(*clientConn) 399 require.True(t, ok) 400 require.NotNil(t, c) 401 402 local, remote := net.Pipe() 403 require.NotNil(t, local) 404 require.NotNil(t, remote) 405 406 var wg sync.WaitGroup 407 wg.Add(1) 408 go func() { 409 defer wg.Done() 410 addr := &ProxyAddr{ 411 SourceAddress: []byte{10, 10, 10, 10}, 412 SourcePort: 1000, 413 TargetAddress: []byte{20, 20, 20, 20}, 414 TargetPort: 2000, 415 } 416 417 b := buf.NewByteBuf(1000) 418 419 b.WriteString(ProxyProtocolV2Signature) 420 err := b.WriteByte(0) 421 require.NoError(t, err) 422 err = b.WriteByte(0) 423 require.NoError(t, err) 424 b.WriteUint16(12) 425 n, err := b.Write(addr.SourceAddress) 426 require.Equal(t, 4, n) 427 require.NoError(t, err) 428 n, err = b.Write(addr.TargetAddress) 429 require.Equal(t, 4, n) 430 require.NoError(t, err) 431 b.WriteUint16(addr.SourcePort) 432 b.WriteUint16(addr.TargetPort) 433 434 n, d := b.ReadAll() 435 require.Equal(t, 28, n) 436 err = binary.Write(remote, binary.BigEndian, d) 437 require.NoError(t, err) 438 439 // little endian 440 err = b.WriteByte(9) 441 require.NoError(t, err) 442 err = b.WriteByte(0) 443 require.NoError(t, err) 444 err = b.WriteByte(0) 445 require.NoError(t, err) 446 err = b.WriteByte(0) 447 require.NoError(t, err) 448 err = b.WriteByte(3) 449 require.NoError(t, err) 450 b.WriteString("select 1") 451 452 n, d = b.ReadAll() 453 require.Equal(t, 13, n) 454 err = binary.Write(remote, binary.LittleEndian, d) 455 require.NoError(t, err) 456 }() 457 458 c.conn.UseConn(local) 459 ret, err := c.readPacket() 460 require.NoError(t, err) 461 require.NotNil(t, ret) 462 require.Equal(t, 9, int(ret.Length)) 463 require.Equal(t, 0, int(ret.SequenceID)) 464 require.Equal(t, 3, int(ret.Payload[0])) 465 require.Equal(t, "select 1", string(ret.Payload[1:])) 466 467 wg.Wait() 468 } 469 470 func TestClientConn_ConnID(t *testing.T) { 471 parallel := 100 472 clientBaseConnID = 1 473 var wg sync.WaitGroup 474 for i := 0; i < parallel; i++ { 475 wg.Add(1) 476 go func() { 477 nextClientConnID() 478 defer wg.Done() 479 }() 480 } 481 wg.Wait() 482 require.Equal(t, 101, int(clientBaseConnID)) 483 } 484 485 func TestClientConn_SendErrToClient(t *testing.T) { 486 local, remote := net.Pipe() 487 require.NotNil(t, local) 488 require.NotNil(t, remote) 489 490 cc, cleanup := createNewClientConn(t) 491 defer cleanup() 492 c, ok := cc.(*clientConn) 493 require.True(t, ok) 494 require.NotNil(t, c) 495 c.conn.UseConn(local) 496 require.Equal(t, "", string(cc.GetTenant())) 497 498 var wg sync.WaitGroup 499 wg.Add(1) 500 go func() { 501 defer wg.Done() 502 b := make([]byte, 100) 503 // client reads init handshake. 504 n, err := remote.Read(b) 505 require.NoError(t, err) 506 require.NotEqual(t, 0, n) 507 508 // client sends handshake resp. 509 resp := makeClientHandshakeResp() 510 n, err = remote.Write(resp) 511 require.NoError(t, err) 512 require.Equal(t, len(resp), n) 513 514 n, err = remote.Read(b) 515 require.NoError(t, err) 516 require.Equal(t, 33, n) 517 require.True(t, strings.Contains(string(b[4+1+2+1+5:n]), "internal error: msg1")) 518 }() 519 520 _, err := cc.BuildConnWithServer("") 521 require.Error(t, err) // just test client, no router set 522 require.Equal(t, "tenant1", string(cc.GetTenant())) 523 require.NotNil(t, cc.GetHandshakePack()) 524 cc.SendErrToClient(moerr.NewInternalErrorNoCtx("msg1")) 525 wg.Wait() 526 }