github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/compile/scope_test.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package compile
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"testing"
    21  	"time"
    22  
    23  	"github.com/matrixorigin/matrixone/pkg/catalog"
    24  	"github.com/matrixorigin/matrixone/pkg/defines"
    25  
    26  	"github.com/matrixorigin/matrixone/pkg/common/buffer"
    27  	"github.com/matrixorigin/matrixone/pkg/common/morpc"
    28  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    29  	"github.com/matrixorigin/matrixone/pkg/pb/pipeline"
    30  	"github.com/matrixorigin/matrixone/pkg/pb/plan"
    31  	"github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect/mysql"
    32  	plan2 "github.com/matrixorigin/matrixone/pkg/sql/plan"
    33  	"github.com/matrixorigin/matrixone/pkg/testutil"
    34  	"github.com/matrixorigin/matrixone/pkg/testutil/testengine"
    35  	"github.com/matrixorigin/matrixone/pkg/vm"
    36  	"github.com/stretchr/testify/require"
    37  )
    38  
    39  func TestScopeSerialization(t *testing.T) {
    40  	testCases := []string{
    41  		"select 1",
    42  		"select * from R",
    43  		"select count(*) from R",
    44  		"select * from R limit 2, 1",
    45  		"select * from R left join S on R.uid = S.uid",
    46  	}
    47  
    48  	var sourceScopes = generateScopeCases(t, testCases)
    49  
    50  	for i, sourceScope := range sourceScopes {
    51  		data, errEncode := encodeScope(sourceScope)
    52  		require.NoError(t, errEncode)
    53  		targetScope, errDecode := decodeScope(data, sourceScope.Proc, false, nil)
    54  		require.NoError(t, errDecode)
    55  
    56  		// Just do simple check
    57  		require.Equal(t, len(sourceScope.PreScopes), len(targetScope.PreScopes), fmt.Sprintf("related SQL is '%s'", testCases[i]))
    58  		require.Equal(t, len(sourceScope.Instructions), len(targetScope.Instructions), fmt.Sprintf("related SQL is '%s'", testCases[i]))
    59  		for j := 0; j < len(sourceScope.Instructions); j++ {
    60  			require.Equal(t, sourceScope.Instructions[j].Op, targetScope.Instructions[j].Op)
    61  		}
    62  		if sourceScope.DataSource == nil {
    63  			require.Nil(t, targetScope.DataSource)
    64  		} else {
    65  			require.Equal(t, sourceScope.DataSource.SchemaName, targetScope.DataSource.SchemaName)
    66  			require.Equal(t, sourceScope.DataSource.RelationName, targetScope.DataSource.RelationName)
    67  			require.Equal(t, sourceScope.DataSource.PushdownId, targetScope.DataSource.PushdownId)
    68  			require.Equal(t, sourceScope.DataSource.PushdownAddr, targetScope.DataSource.PushdownAddr)
    69  		}
    70  		require.Equal(t, sourceScope.NodeInfo.Addr, targetScope.NodeInfo.Addr)
    71  		require.Equal(t, sourceScope.NodeInfo.Id, targetScope.NodeInfo.Id)
    72  	}
    73  
    74  }
    75  
    76  func generateScopeCases(t *testing.T, testCases []string) []*Scope {
    77  	// getScope method generate and return the scope of a SQL string.
    78  	getScope := func(t1 *testing.T, sql string) *Scope {
    79  		proc := testutil.NewProcess()
    80  		proc.SessionInfo.Buf = buffer.New()
    81  		e, _, compilerCtx := testengine.New(defines.AttachAccountId(context.Background(), catalog.System_Account))
    82  		opt := plan2.NewBaseOptimizer(compilerCtx)
    83  		ctx := compilerCtx.GetContext()
    84  		stmts, err := mysql.Parse(ctx, sql, 1, 0)
    85  		require.NoError(t1, err)
    86  		qry, err := opt.Optimize(stmts[0], false)
    87  		require.NoError(t1, err)
    88  		proc.Ctx = ctx
    89  		c := NewCompile("test", "test", sql, "", "", context.Background(), e, proc, nil, false, nil, time.Now())
    90  		err = c.Compile(ctx, &plan.Plan{Plan: &plan.Plan_Query{Query: qry}}, func(batch *batch.Batch) error {
    91  			return nil
    92  		})
    93  		require.NoError(t1, err)
    94  		// ignore the last operator if it's output
    95  		if c.scope[0].Instructions[len(c.scope[0].Instructions)-1].Op == vm.Output {
    96  			c.scope[0].Instructions = c.scope[0].Instructions[:len(c.scope[0].Instructions)-1]
    97  		}
    98  		return c.scope[0]
    99  	}
   100  
   101  	result := make([]*Scope, len(testCases))
   102  	for i, sql := range testCases {
   103  		result[i] = getScope(t, sql)
   104  	}
   105  	return result
   106  }
   107  
   108  func TestMessageSenderOnClientReceive(t *testing.T) {
   109  	sender := new(messageSenderOnClient)
   110  	sender.receiveCh = make(chan morpc.Message, 1)
   111  
   112  	// case 1: use source context, and source context is canceled
   113  	{
   114  		sourceCtx, sourceCancel := context.WithCancel(context.Background())
   115  		sender.ctx = sourceCtx
   116  		sender.ctxCancel = sourceCancel
   117  		sourceCancel()
   118  		v, err := sender.receiveMessage()
   119  		require.NoError(t, err)
   120  		require.Equal(t, nil, v)
   121  	}
   122  
   123  	// case 2: use derived context, and source context is canceled
   124  	{
   125  		sourceCtx, sourceCancel := context.WithCancel(context.Background())
   126  		receiveCtx, receiveCancel := context.WithTimeout(sourceCtx, 3*time.Second)
   127  		sender.ctx = receiveCtx
   128  		sender.ctxCancel = receiveCancel
   129  		sourceCancel()
   130  
   131  		startTime := time.Now()
   132  		v, err := sender.receiveMessage()
   133  		require.NoError(t, err)
   134  		require.Equal(t, nil, v)
   135  		require.True(t, time.Since(startTime) < 3*time.Second)
   136  		receiveCancel()
   137  	}
   138  
   139  	// case 3: receive a nil message
   140  	{
   141  		sourceCtx, sourceCancel := context.WithCancel(context.Background())
   142  		sender.ctx = sourceCtx
   143  		sender.ctxCancel = sourceCancel
   144  		sender.receiveCh <- nil
   145  		_, err := sender.receiveMessage()
   146  		require.NotNil(t, err)
   147  		sourceCancel()
   148  	}
   149  
   150  	// case 4: receive a message
   151  	{
   152  		sourceCtx, sourceCancel := context.WithCancel(context.Background())
   153  		sender.ctx = sourceCtx
   154  		sender.ctxCancel = sourceCancel
   155  		data := &pipeline.Message{}
   156  		sender.receiveCh <- data
   157  		v, err := sender.receiveMessage()
   158  		require.NoError(t, err)
   159  		require.Equal(t, data, v)
   160  		sourceCancel()
   161  	}
   162  
   163  	// case 5: channel is closed
   164  	{
   165  		sourceCtx, sourceCancel := context.WithCancel(context.Background())
   166  		sender.ctx = sourceCtx
   167  		sender.ctxCancel = sourceCancel
   168  		close(sender.receiveCh)
   169  		_, err := sender.receiveMessage()
   170  		require.NotNil(t, err)
   171  	}
   172  }