github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/exchange_test.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rowexec
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"io"
    21  	"sync/atomic"
    22  	"testing"
    23  
    24  	"github.com/stretchr/testify/assert"
    25  	"github.com/stretchr/testify/require"
    26  
    27  	"github.com/dolthub/go-mysql-server/sql"
    28  	"github.com/dolthub/go-mysql-server/sql/expression"
    29  	"github.com/dolthub/go-mysql-server/sql/plan"
    30  	"github.com/dolthub/go-mysql-server/sql/types"
    31  )
    32  
    33  func TestExchange(t *testing.T) {
    34  	children := plan.NewProject(
    35  		[]sql.Expression{
    36  			expression.NewGetField(0, types.Text, "partition", false),
    37  			expression.NewArithmetic(
    38  				expression.NewGetField(1, types.Int64, "val", false),
    39  				expression.NewLiteral(int64(1), types.Int64),
    40  				"+",
    41  			),
    42  		},
    43  		plan.NewFilter(
    44  			expression.NewLessThan(
    45  				expression.NewGetField(1, types.Int64, "val", false),
    46  				expression.NewLiteral(int64(4), types.Int64),
    47  			),
    48  			&partitionable{nil, 3, 6},
    49  		),
    50  	)
    51  
    52  	expected := []sql.Row{
    53  		{"1", int64(2)},
    54  		{"1", int64(3)},
    55  		{"1", int64(4)},
    56  		{"2", int64(2)},
    57  		{"2", int64(3)},
    58  		{"2", int64(4)},
    59  		{"3", int64(2)},
    60  		{"3", int64(3)},
    61  		{"3", int64(4)},
    62  	}
    63  
    64  	for i := 1; i <= 4; i++ {
    65  		t.Run(fmt.Sprint(i), func(t *testing.T) {
    66  			require := require.New(t)
    67  
    68  			exchange := plan.NewExchange(i, children)
    69  			ctx := sql.NewEmptyContext()
    70  			iter, err := DefaultBuilder.Build(ctx, exchange, nil)
    71  			require.NoError(err)
    72  
    73  			rows, err := sql.RowIterToRows(ctx, iter)
    74  			require.NoError(err)
    75  			require.ElementsMatch(expected, rows)
    76  		})
    77  	}
    78  }
    79  
    80  func TestExchangeCancelled(t *testing.T) {
    81  	children := plan.NewProject(
    82  		[]sql.Expression{
    83  			expression.NewGetField(0, types.Text, "partition", false),
    84  			expression.NewArithmetic(
    85  				expression.NewGetField(1, types.Int64, "val", false),
    86  				expression.NewLiteral(int64(1), types.Int64),
    87  				"+",
    88  			),
    89  		},
    90  		plan.NewFilter(
    91  			expression.NewLessThan(
    92  				expression.NewGetField(1, types.Int64, "val", false),
    93  				expression.NewLiteral(int64(4), types.Int64),
    94  			),
    95  			&partitionable{nil, 3, 2048},
    96  		),
    97  	)
    98  
    99  	exchange := plan.NewExchange(3, children)
   100  	require := require.New(t)
   101  
   102  	c, cancel := context.WithCancel(context.Background())
   103  	ctx := sql.NewContext(c)
   104  	cancel()
   105  
   106  	iter, err := DefaultBuilder.Build(ctx, exchange, nil)
   107  	require.NoError(err)
   108  
   109  	_, err = iter.Next(ctx)
   110  	require.Equal(context.Canceled, err)
   111  }
   112  
   113  func TestExchangeIterPartitionsPanic(t *testing.T) {
   114  	ctx := sql.NewContext(context.Background())
   115  	piter, err := (&partitionable{nil, 3, 2048}).Partitions(ctx)
   116  	assert.NoError(t, err)
   117  	closedCh := make(chan sql.Partition)
   118  	close(closedCh)
   119  	err = iterPartitions(ctx, piter, closedCh)
   120  	assert.Error(t, err)
   121  	assert.Contains(t, err.Error(), "panic")
   122  
   123  	openCh := make(chan sql.Partition)
   124  	err = iterPartitions(ctx, &partitionPanic{}, openCh)
   125  	assert.Error(t, err)
   126  	assert.Contains(t, err.Error(), "panic")
   127  }
   128  
   129  func TestExchangeIterPartitionRowsPanic(t *testing.T) {
   130  	ctx := sql.NewContext(context.Background())
   131  	partitions := make(chan sql.Partition, 1)
   132  	partitions <- Partition("test")
   133  	err := iterPartitionRows(ctx, func(*sql.Context, sql.Partition) (sql.RowIter, error) {
   134  		return &rowIterPanic{}, nil
   135  	}, partitions, nil)
   136  	assert.Error(t, err)
   137  	assert.Contains(t, err.Error(), "panic")
   138  
   139  	closedCh := make(chan sql.Row)
   140  	close(closedCh)
   141  	partitions <- Partition("test")
   142  	err = iterPartitionRows(ctx, func(*sql.Context, sql.Partition) (sql.RowIter, error) {
   143  		return &partitionRows{Partition("test"), 10}, nil
   144  	}, partitions, closedCh)
   145  	assert.Error(t, err)
   146  	assert.Contains(t, err.Error(), "panic")
   147  }
   148  
   149  type partitionable struct {
   150  	sql.Node
   151  	partitions       int
   152  	rowsPerPartition int
   153  }
   154  
   155  var _ sql.Table = partitionable{}
   156  var _ sql.CollationCoercible = partitionable{}
   157  
   158  // WithChildren implements the Node interface.
   159  func (p *partitionable) WithChildren(children ...sql.Node) (sql.Node, error) {
   160  	if len(children) != 0 {
   161  		return nil, sql.ErrInvalidChildrenNumber.New(p, len(children), 0)
   162  	}
   163  
   164  	return p, nil
   165  }
   166  
   167  // CheckPrivileges implements the interface sql.Node.
   168  func (p *partitionable) CheckPrivileges(ctx *sql.Context, opChecker sql.PrivilegedOperationChecker) bool {
   169  	return p.Node.CheckPrivileges(ctx, opChecker)
   170  }
   171  
   172  // CollationCoercibility implements the interface sql.CollationCoercible.
   173  func (p partitionable) CollationCoercibility(ctx *sql.Context) (collation sql.CollationID, coercibility byte) {
   174  	return sql.GetCoercibility(ctx, p.Node)
   175  }
   176  
   177  func (partitionable) Children() []sql.Node { return nil }
   178  
   179  func (p partitionable) Partitions(*sql.Context) (sql.PartitionIter, error) {
   180  	return &exchangePartitionIter{int32(p.partitions)}, nil
   181  }
   182  
   183  func (p partitionable) PartitionRows(_ *sql.Context, part sql.Partition) (sql.RowIter, error) {
   184  	return &partitionRows{part, int32(p.rowsPerPartition)}, nil
   185  }
   186  
   187  func (partitionable) Schema() sql.Schema {
   188  	return sql.Schema{
   189  		{Name: "partition", Type: types.Text, Source: "foo"},
   190  		{Name: "val", Type: types.Int64, Source: "foo"},
   191  	}
   192  }
   193  
   194  func (partitionable) Collation() sql.CollationID {
   195  	return sql.Collation_Default
   196  }
   197  
   198  func (partitionable) Name() string { return "partitionable" }
   199  
   200  type Partition string
   201  
   202  func (p Partition) Key() []byte {
   203  	return []byte(p)
   204  }
   205  
   206  type exchangePartitionIter struct {
   207  	num int32
   208  }
   209  
   210  func (i *exchangePartitionIter) Next(*sql.Context) (sql.Partition, error) {
   211  	new := atomic.AddInt32(&i.num, -1)
   212  	if new < 0 {
   213  		return nil, io.EOF
   214  	}
   215  
   216  	return Partition(fmt.Sprint(new + 1)), nil
   217  }
   218  
   219  func (i *exchangePartitionIter) Close(*sql.Context) error {
   220  	atomic.StoreInt32(&i.num, -1)
   221  	return nil
   222  }
   223  
   224  type partitionRows struct {
   225  	sql.Partition
   226  	num int32
   227  }
   228  
   229  func (r *partitionRows) Next(*sql.Context) (sql.Row, error) {
   230  	new := atomic.AddInt32(&r.num, -1)
   231  	if new < 0 {
   232  		return nil, io.EOF
   233  	}
   234  
   235  	return sql.NewRow(string(r.Key()), int64(new+1)), nil
   236  }
   237  
   238  func (r *partitionRows) Close(*sql.Context) error {
   239  	atomic.StoreInt32(&r.num, -1)
   240  	return nil
   241  }
   242  
   243  type rowIterPanic struct {
   244  }
   245  
   246  func (*rowIterPanic) Next(*sql.Context) (sql.Row, error) {
   247  	panic("i panic")
   248  }
   249  
   250  func (*rowIterPanic) Close(*sql.Context) error {
   251  	return nil
   252  }
   253  
   254  type partitionPanic struct {
   255  	sql.Partition
   256  	closed bool
   257  }
   258  
   259  func (*partitionPanic) Next(*sql.Context) (sql.Partition, error) {
   260  	panic("partitionPanic.Next")
   261  }
   262  
   263  func (p *partitionPanic) Close(_ *sql.Context) error {
   264  	p.closed = true
   265  	return nil
   266  }