github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/sqle/testutil.go (about)

     1  // Copyright 2019 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sqle
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"fmt"
    21  	"io"
    22  	"strings"
    23  
    24  	sqle "github.com/dolthub/go-mysql-server"
    25  	"github.com/dolthub/go-mysql-server/sql"
    26  	"github.com/dolthub/vitess/go/vt/sqlparser"
    27  
    28  	"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
    29  	"github.com/dolthub/dolt/go/libraries/doltcore/env"
    30  )
    31  
    32  // ExecuteSql executes all the SQL non-select statements given in the string against the root value given and returns
    33  // the updated root, or an error. Statements in the input string are split by `;\n`
    34  func ExecuteSql(dEnv *env.DoltEnv, root *doltdb.RootValue, statements string) (*doltdb.RootValue, error) {
    35  	db := NewDatabase("dolt", dEnv.DbData())
    36  
    37  	engine, ctx, err := NewTestEngine(context.Background(), db, root)
    38  	DSessFromSess(ctx.Session).EnableBatchedMode()
    39  	err = ctx.Session.SetSessionVariable(ctx, sql.AutoCommitSessionVar, true)
    40  	if err != nil {
    41  		return nil, err
    42  	}
    43  
    44  	for _, query := range strings.Split(statements, ";\n") {
    45  		if len(strings.Trim(query, " ")) == 0 {
    46  			continue
    47  		}
    48  
    49  		sqlStatement, err := sqlparser.Parse(query)
    50  		if err != nil {
    51  			return nil, err
    52  		}
    53  
    54  		var execErr error
    55  		switch sqlStatement.(type) {
    56  		case *sqlparser.Show:
    57  			return nil, errors.New("Show statements aren't handled")
    58  		case *sqlparser.Select, *sqlparser.OtherRead:
    59  			return nil, errors.New("Select statements aren't handled")
    60  		case *sqlparser.Insert:
    61  			var rowIter sql.RowIter
    62  			_, rowIter, execErr = engine.Query(ctx, query)
    63  			if execErr == nil {
    64  				execErr = drainIter(ctx, rowIter)
    65  			}
    66  		case *sqlparser.DDL, *sqlparser.MultiAlterDDL:
    67  			var rowIter sql.RowIter
    68  			_, rowIter, execErr = engine.Query(ctx, query)
    69  			if execErr == nil {
    70  				execErr = drainIter(ctx, rowIter)
    71  			}
    72  			if err = db.Flush(ctx); err != nil {
    73  				return nil, err
    74  			}
    75  		default:
    76  			return nil, fmt.Errorf("Unsupported SQL statement: '%v'.", query)
    77  		}
    78  
    79  		if execErr != nil {
    80  			return nil, execErr
    81  		}
    82  	}
    83  
    84  	if err := db.Flush(ctx); err == nil {
    85  		return db.GetRoot(ctx)
    86  	} else {
    87  		return nil, err
    88  	}
    89  }
    90  
    91  // NewTestSQLCtx returns a new *sql.Context with a default DoltSession, a new IndexRegistry, and a new ViewRegistry
    92  func NewTestSQLCtx(ctx context.Context) *sql.Context {
    93  	session := DefaultDoltSession()
    94  	sqlCtx := sql.NewContext(
    95  		ctx,
    96  		sql.WithSession(session),
    97  		sql.WithIndexRegistry(sql.NewIndexRegistry()),
    98  		sql.WithViewRegistry(sql.NewViewRegistry()),
    99  	).WithCurrentDB("dolt")
   100  
   101  	return sqlCtx
   102  }
   103  
   104  // NewTestEngine creates a new default engine, and a *sql.Context and initializes indexes and schema fragments.
   105  func NewTestEngine(ctx context.Context, db Database, root *doltdb.RootValue) (*sqle.Engine, *sql.Context, error) {
   106  	engine := sqle.NewDefault()
   107  	engine.AddDatabase(db)
   108  
   109  	sqlCtx := NewTestSQLCtx(ctx)
   110  	DSessFromSess(sqlCtx.Session).AddDB(sqlCtx, db, db.DbData())
   111  	sqlCtx.SetCurrentDatabase(db.Name())
   112  	err := db.SetRoot(sqlCtx, root)
   113  
   114  	if err != nil {
   115  		return nil, nil, err
   116  	}
   117  
   118  	err = RegisterSchemaFragments(sqlCtx, db, root)
   119  
   120  	if err != nil {
   121  		return nil, nil, err
   122  	}
   123  
   124  	return engine, sqlCtx, nil
   125  }
   126  
   127  // ExecuteSelect executes the select statement given and returns the resulting rows, or an error if one is encountered.
   128  func ExecuteSelect(dEnv *env.DoltEnv, ddb *doltdb.DoltDB, root *doltdb.RootValue, query string) ([]sql.Row, error) {
   129  
   130  	dbData := env.DbData{
   131  		Ddb: ddb,
   132  		Rsw: dEnv.RepoStateWriter(),
   133  		Rsr: dEnv.RepoStateReader(),
   134  		Drw: dEnv.DocsReadWriter(),
   135  	}
   136  
   137  	db := NewDatabase("dolt", dbData)
   138  	engine, ctx, err := NewTestEngine(context.Background(), db, root)
   139  	if err != nil {
   140  		return nil, err
   141  	}
   142  
   143  	_, rowIter, err := engine.Query(ctx, query)
   144  	if err != nil {
   145  		return nil, err
   146  	}
   147  
   148  	var (
   149  		rows   []sql.Row
   150  		rowErr error
   151  		row    sql.Row
   152  	)
   153  	for row, rowErr = rowIter.Next(); rowErr == nil; row, rowErr = rowIter.Next() {
   154  		rows = append(rows, row)
   155  	}
   156  
   157  	if rowErr != io.EOF {
   158  		return nil, rowErr
   159  	}
   160  
   161  	return rows, nil
   162  }
   163  
   164  func drainIter(ctx *sql.Context, iter sql.RowIter) error {
   165  	for {
   166  		_, err := iter.Next()
   167  		if err == io.EOF {
   168  			break
   169  		} else if err != nil {
   170  			closeErr := iter.Close(ctx)
   171  			if closeErr != nil {
   172  				panic(fmt.Errorf("%v\n%v", err, closeErr))
   173  			}
   174  			return err
   175  		}
   176  	}
   177  	return iter.Close(ctx)
   178  }