github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/exchange_test.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package rowexec 16 17 import ( 18 "context" 19 "fmt" 20 "io" 21 "sync/atomic" 22 "testing" 23 24 "github.com/stretchr/testify/assert" 25 "github.com/stretchr/testify/require" 26 27 "github.com/dolthub/go-mysql-server/sql" 28 "github.com/dolthub/go-mysql-server/sql/expression" 29 "github.com/dolthub/go-mysql-server/sql/plan" 30 "github.com/dolthub/go-mysql-server/sql/types" 31 ) 32 33 func TestExchange(t *testing.T) { 34 children := plan.NewProject( 35 []sql.Expression{ 36 expression.NewGetField(0, types.Text, "partition", false), 37 expression.NewArithmetic( 38 expression.NewGetField(1, types.Int64, "val", false), 39 expression.NewLiteral(int64(1), types.Int64), 40 "+", 41 ), 42 }, 43 plan.NewFilter( 44 expression.NewLessThan( 45 expression.NewGetField(1, types.Int64, "val", false), 46 expression.NewLiteral(int64(4), types.Int64), 47 ), 48 &partitionable{nil, 3, 6}, 49 ), 50 ) 51 52 expected := []sql.Row{ 53 {"1", int64(2)}, 54 {"1", int64(3)}, 55 {"1", int64(4)}, 56 {"2", int64(2)}, 57 {"2", int64(3)}, 58 {"2", int64(4)}, 59 {"3", int64(2)}, 60 {"3", int64(3)}, 61 {"3", int64(4)}, 62 } 63 64 for i := 1; i <= 4; i++ { 65 t.Run(fmt.Sprint(i), func(t *testing.T) { 66 require := require.New(t) 67 68 exchange := plan.NewExchange(i, children) 69 ctx := sql.NewEmptyContext() 70 iter, err := DefaultBuilder.Build(ctx, exchange, nil) 71 require.NoError(err) 72 73 rows, err := sql.RowIterToRows(ctx, iter) 74 require.NoError(err) 75 require.ElementsMatch(expected, rows) 76 }) 77 } 78 } 79 80 func TestExchangeCancelled(t *testing.T) { 81 children := plan.NewProject( 82 []sql.Expression{ 83 expression.NewGetField(0, types.Text, "partition", false), 84 expression.NewArithmetic( 85 expression.NewGetField(1, types.Int64, "val", false), 86 expression.NewLiteral(int64(1), types.Int64), 87 "+", 88 ), 89 }, 90 plan.NewFilter( 91 expression.NewLessThan( 92 expression.NewGetField(1, types.Int64, "val", false), 93 expression.NewLiteral(int64(4), types.Int64), 94 ), 95 &partitionable{nil, 3, 2048}, 96 ), 97 ) 98 99 exchange := plan.NewExchange(3, children) 100 require := require.New(t) 101 102 c, cancel := context.WithCancel(context.Background()) 103 ctx := sql.NewContext(c) 104 cancel() 105 106 iter, err := DefaultBuilder.Build(ctx, exchange, nil) 107 require.NoError(err) 108 109 _, err = iter.Next(ctx) 110 require.Equal(context.Canceled, err) 111 } 112 113 func TestExchangeIterPartitionsPanic(t *testing.T) { 114 ctx := sql.NewContext(context.Background()) 115 piter, err := (&partitionable{nil, 3, 2048}).Partitions(ctx) 116 assert.NoError(t, err) 117 closedCh := make(chan sql.Partition) 118 close(closedCh) 119 err = iterPartitions(ctx, piter, closedCh) 120 assert.Error(t, err) 121 assert.Contains(t, err.Error(), "panic") 122 123 openCh := make(chan sql.Partition) 124 err = iterPartitions(ctx, &partitionPanic{}, openCh) 125 assert.Error(t, err) 126 assert.Contains(t, err.Error(), "panic") 127 } 128 129 func TestExchangeIterPartitionRowsPanic(t *testing.T) { 130 ctx := sql.NewContext(context.Background()) 131 partitions := make(chan sql.Partition, 1) 132 partitions <- Partition("test") 133 err := iterPartitionRows(ctx, func(*sql.Context, sql.Partition) (sql.RowIter, error) { 134 return &rowIterPanic{}, nil 135 }, partitions, nil) 136 assert.Error(t, err) 137 assert.Contains(t, err.Error(), "panic") 138 139 closedCh := make(chan sql.Row) 140 close(closedCh) 141 partitions <- Partition("test") 142 err = iterPartitionRows(ctx, func(*sql.Context, sql.Partition) (sql.RowIter, error) { 143 return &partitionRows{Partition("test"), 10}, nil 144 }, partitions, closedCh) 145 assert.Error(t, err) 146 assert.Contains(t, err.Error(), "panic") 147 } 148 149 type partitionable struct { 150 sql.Node 151 partitions int 152 rowsPerPartition int 153 } 154 155 var _ sql.Table = partitionable{} 156 var _ sql.CollationCoercible = partitionable{} 157 158 // WithChildren implements the Node interface. 159 func (p *partitionable) WithChildren(children ...sql.Node) (sql.Node, error) { 160 if len(children) != 0 { 161 return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0) 162 } 163 164 return p, nil 165 } 166 167 // CheckPrivileges implements the interface sql.Node. 168 func (p *partitionable) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool { 169 return p.Node.CheckPrivileges(ctx, opChecker) 170 } 171 172 // CollationCoercibility implements the interface sql.CollationCoercible. 173 func (p partitionable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) { 174 return sql.GetCoercibility(ctx, p.Node) 175 } 176 177 func (partitionable) Children() []sql.Node { return nil } 178 179 func (p partitionable) Partitions(*sql.Context) (sql.PartitionIter, error) { 180 return &exchangePartitionIter{int32(p.partitions)}, nil 181 } 182 183 func (p partitionable) PartitionRows(_ *sql.Context, part sql.Partition) (sql.RowIter, error) { 184 return &partitionRows{part, int32(p.rowsPerPartition)}, nil 185 } 186 187 func (partitionable) Schema() sql.Schema { 188 return sql.Schema{ 189 {Name: "partition", Type: types.Text, Source: "foo"}, 190 {Name: "val", Type: types.Int64, Source: "foo"}, 191 } 192 } 193 194 func (partitionable) Collation() sql.CollationID { 195 return sql.Collation_Default 196 } 197 198 func (partitionable) Name() string { return "partitionable" } 199 200 type Partition string 201 202 func (p Partition) Key() []byte { 203 return []byte(p) 204 } 205 206 type exchangePartitionIter struct { 207 num int32 208 } 209 210 func (i *exchangePartitionIter) Next(*sql.Context) (sql.Partition, error) { 211 new := atomic.AddInt32(&i.num, -1) 212 if new < 0 { 213 return nil, io.EOF 214 } 215 216 return Partition(fmt.Sprint(new + 1)), nil 217 } 218 219 func (i *exchangePartitionIter) Close(*sql.Context) error { 220 atomic.StoreInt32(&i.num, -1) 221 return nil 222 } 223 224 type partitionRows struct { 225 sql.Partition 226 num int32 227 } 228 229 func (r *partitionRows) Next(*sql.Context) (sql.Row, error) { 230 new := atomic.AddInt32(&r.num, -1) 231 if new < 0 { 232 return nil, io.EOF 233 } 234 235 return sql.NewRow(string(r.Key()), int64(new+1)), nil 236 } 237 238 func (r *partitionRows) Close(*sql.Context) error { 239 atomic.StoreInt32(&r.num, -1) 240 return nil 241 } 242 243 type rowIterPanic struct { 244 } 245 246 func (*rowIterPanic) Next(*sql.Context) (sql.Row, error) { 247 panic("i panic") 248 } 249 250 func (*rowIterPanic) Close(*sql.Context) error { 251 return nil 252 } 253 254 type partitionPanic struct { 255 sql.Partition 256 closed bool 257 } 258 259 func (*partitionPanic) Next(*sql.Context) (sql.Partition, error) { 260 panic("partitionPanic.Next") 261 } 262 263 func (p *partitionPanic) Close(_ *sql.Context) error { 264 p.closed = true 265 return nil 266 }