github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/compile/scope_test.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package compile
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"testing"
    21  
    22  	"github.com/matrixorigin/matrixone/pkg/container/batch"
    23  	"github.com/matrixorigin/matrixone/pkg/pb/plan"
    24  	"github.com/matrixorigin/matrixone/pkg/sql/parsers/dialect/mysql"
    25  	plan2 "github.com/matrixorigin/matrixone/pkg/sql/plan"
    26  	"github.com/matrixorigin/matrixone/pkg/testutil"
    27  	"github.com/matrixorigin/matrixone/pkg/testutil/testengine"
    28  	"github.com/matrixorigin/matrixone/pkg/vm"
    29  	"github.com/stretchr/testify/require"
    30  )
    31  
    32  func TestScopeSerialization(t *testing.T) {
    33  	testCases := []string{
    34  		"select 1",
    35  		"select * from R",
    36  		"select count(*) from R",
    37  		"select * from R limit 2, 1",
    38  		"select * from R left join S on R.uid = S.uid",
    39  	}
    40  
    41  	var sourceScopes = generateScopeCases(t, testCases)
    42  
    43  	for i, sourceScope := range sourceScopes {
    44  		data, errEncode := encodeScope(sourceScope)
    45  		require.NoError(t, errEncode)
    46  		targetScope, errDecode := decodeScope(data, sourceScope.Proc, false)
    47  		require.NoError(t, errDecode)
    48  
    49  		// Just do simple check
    50  		require.Equal(t, len(sourceScope.PreScopes), len(targetScope.PreScopes), fmt.Sprintf("related SQL is '%s'", testCases[i]))
    51  		require.Equal(t, len(sourceScope.Instructions), len(targetScope.Instructions), fmt.Sprintf("related SQL is '%s'", testCases[i]))
    52  		for j := 0; j < len(sourceScope.Instructions); j++ {
    53  			require.Equal(t, sourceScope.Instructions[j].Op, targetScope.Instructions[j].Op)
    54  		}
    55  		if sourceScope.DataSource == nil {
    56  			require.Nil(t, targetScope.DataSource)
    57  		} else {
    58  			require.Equal(t, sourceScope.DataSource.SchemaName, targetScope.DataSource.SchemaName)
    59  			require.Equal(t, sourceScope.DataSource.RelationName, targetScope.DataSource.RelationName)
    60  			require.Equal(t, sourceScope.DataSource.PushdownId, targetScope.DataSource.PushdownId)
    61  			require.Equal(t, sourceScope.DataSource.PushdownAddr, targetScope.DataSource.PushdownAddr)
    62  		}
    63  		require.Equal(t, sourceScope.NodeInfo.Addr, targetScope.NodeInfo.Addr)
    64  		require.Equal(t, sourceScope.NodeInfo.Id, targetScope.NodeInfo.Id)
    65  	}
    66  
    67  }
    68  
    69  func generateScopeCases(t *testing.T, testCases []string) []*Scope {
    70  	// getScope method generate and return the scope of a SQL string.
    71  	getScope := func(t1 *testing.T, sql string) *Scope {
    72  		proc := testutil.NewProcess()
    73  		e, _, compilerCtx := testengine.New(context.Background())
    74  		opt := plan2.NewBaseOptimizer(compilerCtx)
    75  		ctx := compilerCtx.GetContext()
    76  		stmts, err := mysql.Parse(ctx, sql)
    77  		require.NoError(t1, err)
    78  		qry, err := opt.Optimize(stmts[0])
    79  		require.NoError(t1, err)
    80  		c := New("", "test", sql, "", context.Background(), e, proc, nil)
    81  		err = c.Compile(ctx, &plan.Plan{Plan: &plan.Plan_Query{Query: qry}}, nil, func(a any, batch *batch.Batch) error {
    82  			return nil
    83  		})
    84  		require.NoError(t1, err)
    85  		// ignore the last operator if it's output
    86  		if c.scope.Instructions[len(c.scope.Instructions)-1].Op == vm.Output {
    87  			c.scope.Instructions = c.scope.Instructions[:len(c.scope.Instructions)-1]
    88  		}
    89  		return c.scope
    90  	}
    91  
    92  	result := make([]*Scope, len(testCases))
    93  	for i, sql := range testCases {
    94  		result[i] = getScope(t, sql)
    95  	}
    96  	return result
    97  }