github.com/dolthub/go-mysql-server@v0.18.0/processlist_test.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sqle
    16  
    17  import (
    18  	"context"
    19  	"sort"
    20  	"testing"
    21  
    22  	"github.com/stretchr/testify/require"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  )
    26  
    27  func TestProcessList(t *testing.T) {
    28  	require := require.New(t)
    29  
    30  	clientHostOne := "127.0.0.1:34567"
    31  	clientHostTwo := "127.0.0.1:34568"
    32  	p := NewProcessList()
    33  	p.AddConnection(1, clientHostOne)
    34  	sess := sql.NewBaseSessionWithClientServer("0.0.0.0:3306", sql.Client{Address: clientHostOne, User: "foo"}, 1)
    35  	sess.SetCurrentDatabase("test_db")
    36  	p.ConnectionReady(sess)
    37  	ctx := sql.NewContext(context.Background(), sql.WithPid(1), sql.WithSession(sess))
    38  	ctx, err := p.BeginQuery(ctx, "SELECT foo")
    39  	require.NoError(err)
    40  
    41  	require.Equal(uint64(1), ctx.Pid())
    42  	require.Len(p.procs, 1)
    43  
    44  	p.AddTableProgress(ctx.Pid(), "a", 5)
    45  	p.AddTableProgress(ctx.Pid(), "b", 6)
    46  
    47  	expectedProcess := &sql.Process{
    48  		QueryPid:   1,
    49  		Connection: 1,
    50  		Host:       clientHostOne,
    51  		Progress: map[string]sql.TableProgress{
    52  			"a": {sql.Progress{Name: "a", Done: 0, Total: 5}, map[string]sql.PartitionProgress{}},
    53  			"b": {sql.Progress{Name: "b", Done: 0, Total: 6}, map[string]sql.PartitionProgress{}},
    54  		},
    55  		User:      "foo",
    56  		Query:     "SELECT foo",
    57  		Command:   sql.ProcessCommandQuery,
    58  		StartedAt: p.procs[1].StartedAt,
    59  		Database:  "test_db",
    60  	}
    61  	require.NotNil(p.procs[1].Kill)
    62  	p.procs[1].Kill = nil
    63  	require.Equal(expectedProcess, p.procs[1])
    64  
    65  	p.AddPartitionProgress(ctx.Pid(), "b", "b-1", -1)
    66  	p.AddPartitionProgress(ctx.Pid(), "b", "b-2", -1)
    67  	p.AddPartitionProgress(ctx.Pid(), "b", "b-3", -1)
    68  
    69  	p.UpdatePartitionProgress(ctx.Pid(), "b", "b-2", 1)
    70  
    71  	p.RemovePartitionProgress(ctx.Pid(), "b", "b-3")
    72  
    73  	expectedProgress := map[string]sql.TableProgress{
    74  		"a": {sql.Progress{Name: "a", Total: 5}, map[string]sql.PartitionProgress{}},
    75  		"b": {sql.Progress{Name: "b", Total: 6}, map[string]sql.PartitionProgress{
    76  			"b-1": {sql.Progress{Name: "b-1", Done: 0, Total: -1}},
    77  			"b-2": {sql.Progress{Name: "b-2", Done: 1, Total: -1}},
    78  		}},
    79  	}
    80  	require.Equal(expectedProgress, p.procs[1].Progress)
    81  
    82  	p.AddConnection(2, clientHostTwo)
    83  	sess = sql.NewBaseSessionWithClientServer("0.0.0.0:3306", sql.Client{Address: clientHostTwo, User: "foo"}, 2)
    84  	p.ConnectionReady(sess)
    85  	ctx = sql.NewContext(context.Background(), sql.WithPid(2), sql.WithSession(sess))
    86  	ctx, err = p.BeginQuery(ctx, "SELECT bar")
    87  	require.NoError(err)
    88  
    89  	p.AddTableProgress(ctx.Pid(), "foo", 2)
    90  
    91  	require.Equal(uint64(2), ctx.Pid())
    92  	require.Len(p.procs, 2)
    93  
    94  	p.UpdateTableProgress(1, "a", 3)
    95  	p.UpdateTableProgress(1, "a", 1)
    96  	p.UpdateTableProgress(1, "b", 2)
    97  	p.UpdateTableProgress(2, "foo", 1)
    98  
    99  	require.Equal(int64(4), p.procs[1].Progress["a"].Done)
   100  	require.Equal(int64(2), p.procs[1].Progress["b"].Done)
   101  	require.Equal(int64(1), p.procs[2].Progress["foo"].Done)
   102  
   103  	var expected []sql.Process
   104  	for _, p := range p.procs {
   105  		np := *p
   106  		np.Kill = nil
   107  		expected = append(expected, np)
   108  	}
   109  
   110  	result := p.Processes()
   111  	for i := range result {
   112  		result[i].Kill = nil
   113  	}
   114  
   115  	sortById(expected)
   116  	sortById(result)
   117  	require.Equal(expected, result)
   118  
   119  	p.EndQuery(ctx)
   120  
   121  	require.Len(p.procs, 2)
   122  	proc, ok := p.procs[2]
   123  	require.True(ok)
   124  	require.Equal(sql.ProcessCommandSleep, proc.Command)
   125  }
   126  
   127  func sortById(slice []sql.Process) {
   128  	sort.Slice(slice, func(i, j int) bool {
   129  		return slice[i].Connection < slice[j].Connection
   130  	})
   131  }
   132  
   133  func TestKillConnection(t *testing.T) {
   134  	pl := NewProcessList()
   135  
   136  	pl.AddConnection(1, "")
   137  	pl.AddConnection(2, "")
   138  	s1 := sql.NewBaseSessionWithClientServer("", sql.Client{}, 1)
   139  	s2 := sql.NewBaseSessionWithClientServer("", sql.Client{}, 2)
   140  	pl.ConnectionReady(s1)
   141  	pl.ConnectionReady(s2)
   142  
   143  	_, err := pl.BeginQuery(
   144  		sql.NewContext(context.Background(), sql.WithPid(3), sql.WithSession(s1)),
   145  		"foo",
   146  	)
   147  	require.NoError(t, err)
   148  
   149  	_, err = pl.BeginQuery(
   150  		sql.NewContext(context.Background(), sql.WithPid(4), sql.WithSession(s2)),
   151  		"foo",
   152  	)
   153  	require.NoError(t, err)
   154  
   155  	var killed = make(map[uint64]bool)
   156  
   157  	pl.procs[1].Kill = func() {
   158  		killed[1] = true
   159  	}
   160  	pl.procs[2].Kill = func() {
   161  		killed[2] = true
   162  	}
   163  
   164  	pl.Kill(1)
   165  	require.Len(t, pl.procs, 2)
   166  
   167  	require.True(t, killed[1])
   168  	require.False(t, killed[2])
   169  }