github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/sqle/testutil.go (about) 1 // Copyright 2019 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package sqle 16 17 import ( 18 "context" 19 "errors" 20 "fmt" 21 "io" 22 "strings" 23 24 sqle "github.com/dolthub/go-mysql-server" 25 "github.com/dolthub/go-mysql-server/sql" 26 "github.com/dolthub/vitess/go/vt/sqlparser" 27 28 "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" 29 "github.com/dolthub/dolt/go/libraries/doltcore/env" 30 ) 31 32 // ExecuteSql executes all the SQL non-select statements given in the string against the root value given and returns 33 // the updated root, or an error. Statements in the input string are split by `;\n` 34 func ExecuteSql(dEnv *env.DoltEnv, root *doltdb.RootValue, statements string) (*doltdb.RootValue, error) { 35 db := NewDatabase("dolt", dEnv.DbData()) 36 37 engine, ctx, err := NewTestEngine(context.Background(), db, root) 38 DSessFromSess(ctx.Session).EnableBatchedMode() 39 err = ctx.Session.SetSessionVariable(ctx, sql.AutoCommitSessionVar, true) 40 if err != nil { 41 return nil, err 42 } 43 44 for _, query := range strings.Split(statements, ";\n") { 45 if len(strings.Trim(query, " ")) == 0 { 46 continue 47 } 48 49 sqlStatement, err := sqlparser.Parse(query) 50 if err != nil { 51 return nil, err 52 } 53 54 var execErr error 55 switch sqlStatement.(type) { 56 case *sqlparser.Show: 57 return nil, errors.New("Show statements aren't handled") 58 case *sqlparser.Select, *sqlparser.OtherRead: 59 return nil, errors.New("Select statements aren't handled") 60 case *sqlparser.Insert: 61 var rowIter sql.RowIter 62 _, rowIter, execErr = engine.Query(ctx, query) 63 if execErr == nil { 64 execErr = drainIter(ctx, rowIter) 65 } 66 case *sqlparser.DDL, *sqlparser.MultiAlterDDL: 67 var rowIter sql.RowIter 68 _, rowIter, execErr = engine.Query(ctx, query) 69 if execErr == nil { 70 execErr = drainIter(ctx, rowIter) 71 } 72 if err = db.Flush(ctx); err != nil { 73 return nil, err 74 } 75 default: 76 return nil, fmt.Errorf("Unsupported SQL statement: '%v'.", query) 77 } 78 79 if execErr != nil { 80 return nil, execErr 81 } 82 } 83 84 if err := db.Flush(ctx); err == nil { 85 return db.GetRoot(ctx) 86 } else { 87 return nil, err 88 } 89 } 90 91 // NewTestSQLCtx returns a new *sql.Context with a default DoltSession, a new IndexRegistry, and a new ViewRegistry 92 func NewTestSQLCtx(ctx context.Context) *sql.Context { 93 session := DefaultDoltSession() 94 sqlCtx := sql.NewContext( 95 ctx, 96 sql.WithSession(session), 97 sql.WithIndexRegistry(sql.NewIndexRegistry()), 98 sql.WithViewRegistry(sql.NewViewRegistry()), 99 ).WithCurrentDB("dolt") 100 101 return sqlCtx 102 } 103 104 // NewTestEngine creates a new default engine, and a *sql.Context and initializes indexes and schema fragments. 105 func NewTestEngine(ctx context.Context, db Database, root *doltdb.RootValue) (*sqle.Engine, *sql.Context, error) { 106 engine := sqle.NewDefault() 107 engine.AddDatabase(db) 108 109 sqlCtx := NewTestSQLCtx(ctx) 110 DSessFromSess(sqlCtx.Session).AddDB(sqlCtx, db, db.DbData()) 111 sqlCtx.SetCurrentDatabase(db.Name()) 112 err := db.SetRoot(sqlCtx, root) 113 114 if err != nil { 115 return nil, nil, err 116 } 117 118 err = RegisterSchemaFragments(sqlCtx, db, root) 119 120 if err != nil { 121 return nil, nil, err 122 } 123 124 return engine, sqlCtx, nil 125 } 126 127 // ExecuteSelect executes the select statement given and returns the resulting rows, or an error if one is encountered. 128 func ExecuteSelect(dEnv *env.DoltEnv, ddb *doltdb.DoltDB, root *doltdb.RootValue, query string) ([]sql.Row, error) { 129 130 dbData := env.DbData{ 131 Ddb: ddb, 132 Rsw: dEnv.RepoStateWriter(), 133 Rsr: dEnv.RepoStateReader(), 134 Drw: dEnv.DocsReadWriter(), 135 } 136 137 db := NewDatabase("dolt", dbData) 138 engine, ctx, err := NewTestEngine(context.Background(), db, root) 139 if err != nil { 140 return nil, err 141 } 142 143 _, rowIter, err := engine.Query(ctx, query) 144 if err != nil { 145 return nil, err 146 } 147 148 var ( 149 rows []sql.Row 150 rowErr error 151 row sql.Row 152 ) 153 for row, rowErr = rowIter.Next(); rowErr == nil; row, rowErr = rowIter.Next() { 154 rows = append(rows, row) 155 } 156 157 if rowErr != io.EOF { 158 return nil, rowErr 159 } 160 161 return rows, nil 162 } 163 164 func drainIter(ctx *sql.Context, iter sql.RowIter) error { 165 for { 166 _, err := iter.Next() 167 if err == io.EOF { 168 break 169 } else if err != nil { 170 closeErr := iter.Close(ctx) 171 if closeErr != nil { 172 panic(fmt.Errorf("%v\n%v", err, closeErr)) 173 } 174 return err 175 } 176 } 177 return iter.Close(ctx) 178 }