github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/set_test.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package rowexec
    16  
    17  import (
    18  	"context"
    19  	"testing"
    20  
    21  	"github.com/dolthub/vitess/go/sqltypes"
    22  	"github.com/stretchr/testify/assert"
    23  	"github.com/stretchr/testify/require"
    24  	"gopkg.in/src-d/go-errors.v1"
    25  
    26  	"github.com/dolthub/go-mysql-server/memory"
    27  	"github.com/dolthub/go-mysql-server/sql"
    28  	"github.com/dolthub/go-mysql-server/sql/expression"
    29  	"github.com/dolthub/go-mysql-server/sql/plan"
    30  	"github.com/dolthub/go-mysql-server/sql/types"
    31  	"github.com/dolthub/go-mysql-server/sql/variables"
    32  )
    33  
    34  func TestSet(t *testing.T) {
    35  	require := require.New(t)
    36  
    37  	ctx := sql.NewContext(context.Background(), sql.WithSession(sql.NewBaseSession()))
    38  
    39  	s := plan.NewSet(
    40  		[]sql.Expression{
    41  			expression.NewSetField(expression.NewUserVar("foo"), expression.NewLiteral("bar", types.LongText)),
    42  			expression.NewSetField(expression.NewUserVar("baz"), expression.NewLiteral(int64(1), types.Int64)),
    43  		},
    44  	)
    45  
    46  	_, err := DefaultBuilder.Build(ctx, s, nil)
    47  	require.NoError(err)
    48  
    49  	typ, v, err := ctx.GetUserVariable(ctx, "foo")
    50  	require.NoError(err)
    51  	require.Equal(types.MustCreateStringWithDefaults(sqltypes.VarChar, 3), typ)
    52  	require.Equal("bar", v)
    53  
    54  	typ, v, err = ctx.GetUserVariable(ctx, "baz")
    55  	require.NoError(err)
    56  	require.Equal(types.Int64, typ)
    57  	require.Equal(int64(1), v)
    58  }
    59  
    60  func newPersistedSqlContext() (*sql.Context, memory.GlobalsMap) {
    61  	ctx, _ := context.WithCancel(context.TODO())
    62  	pro := memory.NewDBProvider()
    63  	sess := memory.NewSession(sql.NewBaseSession(), pro)
    64  
    65  	persistedGlobals := map[string]interface{}{"max_connections": 1000}
    66  	sess.SetGlobals(persistedGlobals)
    67  
    68  	sqlCtx := sql.NewContext(ctx)
    69  	sqlCtx.Session = sess
    70  	return sqlCtx, persistedGlobals
    71  }
    72  
    73  func TestPersistedSessionSetIterator(t *testing.T) {
    74  	setTests := []struct {
    75  		title        string
    76  		name         string
    77  		value        int
    78  		scope        sql.SystemVariableScope
    79  		err          *errors.Kind
    80  		globalCmp    interface{}
    81  		persistedCmp interface{}
    82  	}{
    83  		{"persist var", "max_connections", 10, sql.SystemVariableScope_Persist, nil, int64(10), int64(10)},
    84  		{"persist only", "max_connections", 10, sql.SystemVariableScope_PersistOnly, nil, int64(151), int64(10)},
    85  		{"no persist", "auto_increment_increment", 3300, sql.SystemVariableScope_Global, nil, int64(3300), nil},
    86  		{"persist unknown variable", "nonexistent", 10, sql.SystemVariableScope_Persist, sql.ErrUnknownSystemVariable, nil, nil},
    87  		{"persist only unknown variable", "nonexistent", 10, sql.SystemVariableScope_PersistOnly, sql.ErrUnknownSystemVariable, nil, nil},
    88  	}
    89  
    90  	for _, test := range setTests {
    91  		t.Run(test.title, func(t *testing.T) {
    92  			variables.InitSystemVariables()
    93  			sqlCtx, globals := newPersistedSqlContext()
    94  			s := plan.NewSet(
    95  				[]sql.Expression{
    96  					expression.NewSetField(expression.NewSystemVar(test.name, test.scope, string(test.scope)), expression.NewLiteral(int64(test.value), types.Int64)),
    97  				},
    98  			)
    99  
   100  			_, err := DefaultBuilder.Build(sqlCtx, s, nil)
   101  			if test.err != nil {
   102  				assert.True(t, test.err.Is(err))
   103  				return
   104  			} else {
   105  				assert.NoError(t, err)
   106  			}
   107  
   108  			res := globals[test.name]
   109  			assert.Equal(t, test.persistedCmp, res)
   110  
   111  			_, val, _ := sql.SystemVariables.GetGlobal(test.name)
   112  			assert.Equal(t, test.globalCmp, val)
   113  		})
   114  	}
   115  }