github.com/matrixorigin/matrixone@v1.2.0/pkg/sql/plan/function/ctl/cmd_rpc_version_test.go (about)

     1  // Copyright 2023 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package ctl
    16  
    17  import (
    18  	"fmt"
    19  	"testing"
    20  
    21  	"github.com/google/uuid"
    22  	"github.com/matrixorigin/matrixone/pkg/common/moerr"
    23  	"github.com/matrixorigin/matrixone/pkg/common/morpc"
    24  	"github.com/matrixorigin/matrixone/pkg/common/runtime"
    25  	"github.com/matrixorigin/matrixone/pkg/defines"
    26  	"github.com/matrixorigin/matrixone/pkg/queryservice"
    27  	qclient "github.com/matrixorigin/matrixone/pkg/queryservice/client"
    28  	"github.com/matrixorigin/matrixone/pkg/util/trace"
    29  	"github.com/matrixorigin/matrixone/pkg/vm/process"
    30  	"github.com/stretchr/testify/require"
    31  )
    32  
    33  func TestHandleGetProtocolVersion(t *testing.T) {
    34  	var arguments struct {
    35  		proc    *process.Process
    36  		service serviceType
    37  	}
    38  
    39  	trace.InitMOCtledSpan()
    40  
    41  	id := uuid.New().String()
    42  	addr := "127.0.0.1:7777"
    43  	initRuntime([]string{id}, []string{addr})
    44  	qs, err := queryservice.NewQueryService(id, addr, morpc.Config{})
    45  	require.NoError(t, err)
    46  	qt, err := qclient.NewQueryClient(id, morpc.Config{})
    47  	require.NoError(t, err)
    48  
    49  	arguments.proc = new(process.Process)
    50  	arguments.proc.QueryClient = qt
    51  	arguments.service = cn
    52  
    53  	err = qs.Start()
    54  	require.NoError(t, err)
    55  
    56  	defer func() {
    57  		qs.Close()
    58  	}()
    59  
    60  	ret, err := handleGetProtocolVersion(arguments.proc, arguments.service, "", nil)
    61  	require.NoError(t, err)
    62  	require.Equal(t, ret, Result{
    63  		Method: GetProtocolVersionMethod,
    64  		Data:   fmt.Sprintf("%s:%d", id, defines.MORPCLatestVersion),
    65  	})
    66  }
    67  
    68  func TestHandleSetProtocolVersion(t *testing.T) {
    69  	trace.InitMOCtledSpan()
    70  	proc := new(process.Process)
    71  	id := uuid.New().String()
    72  	addr := "127.0.0.1:7777"
    73  	initRuntime([]string{id}, []string{addr})
    74  	requireVersionValue(t, defines.MORPCLatestVersion)
    75  	qs, err := queryservice.NewQueryService(id, addr, morpc.Config{})
    76  	require.NoError(t, err)
    77  	qt, err := qclient.NewQueryClient(id, morpc.Config{})
    78  	require.NoError(t, err)
    79  	proc.QueryClient = qt
    80  
    81  	cases := []struct {
    82  		service serviceType
    83  		version int64
    84  
    85  		expectedErr error
    86  	}{
    87  		{service: cn, version: 1},
    88  		{service: cn, version: 2},
    89  		{service: tn, version: 1, expectedErr: moerr.NewInternalError(proc.Ctx, "no such tn service")},
    90  	}
    91  
    92  	err = qs.Start()
    93  	require.NoError(t, err)
    94  	defer func() {
    95  		qs.Close()
    96  	}()
    97  
    98  	for _, c := range cases {
    99  		var parameter string
   100  		if c.service == tn {
   101  			parameter = fmt.Sprintf("%d", c.version)
   102  		} else {
   103  			parameter = fmt.Sprintf("%s:%d", id, c.version)
   104  		}
   105  
   106  		ret, err := handleSetProtocolVersion(proc, c.service, parameter, nil)
   107  		if c.expectedErr != nil {
   108  			require.Equal(t, c.expectedErr, err)
   109  			continue
   110  		} else {
   111  			require.NoError(t, err)
   112  		}
   113  		require.Equal(t, ret, Result{
   114  			Method: SetProtocolVersionMethod,
   115  			Data:   fmt.Sprintf("%s:%d", id, c.version),
   116  		})
   117  		requireVersionValue(t, c.version)
   118  	}
   119  }
   120  
   121  func requireVersionValue(t *testing.T, version int64) {
   122  	v, ok := runtime.ProcessLevelRuntime().GetGlobalVariables(runtime.MOProtocolVersion)
   123  	require.True(t, ok)
   124  	require.EqualValues(t, version, v)
   125  }