github.com/dolthub/go-mysql-server@v0.18.0/processlist_test.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package sqle 16 17 import ( 18 "context" 19 "sort" 20 "testing" 21 22 "github.com/stretchr/testify/require" 23 24 "github.com/dolthub/go-mysql-server/sql" 25 ) 26 27 func TestProcessList(t *testing.T) { 28 require := require.New(t) 29 30 clientHostOne := "127.0.0.1:34567" 31 clientHostTwo := "127.0.0.1:34568" 32 p := NewProcessList() 33 p.AddConnection(1, clientHostOne) 34 sess := sql.NewBaseSessionWithClientServer("0.0.0.0:3306", sql.Client{Address: clientHostOne, User: "foo"}, 1) 35 sess.SetCurrentDatabase("test_db") 36 p.ConnectionReady(sess) 37 ctx := sql.NewContext(context.Background(), sql.WithPid(1), sql.WithSession(sess)) 38 ctx, err := p.BeginQuery(ctx, "SELECT foo") 39 require.NoError(err) 40 41 require.Equal(uint64(1), ctx.Pid()) 42 require.Len(p.procs, 1) 43 44 p.AddTableProgress(ctx.Pid(), "a", 5) 45 p.AddTableProgress(ctx.Pid(), "b", 6) 46 47 expectedProcess := &sql.Process{ 48 QueryPid: 1, 49 Connection: 1, 50 Host: clientHostOne, 51 Progress: map[string]sql.TableProgress{ 52 "a": {sql.Progress{Name: "a", Done: 0, Total: 5}, map[string]sql.PartitionProgress{}}, 53 "b": {sql.Progress{Name: "b", Done: 0, Total: 6}, map[string]sql.PartitionProgress{}}, 54 }, 55 User: "foo", 56 Query: "SELECT foo", 57 Command: sql.ProcessCommandQuery, 58 StartedAt: p.procs[1].StartedAt, 59 Database: "test_db", 60 } 61 require.NotNil(p.procs[1].Kill) 62 p.procs[1].Kill = nil 63 require.Equal(expectedProcess, p.procs[1]) 64 65 p.AddPartitionProgress(ctx.Pid(), "b", "b-1", -1) 66 p.AddPartitionProgress(ctx.Pid(), "b", "b-2", -1) 67 p.AddPartitionProgress(ctx.Pid(), "b", "b-3", -1) 68 69 p.UpdatePartitionProgress(ctx.Pid(), "b", "b-2", 1) 70 71 p.RemovePartitionProgress(ctx.Pid(), "b", "b-3") 72 73 expectedProgress := map[string]sql.TableProgress{ 74 "a": {sql.Progress{Name: "a", Total: 5}, map[string]sql.PartitionProgress{}}, 75 "b": {sql.Progress{Name: "b", Total: 6}, map[string]sql.PartitionProgress{ 76 "b-1": {sql.Progress{Name: "b-1", Done: 0, Total: -1}}, 77 "b-2": {sql.Progress{Name: "b-2", Done: 1, Total: -1}}, 78 }}, 79 } 80 require.Equal(expectedProgress, p.procs[1].Progress) 81 82 p.AddConnection(2, clientHostTwo) 83 sess = sql.NewBaseSessionWithClientServer("0.0.0.0:3306", sql.Client{Address: clientHostTwo, User: "foo"}, 2) 84 p.ConnectionReady(sess) 85 ctx = sql.NewContext(context.Background(), sql.WithPid(2), sql.WithSession(sess)) 86 ctx, err = p.BeginQuery(ctx, "SELECT bar") 87 require.NoError(err) 88 89 p.AddTableProgress(ctx.Pid(), "foo", 2) 90 91 require.Equal(uint64(2), ctx.Pid()) 92 require.Len(p.procs, 2) 93 94 p.UpdateTableProgress(1, "a", 3) 95 p.UpdateTableProgress(1, "a", 1) 96 p.UpdateTableProgress(1, "b", 2) 97 p.UpdateTableProgress(2, "foo", 1) 98 99 require.Equal(int64(4), p.procs[1].Progress["a"].Done) 100 require.Equal(int64(2), p.procs[1].Progress["b"].Done) 101 require.Equal(int64(1), p.procs[2].Progress["foo"].Done) 102 103 var expected []sql.Process 104 for _, p := range p.procs { 105 np := *p 106 np.Kill = nil 107 expected = append(expected, np) 108 } 109 110 result := p.Processes() 111 for i := range result { 112 result[i].Kill = nil 113 } 114 115 sortById(expected) 116 sortById(result) 117 require.Equal(expected, result) 118 119 p.EndQuery(ctx) 120 121 require.Len(p.procs, 2) 122 proc, ok := p.procs[2] 123 require.True(ok) 124 require.Equal(sql.ProcessCommandSleep, proc.Command) 125 } 126 127 func sortById(slice []sql.Process) { 128 sort.Slice(slice, func(i, j int) bool { 129 return slice[i].Connection < slice[j].Connection 130 }) 131 } 132 133 func TestKillConnection(t *testing.T) { 134 pl := NewProcessList() 135 136 pl.AddConnection(1, "") 137 pl.AddConnection(2, "") 138 s1 := sql.NewBaseSessionWithClientServer("", sql.Client{}, 1) 139 s2 := sql.NewBaseSessionWithClientServer("", sql.Client{}, 2) 140 pl.ConnectionReady(s1) 141 pl.ConnectionReady(s2) 142 143 _, err := pl.BeginQuery( 144 sql.NewContext(context.Background(), sql.WithPid(3), sql.WithSession(s1)), 145 "foo", 146 ) 147 require.NoError(t, err) 148 149 _, err = pl.BeginQuery( 150 sql.NewContext(context.Background(), sql.WithPid(4), sql.WithSession(s2)), 151 "foo", 152 ) 153 require.NoError(t, err) 154 155 var killed = make(map[uint64]bool) 156 157 pl.procs[1].Kill = func() { 158 killed[1] = true 159 } 160 pl.procs[2].Kill = func() { 161 killed[2] = true 162 } 163 164 pl.Kill(1) 165 require.Len(t, pl.procs, 2) 166 167 require.True(t, killed[1]) 168 require.False(t, killed[2]) 169 }