github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dsess/dolt_session_test.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package dsess
    16  
    17  import (
    18  	"context"
    19  	"testing"
    20  
    21  	"github.com/dolthub/go-mysql-server/sql"
    22  	_ "github.com/dolthub/go-mysql-server/sql/variables"
    23  	"github.com/stretchr/testify/assert"
    24  	"gopkg.in/src-d/go-errors.v1"
    25  
    26  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    27  	"github.com/dolthub/dolt/go/libraries/doltcore/env"
    28  	"github.com/dolthub/dolt/go/libraries/utils/config"
    29  	"github.com/dolthub/dolt/go/libraries/utils/filesys"
    30  	"github.com/dolthub/dolt/go/store/types"
    31  )
    32  
    33  func TestDoltSessionInit(t *testing.T) {
    34  	dsess := DefaultSession(emptyDatabaseProvider())
    35  	conf := config.NewMapConfig(make(map[string]string))
    36  	assert.Equal(t, conf, dsess.globalsConf)
    37  }
    38  
    39  func TestNewPersistedSystemVariables(t *testing.T) {
    40  	dsess := DefaultSession(emptyDatabaseProvider())
    41  	conf := config.NewMapConfig(map[string]string{"max_connections": "1000"})
    42  	dsess = dsess.WithGlobals(conf)
    43  
    44  	sysVars, err := dsess.SystemVariablesInConfig()
    45  	assert.NoError(t, err)
    46  
    47  	maxConRes := sysVars[0]
    48  	assert.Equal(t, "max_connections", maxConRes.GetName())
    49  	assert.Equal(t, int64(1000), maxConRes.GetDefault())
    50  }
    51  
    52  func TestValidatePeristableSystemVar(t *testing.T) {
    53  	tests := []struct {
    54  		Name string
    55  		Err  *errors.Kind
    56  	}{
    57  		{
    58  			Name: "max_connections",
    59  			Err:  nil,
    60  		},
    61  		{
    62  			Name: "init_file",
    63  			Err:  sql.ErrSystemVariableReadOnly,
    64  		},
    65  		{
    66  			Name: "unknown",
    67  			Err:  sql.ErrUnknownSystemVariable,
    68  		},
    69  	}
    70  
    71  	for _, tt := range tests {
    72  		t.Run(tt.Name, func(t *testing.T) {
    73  			if sysVar, _, err := validatePersistableSysVar(tt.Name); tt.Err != nil {
    74  				assert.True(t, tt.Err.Is(err))
    75  			} else {
    76  				assert.Equal(t, tt.Name, sysVar.GetName())
    77  
    78  			}
    79  		})
    80  	}
    81  }
    82  
    83  func TestSetPersistedValue(t *testing.T) {
    84  	tests := []struct {
    85  		Name        string
    86  		Value       interface{}
    87  		ExpectedRes interface{}
    88  		Err         *errors.Kind
    89  	}{
    90  		{
    91  			Name:  "int",
    92  			Value: 7,
    93  		},
    94  		{
    95  			Name:  "int8",
    96  			Value: int8(7),
    97  		},
    98  		{
    99  			Name:  "int16",
   100  			Value: int16(7),
   101  		},
   102  		{
   103  			Name:  "int32",
   104  			Value: int32(7),
   105  		},
   106  		{
   107  			Name:  "int64",
   108  			Value: int64(7),
   109  		},
   110  		{
   111  			Name:  "uint",
   112  			Value: uint(7),
   113  		},
   114  		{
   115  			Name:  "uint8",
   116  			Value: uint8(7),
   117  		},
   118  		{
   119  			Name:  "uint16",
   120  			Value: uint16(7),
   121  		},
   122  		{
   123  			Name:  "uint32",
   124  			Value: uint32(7),
   125  		},
   126  		{
   127  			Name:  "uint64",
   128  			Value: uint64(7),
   129  		},
   130  		{
   131  			Name:        "float32",
   132  			Value:       float32(7),
   133  			ExpectedRes: "7.00000000",
   134  		},
   135  		{
   136  			Name:        "float64",
   137  			Value:       float64(7),
   138  			ExpectedRes: "7.00000000",
   139  		},
   140  		{
   141  			Name:  "string",
   142  			Value: "7",
   143  		},
   144  		{
   145  			Name:        "bool",
   146  			Value:       true,
   147  			ExpectedRes: "1",
   148  		},
   149  		{
   150  			Name:        "bool",
   151  			Value:       false,
   152  			ExpectedRes: "0",
   153  		},
   154  		{
   155  			Value: complex64(7),
   156  			Err:   sql.ErrInvalidType,
   157  		},
   158  	}
   159  
   160  	for _, tt := range tests {
   161  		t.Run(tt.Name, func(t *testing.T) {
   162  			conf := config.NewMapConfig(make(map[string]string))
   163  			if err := setPersistedValue(conf, "key", tt.Value); tt.Err != nil {
   164  				assert.True(t, tt.Err.Is(err))
   165  			} else if tt.ExpectedRes == nil {
   166  				assert.Equal(t, "7", conf.GetStringOrDefault("key", ""))
   167  			} else {
   168  				assert.Equal(t, tt.ExpectedRes, conf.GetStringOrDefault("key", ""))
   169  
   170  			}
   171  		})
   172  	}
   173  }
   174  
   175  func TestGetPersistedValue(t *testing.T) {
   176  	tests := []struct {
   177  		Name        string
   178  		Value       string
   179  		ExpectedRes interface{}
   180  		Err         bool
   181  	}{
   182  		{
   183  			Name:        "long_query_time",
   184  			Value:       "7",
   185  			ExpectedRes: float64(7),
   186  		},
   187  		{
   188  			Name:        "tls_ciphersuites",
   189  			Value:       "7",
   190  			ExpectedRes: "7",
   191  		},
   192  		{
   193  			Name:        "max_connections",
   194  			Value:       "7",
   195  			ExpectedRes: int64(7),
   196  		},
   197  		{
   198  			Name:        "tmp_table_size",
   199  			Value:       "7",
   200  			ExpectedRes: uint64(7),
   201  		},
   202  		{
   203  			Name:  "activate_all_roles_on_login",
   204  			Value: "true",
   205  			Err:   true,
   206  		},
   207  		{
   208  			Name:  "activate_all_roles_on_login",
   209  			Value: "on",
   210  			Err:   true,
   211  		},
   212  		{
   213  			Name:        "activate_all_roles_on_login",
   214  			Value:       "1",
   215  			ExpectedRes: int8(1),
   216  		},
   217  		{
   218  			Name:  "activate_all_roles_on_login",
   219  			Value: "false",
   220  			Err:   true,
   221  		},
   222  		{
   223  			Name:  "activate_all_roles_on_login",
   224  			Value: "off",
   225  			Err:   true,
   226  		},
   227  		{
   228  			Name:        "activate_all_roles_on_login",
   229  			Value:       "0",
   230  			ExpectedRes: int8(0),
   231  		},
   232  	}
   233  
   234  	for _, tt := range tests {
   235  		t.Run(tt.Name, func(t *testing.T) {
   236  			conf := config.NewMapConfig(map[string]string{tt.Name: tt.Value})
   237  			if val, err := getPersistedValue(conf, tt.Name); tt.Err {
   238  				assert.Error(t, err)
   239  			} else {
   240  				assert.Equal(t, tt.ExpectedRes, val)
   241  			}
   242  		})
   243  	}
   244  }
   245  
   246  func emptyDatabaseProvider() DoltDatabaseProvider {
   247  	return emptyRevisionDatabaseProvider{}
   248  }
   249  
   250  type emptyRevisionDatabaseProvider struct {
   251  	sql.DatabaseProvider
   252  }
   253  
   254  func (e emptyRevisionDatabaseProvider) DbFactoryUrl() string {
   255  	return ""
   256  }
   257  
   258  func (e emptyRevisionDatabaseProvider) UndropDatabase(ctx *sql.Context, dbName string) error {
   259  	return nil
   260  }
   261  
   262  func (e emptyRevisionDatabaseProvider) ListDroppedDatabases(ctx *sql.Context) ([]string, error) {
   263  	return nil, nil
   264  }
   265  
   266  func (e emptyRevisionDatabaseProvider) PurgeDroppedDatabases(ctx *sql.Context) error {
   267  	return nil
   268  }
   269  
   270  func (e emptyRevisionDatabaseProvider) BaseDatabase(ctx *sql.Context, dbName string) (SqlDatabase, bool) {
   271  	return nil, false
   272  }
   273  
   274  func (e emptyRevisionDatabaseProvider) SessionDatabase(ctx *sql.Context, dbName string) (SqlDatabase, bool, error) {
   275  	return nil, false, sql.ErrDatabaseNotFound.New(dbName)
   276  }
   277  
   278  func (e emptyRevisionDatabaseProvider) DoltDatabases() []SqlDatabase {
   279  	return nil
   280  }
   281  
   282  func (e emptyRevisionDatabaseProvider) DbState(ctx *sql.Context, dbName string, defaultBranch string) (InitialDbState, error) {
   283  	return InitialDbState{}, sql.ErrDatabaseNotFound.New(dbName)
   284  }
   285  
   286  func (e emptyRevisionDatabaseProvider) DropDatabase(ctx *sql.Context, name string) error {
   287  	return nil
   288  }
   289  
   290  func (e emptyRevisionDatabaseProvider) GetRevisionForRevisionDatabase(_ *sql.Context, _ string) (string, string, error) {
   291  	return "", "", nil
   292  }
   293  
   294  func (e emptyRevisionDatabaseProvider) IsRevisionDatabase(_ *sql.Context, _ string) (bool, error) {
   295  	return false, nil
   296  }
   297  
   298  func (e emptyRevisionDatabaseProvider) GetRemoteDB(ctx context.Context, format *types.NomsBinFormat, r env.Remote, withCaching bool) (*doltdb.DoltDB, error) {
   299  	return nil, nil
   300  }
   301  
   302  func (e emptyRevisionDatabaseProvider) FileSystem() filesys.Filesys {
   303  	return nil
   304  }
   305  
   306  func (e emptyRevisionDatabaseProvider) FileSystemForDatabase(dbname string) (filesys.Filesys, error) {
   307  	return nil, nil
   308  }
   309  
   310  func (e emptyRevisionDatabaseProvider) CloneDatabaseFromRemote(ctx *sql.Context, dbName, branch, remoteName, remoteUrl string, depth int, remoteParams map[string]string) error {
   311  	return nil
   312  }
   313  
   314  func (e emptyRevisionDatabaseProvider) CreateDatabase(ctx *sql.Context, dbName string) error {
   315  	return nil
   316  }
   317  
   318  func (e emptyRevisionDatabaseProvider) RevisionDbState(_ *sql.Context, revDB string) (InitialDbState, error) {
   319  	return InitialDbState{}, sql.ErrDatabaseNotFound.New(revDB)
   320  }