github.com/dolthub/go-mysql-server@v0.18.0/sql/rowexec/set_test.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package rowexec 16 17 import ( 18 "context" 19 "testing" 20 21 "github.com/dolthub/vitess/go/sqltypes" 22 "github.com/stretchr/testify/assert" 23 "github.com/stretchr/testify/require" 24 "gopkg.in/src-d/go-errors.v1" 25 26 "github.com/dolthub/go-mysql-server/memory" 27 "github.com/dolthub/go-mysql-server/sql" 28 "github.com/dolthub/go-mysql-server/sql/expression" 29 "github.com/dolthub/go-mysql-server/sql/plan" 30 "github.com/dolthub/go-mysql-server/sql/types" 31 "github.com/dolthub/go-mysql-server/sql/variables" 32 ) 33 34 func TestSet(t *testing.T) { 35 require := require.New(t) 36 37 ctx := sql.NewContext(context.Background(), sql.WithSession(sql.NewBaseSession())) 38 39 s := plan.NewSet( 40 []sql.Expression{ 41 expression.NewSetField(expression.NewUserVar("foo"), expression.NewLiteral("bar", types.LongText)), 42 expression.NewSetField(expression.NewUserVar("baz"), expression.NewLiteral(int64(1), types.Int64)), 43 }, 44 ) 45 46 _, err := DefaultBuilder.Build(ctx, s, nil) 47 require.NoError(err) 48 49 typ, v, err := ctx.GetUserVariable(ctx, "foo") 50 require.NoError(err) 51 require.Equal(types.MustCreateStringWithDefaults(sqltypes.VarChar, 3), typ) 52 require.Equal("bar", v) 53 54 typ, v, err = ctx.GetUserVariable(ctx, "baz") 55 require.NoError(err) 56 require.Equal(types.Int64, typ) 57 require.Equal(int64(1), v) 58 } 59 60 func newPersistedSqlContext() (*sql.Context, memory.GlobalsMap) { 61 ctx, _ := context.WithCancel(context.TODO()) 62 pro := memory.NewDBProvider() 63 sess := memory.NewSession(sql.NewBaseSession(), pro) 64 65 persistedGlobals := map[string]interface{}{"max_connections": 1000} 66 sess.SetGlobals(persistedGlobals) 67 68 sqlCtx := sql.NewContext(ctx) 69 sqlCtx.Session = sess 70 return sqlCtx, persistedGlobals 71 } 72 73 func TestPersistedSessionSetIterator(t *testing.T) { 74 setTests := []struct { 75 title string 76 name string 77 value int 78 scope sql.SystemVariableScope 79 err *errors.Kind 80 globalCmp interface{} 81 persistedCmp interface{} 82 }{ 83 {"persist var", "max_connections", 10, sql.SystemVariableScope_Persist, nil, int64(10), int64(10)}, 84 {"persist only", "max_connections", 10, sql.SystemVariableScope_PersistOnly, nil, int64(151), int64(10)}, 85 {"no persist", "auto_increment_increment", 3300, sql.SystemVariableScope_Global, nil, int64(3300), nil}, 86 {"persist unknown variable", "nonexistent", 10, sql.SystemVariableScope_Persist, sql.ErrUnknownSystemVariable, nil, nil}, 87 {"persist only unknown variable", "nonexistent", 10, sql.SystemVariableScope_PersistOnly, sql.ErrUnknownSystemVariable, nil, nil}, 88 } 89 90 for _, test := range setTests { 91 t.Run(test.title, func(t *testing.T) { 92 variables.InitSystemVariables() 93 sqlCtx, globals := newPersistedSqlContext() 94 s := plan.NewSet( 95 []sql.Expression{ 96 expression.NewSetField(expression.NewSystemVar(test.name, test.scope, string(test.scope)), expression.NewLiteral(int64(test.value), types.Int64)), 97 }, 98 ) 99 100 _, err := DefaultBuilder.Build(sqlCtx, s, nil) 101 if test.err != nil { 102 assert.True(t, test.err.Is(err)) 103 return 104 } else { 105 assert.NoError(t, err) 106 } 107 108 res := globals[test.name] 109 assert.Equal(t, test.persistedCmp, res) 110 111 _, val, _ := sql.SystemVariables.GetGlobal(test.name) 112 assert.Equal(t, test.globalCmp, val) 113 }) 114 } 115 }