github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dsess/dolt_session_test.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package dsess 16 17 import ( 18 "context" 19 "testing" 20 21 "github.com/dolthub/go-mysql-server/sql" 22 _ "github.com/dolthub/go-mysql-server/sql/variables" 23 "github.com/stretchr/testify/assert" 24 "gopkg.in/src-d/go-errors.v1" 25 26 "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" 27 "github.com/dolthub/dolt/go/libraries/doltcore/env" 28 "github.com/dolthub/dolt/go/libraries/utils/config" 29 "github.com/dolthub/dolt/go/libraries/utils/filesys" 30 "github.com/dolthub/dolt/go/store/types" 31 ) 32 33 func TestDoltSessionInit(t *testing.T) { 34 dsess := DefaultSession(emptyDatabaseProvider()) 35 conf := config.NewMapConfig(make(map[string]string)) 36 assert.Equal(t, conf, dsess.globalsConf) 37 } 38 39 func TestNewPersistedSystemVariables(t *testing.T) { 40 dsess := DefaultSession(emptyDatabaseProvider()) 41 conf := config.NewMapConfig(map[string]string{"max_connections": "1000"}) 42 dsess = dsess.WithGlobals(conf) 43 44 sysVars, err := dsess.SystemVariablesInConfig() 45 assert.NoError(t, err) 46 47 maxConRes := sysVars[0] 48 assert.Equal(t, "max_connections", maxConRes.GetName()) 49 assert.Equal(t, int64(1000), maxConRes.GetDefault()) 50 } 51 52 func TestValidatePeristableSystemVar(t *testing.T) { 53 tests := []struct { 54 Name string 55 Err *errors.Kind 56 }{ 57 { 58 Name: "max_connections", 59 Err: nil, 60 }, 61 { 62 Name: "init_file", 63 Err: sql.ErrSystemVariableReadOnly, 64 }, 65 { 66 Name: "unknown", 67 Err: sql.ErrUnknownSystemVariable, 68 }, 69 } 70 71 for _, tt := range tests { 72 t.Run(tt.Name, func(t *testing.T) { 73 if sysVar, _, err := validatePersistableSysVar(tt.Name); tt.Err != nil { 74 assert.True(t, tt.Err.Is(err)) 75 } else { 76 assert.Equal(t, tt.Name, sysVar.GetName()) 77 78 } 79 }) 80 } 81 } 82 83 func TestSetPersistedValue(t *testing.T) { 84 tests := []struct { 85 Name string 86 Value interface{} 87 ExpectedRes interface{} 88 Err *errors.Kind 89 }{ 90 { 91 Name: "int", 92 Value: 7, 93 }, 94 { 95 Name: "int8", 96 Value: int8(7), 97 }, 98 { 99 Name: "int16", 100 Value: int16(7), 101 }, 102 { 103 Name: "int32", 104 Value: int32(7), 105 }, 106 { 107 Name: "int64", 108 Value: int64(7), 109 }, 110 { 111 Name: "uint", 112 Value: uint(7), 113 }, 114 { 115 Name: "uint8", 116 Value: uint8(7), 117 }, 118 { 119 Name: "uint16", 120 Value: uint16(7), 121 }, 122 { 123 Name: "uint32", 124 Value: uint32(7), 125 }, 126 { 127 Name: "uint64", 128 Value: uint64(7), 129 }, 130 { 131 Name: "float32", 132 Value: float32(7), 133 ExpectedRes: "7.00000000", 134 }, 135 { 136 Name: "float64", 137 Value: float64(7), 138 ExpectedRes: "7.00000000", 139 }, 140 { 141 Name: "string", 142 Value: "7", 143 }, 144 { 145 Name: "bool", 146 Value: true, 147 ExpectedRes: "1", 148 }, 149 { 150 Name: "bool", 151 Value: false, 152 ExpectedRes: "0", 153 }, 154 { 155 Value: complex64(7), 156 Err: sql.ErrInvalidType, 157 }, 158 } 159 160 for _, tt := range tests { 161 t.Run(tt.Name, func(t *testing.T) { 162 conf := config.NewMapConfig(make(map[string]string)) 163 if err := setPersistedValue(conf, "key", tt.Value); tt.Err != nil { 164 assert.True(t, tt.Err.Is(err)) 165 } else if tt.ExpectedRes == nil { 166 assert.Equal(t, "7", conf.GetStringOrDefault("key", "")) 167 } else { 168 assert.Equal(t, tt.ExpectedRes, conf.GetStringOrDefault("key", "")) 169 170 } 171 }) 172 } 173 } 174 175 func TestGetPersistedValue(t *testing.T) { 176 tests := []struct { 177 Name string 178 Value string 179 ExpectedRes interface{} 180 Err bool 181 }{ 182 { 183 Name: "long_query_time", 184 Value: "7", 185 ExpectedRes: float64(7), 186 }, 187 { 188 Name: "tls_ciphersuites", 189 Value: "7", 190 ExpectedRes: "7", 191 }, 192 { 193 Name: "max_connections", 194 Value: "7", 195 ExpectedRes: int64(7), 196 }, 197 { 198 Name: "tmp_table_size", 199 Value: "7", 200 ExpectedRes: uint64(7), 201 }, 202 { 203 Name: "activate_all_roles_on_login", 204 Value: "true", 205 Err: true, 206 }, 207 { 208 Name: "activate_all_roles_on_login", 209 Value: "on", 210 Err: true, 211 }, 212 { 213 Name: "activate_all_roles_on_login", 214 Value: "1", 215 ExpectedRes: int8(1), 216 }, 217 { 218 Name: "activate_all_roles_on_login", 219 Value: "false", 220 Err: true, 221 }, 222 { 223 Name: "activate_all_roles_on_login", 224 Value: "off", 225 Err: true, 226 }, 227 { 228 Name: "activate_all_roles_on_login", 229 Value: "0", 230 ExpectedRes: int8(0), 231 }, 232 } 233 234 for _, tt := range tests { 235 t.Run(tt.Name, func(t *testing.T) { 236 conf := config.NewMapConfig(map[string]string{tt.Name: tt.Value}) 237 if val, err := getPersistedValue(conf, tt.Name); tt.Err { 238 assert.Error(t, err) 239 } else { 240 assert.Equal(t, tt.ExpectedRes, val) 241 } 242 }) 243 } 244 } 245 246 func emptyDatabaseProvider() DoltDatabaseProvider { 247 return emptyRevisionDatabaseProvider{} 248 } 249 250 type emptyRevisionDatabaseProvider struct { 251 sql.DatabaseProvider 252 } 253 254 func (e emptyRevisionDatabaseProvider) DbFactoryUrl() string { 255 return "" 256 } 257 258 func (e emptyRevisionDatabaseProvider) UndropDatabase(ctx *sql.Context, dbName string) error { 259 return nil 260 } 261 262 func (e emptyRevisionDatabaseProvider) ListDroppedDatabases(ctx *sql.Context) ([]string, error) { 263 return nil, nil 264 } 265 266 func (e emptyRevisionDatabaseProvider) PurgeDroppedDatabases(ctx *sql.Context) error { 267 return nil 268 } 269 270 func (e emptyRevisionDatabaseProvider) BaseDatabase(ctx *sql.Context, dbName string) (SqlDatabase, bool) { 271 return nil, false 272 } 273 274 func (e emptyRevisionDatabaseProvider) SessionDatabase(ctx *sql.Context, dbName string) (SqlDatabase, bool, error) { 275 return nil, false, sql.ErrDatabaseNotFound.New(dbName) 276 } 277 278 func (e emptyRevisionDatabaseProvider) DoltDatabases() []SqlDatabase { 279 return nil 280 } 281 282 func (e emptyRevisionDatabaseProvider) DbState(ctx *sql.Context, dbName string, defaultBranch string) (InitialDbState, error) { 283 return InitialDbState{}, sql.ErrDatabaseNotFound.New(dbName) 284 } 285 286 func (e emptyRevisionDatabaseProvider) DropDatabase(ctx *sql.Context, name string) error { 287 return nil 288 } 289 290 func (e emptyRevisionDatabaseProvider) GetRevisionForRevisionDatabase(_ *sql.Context, _ string) (string, string, error) { 291 return "", "", nil 292 } 293 294 func (e emptyRevisionDatabaseProvider) IsRevisionDatabase(_ *sql.Context, _ string) (bool, error) { 295 return false, nil 296 } 297 298 func (e emptyRevisionDatabaseProvider) GetRemoteDB(ctx context.Context, format *types.NomsBinFormat, r env.Remote, withCaching bool) (*doltdb.DoltDB, error) { 299 return nil, nil 300 } 301 302 func (e emptyRevisionDatabaseProvider) FileSystem() filesys.Filesys { 303 return nil 304 } 305 306 func (e emptyRevisionDatabaseProvider) FileSystemForDatabase(dbname string) (filesys.Filesys, error) { 307 return nil, nil 308 } 309 310 func (e emptyRevisionDatabaseProvider) CloneDatabaseFromRemote(ctx *sql.Context, dbName, branch, remoteName, remoteUrl string, depth int, remoteParams map[string]string) error { 311 return nil 312 } 313 314 func (e emptyRevisionDatabaseProvider) CreateDatabase(ctx *sql.Context, dbName string) error { 315 return nil 316 } 317 318 func (e emptyRevisionDatabaseProvider) RevisionDbState(_ *sql.Context, revDB string) (InitialDbState, error) { 319 return InitialDbState{}, sql.ErrDatabaseNotFound.New(revDB) 320 }