github.com/dolthub/go-mysql-server@v0.18.0/enginetest/sqllogictest/harness/memory_harness.go (about) 1 // Copyright 2020-2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package harness 16 17 import ( 18 "fmt" 19 "io" 20 "strconv" 21 "strings" 22 "sync/atomic" 23 24 "github.com/dolthub/vitess/go/vt/proto/query" 25 "github.com/shopspring/decimal" 26 27 sqle "github.com/dolthub/go-mysql-server" 28 "github.com/dolthub/go-mysql-server/enginetest" 29 "github.com/dolthub/go-mysql-server/memory" 30 "github.com/dolthub/go-mysql-server/sql" 31 ) 32 33 type memoryHarness struct { 34 engine *sqle.Engine 35 harness enginetest.VersionedDBHarness 36 } 37 38 func NewMemoryHarness(harness enginetest.VersionedDBHarness) *memoryHarness { 39 return &memoryHarness{ 40 harness: harness, 41 } 42 } 43 44 func (h *memoryHarness) EngineStr() string { 45 return "mysql" 46 } 47 48 func (h *memoryHarness) Init() error { 49 dbs := h.harness.NewDatabases("mydb") 50 pro := memory.NewDBProvider(dbs...) 51 h.engine = sqle.NewDefault(pro) 52 return nil 53 } 54 55 func (h *memoryHarness) ExecuteStatement(statement string) error { 56 ctx := h.newContext() 57 58 _, rowIter, err := h.engine.Query(ctx, statement) 59 if err != nil { 60 return err 61 } 62 63 return enginetest.DrainIterator(ctx, rowIter) 64 } 65 66 var pid uint32 67 68 func (h *memoryHarness) newContext() *sql.Context { 69 ctx := h.harness.NewContext() 70 ctx.SetCurrentDatabase("mydb") 71 ctx.ApplyOpts(sql.WithPid(uint64(atomic.AddUint32(&pid, 1)))) 72 return ctx 73 } 74 75 func (h *memoryHarness) ExecuteQuery(statement string) (schema string, results []string, err error) { 76 ctx := h.newContext() 77 78 var sch sql.Schema 79 var rowIter sql.RowIter 80 defer func() { 81 if r := recover(); r != nil { 82 // Panics leave the engine in a bad state that we have to clean up 83 h.engine.ProcessList.Kill(pid) 84 panic(r) 85 } 86 }() 87 88 sch, rowIter, err = h.engine.Query(ctx, statement) 89 if err != nil { 90 return "", nil, err 91 } 92 93 schemaString, err := schemaToSchemaString(sch) 94 if err != nil { 95 return "", nil, err 96 } 97 98 results, err = rowsToResultStrings(ctx, rowIter) 99 if err != nil { 100 return "", nil, err 101 } 102 103 return schemaString, results, nil 104 } 105 106 // Returns the rows in the iterator given as an array of their string representations, as expected by the test files 107 func rowsToResultStrings(ctx *sql.Context, iter sql.RowIter) ([]string, error) { 108 var results []string 109 if iter == nil { 110 return results, nil 111 } 112 113 for { 114 row, err := iter.Next(ctx) 115 if err == io.EOF { 116 return results, nil 117 } else if err != nil { 118 enginetest.DrainIteratorIgnoreErrors(ctx, iter) 119 return nil, err 120 } else { 121 for _, col := range row { 122 results = append(results, toSqlString(col)) 123 } 124 } 125 } 126 } 127 128 func toSqlString(val interface{}) string { 129 if val == nil { 130 return "NULL" 131 } 132 133 switch v := val.(type) { 134 case float32, float64: 135 // exactly 3 decimal points for floats 136 return fmt.Sprintf("%.3f", v) 137 case decimal.Decimal: 138 // exactly 3 decimal points for floats 139 return v.StringFixed(3) 140 case int: 141 return strconv.Itoa(v) 142 case uint: 143 return strconv.Itoa(int(v)) 144 case int8: 145 return strconv.Itoa(int(v)) 146 case uint8: 147 return strconv.Itoa(int(v)) 148 case int16: 149 return strconv.Itoa(int(v)) 150 case uint16: 151 return strconv.Itoa(int(v)) 152 case int32: 153 return strconv.Itoa(int(v)) 154 case uint32: 155 return strconv.Itoa(int(v)) 156 case int64: 157 return strconv.Itoa(int(v)) 158 case uint64: 159 return strconv.Itoa(int(v)) 160 case string: 161 return v 162 // Mysql returns 1 and 0 for boolean values, mimic that 163 case bool: 164 if v { 165 return "1" 166 } else { 167 return "0" 168 } 169 default: 170 panic(fmt.Sprintf("No conversion for value %v of type %T", val, val)) 171 } 172 } 173 174 func schemaToSchemaString(sch sql.Schema) (string, error) { 175 b := strings.Builder{} 176 for _, col := range sch { 177 switch col.Type.Type() { 178 case query.Type_INT8, query.Type_INT16, query.Type_INT24, query.Type_INT32, query.Type_INT64, 179 query.Type_UINT8, query.Type_UINT16, query.Type_UINT24, query.Type_UINT32, query.Type_UINT64, 180 query.Type_BIT: 181 b.WriteString("I") 182 case query.Type_TEXT, query.Type_VARCHAR: 183 b.WriteString("T") 184 case query.Type_FLOAT32, query.Type_FLOAT64, query.Type_DECIMAL: 185 b.WriteString("R") 186 default: 187 return "", fmt.Errorf("Unhandled type: %v", col.Type) 188 } 189 } 190 return b.String(), nil 191 }