github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/sqle/altertests/common_test.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package altertests
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"io"
    21  	"os"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/dolthub/go-mysql-server/sql"
    26  	"github.com/stretchr/testify/assert"
    27  	"github.com/stretchr/testify/require"
    28  
    29  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    30  	"github.com/dolthub/dolt/go/libraries/doltcore/dtestutils"
    31  	"github.com/dolthub/dolt/go/libraries/doltcore/env"
    32  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle"
    33  )
    34  
    35  type ModifyTypeTest struct {
    36  	FromType     string
    37  	ToType       string
    38  	InsertValues string
    39  	SelectRes    []interface{}
    40  	ExpectedErr  bool
    41  }
    42  
    43  func RunModifyTypeTests(t *testing.T, tests []ModifyTypeTest) {
    44  	for _, test := range tests {
    45  		name := fmt.Sprintf("%s -> %s: %s", test.FromType, test.ToType, test.InsertValues)
    46  		if len(name) > 200 {
    47  			name = name[:200]
    48  		}
    49  		t.Run(name, func(t *testing.T) {
    50  			ctx := context.Background()
    51  			dEnv := dtestutils.CreateTestEnv()
    52  			root, err := dEnv.WorkingRoot(ctx)
    53  			require.NoError(t, err)
    54  			root, err = executeModify(ctx, dEnv, root, fmt.Sprintf("CREATE TABLE test(pk BIGINT PRIMARY KEY, v1 %s);", test.FromType))
    55  			require.NoError(t, err)
    56  			root, err = executeModify(ctx, dEnv, root, fmt.Sprintf("INSERT INTO test VALUES %s;", test.InsertValues))
    57  			require.NoError(t, err)
    58  			root, err = executeModify(ctx, dEnv, root, fmt.Sprintf("ALTER TABLE test MODIFY v1 %s;", test.ToType))
    59  			if test.ExpectedErr {
    60  				assert.Error(t, err)
    61  				return
    62  			}
    63  			require.NoError(t, err)
    64  			res, err := executeSelect(ctx, dEnv, root, "SELECT v1 FROM test ORDER BY pk;")
    65  			require.NoError(t, err)
    66  			assert.Equal(t, test.SelectRes, res)
    67  		})
    68  	}
    69  }
    70  
    71  func SkipByDefaultInCI(t *testing.T) {
    72  	if os.Getenv("CI") != "" && os.Getenv("DOLT_TEST_RUN_NON_RACE_TESTS") == "" {
    73  		t.Skip()
    74  	}
    75  }
    76  
    77  func widenValue(v interface{}) interface{} {
    78  	switch x := v.(type) {
    79  	case int:
    80  		return int64(x)
    81  	case int8:
    82  		return int64(x)
    83  	case int16:
    84  		return int64(x)
    85  	case int32:
    86  		return int64(x)
    87  	case uint:
    88  		return uint64(x)
    89  	case uint8:
    90  		return uint64(x)
    91  	case uint16:
    92  		return uint64(x)
    93  	case uint32:
    94  		return uint64(x)
    95  	case float32:
    96  		return float64(x)
    97  	default:
    98  		return v
    99  	}
   100  }
   101  
   102  func parseTime(timestampLayout bool, value string) time.Time {
   103  	var t time.Time
   104  	var err error
   105  	if timestampLayout {
   106  		t, err = time.Parse("2006-01-02 15:04:05.999999", value)
   107  	} else {
   108  		t, err = time.Parse("2006-01-02", value)
   109  	}
   110  	if err != nil {
   111  		panic(err)
   112  	}
   113  	return t.UTC()
   114  }
   115  
   116  func executeSelect(ctx context.Context, dEnv *env.DoltEnv, root *doltdb.RootValue, query string) ([]interface{}, error) {
   117  	var err error
   118  	db := sqle.NewDatabase("dolt", dEnv.DbData())
   119  	engine, sqlCtx, err := sqle.NewTestEngine(ctx, db, root)
   120  	if err != nil {
   121  		return nil, err
   122  	}
   123  	_, iter, err := engine.Query(sqlCtx, query)
   124  	if err != nil {
   125  		return nil, err
   126  	}
   127  	var vals []interface{}
   128  	var r sql.Row
   129  	for r, err = iter.Next(); err == nil; r, err = iter.Next() {
   130  		if len(r) == 1 {
   131  			// widen the values since we're testing values rather than types
   132  			vals = append(vals, widenValue(r[0]))
   133  		} else if len(r) > 1 {
   134  			return nil, fmt.Errorf("expected return of single value from select: %q", query)
   135  		} else { // no values
   136  			vals = append(vals, nil)
   137  		}
   138  	}
   139  	if err != io.EOF {
   140  		return nil, err
   141  	}
   142  	return vals, nil
   143  }
   144  
   145  func executeModify(ctx context.Context, dEnv *env.DoltEnv, root *doltdb.RootValue, query string) (*doltdb.RootValue, error) {
   146  	db := sqle.NewDatabase("dolt", dEnv.DbData())
   147  	engine, sqlCtx, err := sqle.NewTestEngine(ctx, db, root)
   148  	if err != nil {
   149  		return nil, err
   150  	}
   151  	_, iter, err := engine.Query(sqlCtx, query)
   152  	if err != nil {
   153  		return nil, err
   154  	}
   155  	for {
   156  		_, err := iter.Next()
   157  		if err == io.EOF {
   158  			break
   159  		}
   160  		if err != nil {
   161  			return nil, err
   162  		}
   163  	}
   164  	err = iter.Close(sqlCtx)
   165  	if err != nil {
   166  		return nil, err
   167  	}
   168  	return db.GetRoot(sqlCtx)
   169  }