github.com/dolthub/go-mysql-server@v0.18.0/enginetest/sqllogictest/harness/memory_harness.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package harness
    16  
    17  import (
    18  	"fmt"
    19  	"io"
    20  	"strconv"
    21  	"strings"
    22  	"sync/atomic"
    23  
    24  	"github.com/dolthub/vitess/go/vt/proto/query"
    25  	"github.com/shopspring/decimal"
    26  
    27  	sqle "github.com/dolthub/go-mysql-server"
    28  	"github.com/dolthub/go-mysql-server/enginetest"
    29  	"github.com/dolthub/go-mysql-server/memory"
    30  	"github.com/dolthub/go-mysql-server/sql"
    31  )
    32  
    33  type memoryHarness struct {
    34  	engine  *sqle.Engine
    35  	harness enginetest.VersionedDBHarness
    36  }
    37  
    38  func NewMemoryHarness(harness enginetest.VersionedDBHarness) *memoryHarness {
    39  	return &memoryHarness{
    40  		harness: harness,
    41  	}
    42  }
    43  
    44  func (h *memoryHarness) EngineStr() string {
    45  	return "mysql"
    46  }
    47  
    48  func (h *memoryHarness) Init() error {
    49  	dbs := h.harness.NewDatabases("mydb")
    50  	pro := memory.NewDBProvider(dbs...)
    51  	h.engine = sqle.NewDefault(pro)
    52  	return nil
    53  }
    54  
    55  func (h *memoryHarness) ExecuteStatement(statement string) error {
    56  	ctx := h.newContext()
    57  
    58  	_, rowIter, err := h.engine.Query(ctx, statement)
    59  	if err != nil {
    60  		return err
    61  	}
    62  
    63  	return enginetest.DrainIterator(ctx, rowIter)
    64  }
    65  
    66  var pid uint32
    67  
    68  func (h *memoryHarness) newContext() *sql.Context {
    69  	ctx := h.harness.NewContext()
    70  	ctx.SetCurrentDatabase("mydb")
    71  	ctx.ApplyOpts(sql.WithPid(uint64(atomic.AddUint32(&pid, 1))))
    72  	return ctx
    73  }
    74  
    75  func (h *memoryHarness) ExecuteQuery(statement string) (schema string, results []string, err error) {
    76  	ctx := h.newContext()
    77  
    78  	var sch sql.Schema
    79  	var rowIter sql.RowIter
    80  	defer func() {
    81  		if r := recover(); r != nil {
    82  			// Panics leave the engine in a bad state that we have to clean up
    83  			h.engine.ProcessList.Kill(pid)
    84  			panic(r)
    85  		}
    86  	}()
    87  
    88  	sch, rowIter, err = h.engine.Query(ctx, statement)
    89  	if err != nil {
    90  		return "", nil, err
    91  	}
    92  
    93  	schemaString, err := schemaToSchemaString(sch)
    94  	if err != nil {
    95  		return "", nil, err
    96  	}
    97  
    98  	results, err = rowsToResultStrings(ctx, rowIter)
    99  	if err != nil {
   100  		return "", nil, err
   101  	}
   102  
   103  	return schemaString, results, nil
   104  }
   105  
   106  // Returns the rows in the iterator given as an array of their string representations, as expected by the test files
   107  func rowsToResultStrings(ctx *sql.Context, iter sql.RowIter) ([]string, error) {
   108  	var results []string
   109  	if iter == nil {
   110  		return results, nil
   111  	}
   112  
   113  	for {
   114  		row, err := iter.Next(ctx)
   115  		if err == io.EOF {
   116  			return results, nil
   117  		} else if err != nil {
   118  			enginetest.DrainIteratorIgnoreErrors(ctx, iter)
   119  			return nil, err
   120  		} else {
   121  			for _, col := range row {
   122  				results = append(results, toSqlString(col))
   123  			}
   124  		}
   125  	}
   126  }
   127  
   128  func toSqlString(val interface{}) string {
   129  	if val == nil {
   130  		return "NULL"
   131  	}
   132  
   133  	switch v := val.(type) {
   134  	case float32, float64:
   135  		// exactly 3 decimal points for floats
   136  		return fmt.Sprintf("%.3f", v)
   137  	case decimal.Decimal:
   138  		// exactly 3 decimal points for floats
   139  		return v.StringFixed(3)
   140  	case int:
   141  		return strconv.Itoa(v)
   142  	case uint:
   143  		return strconv.Itoa(int(v))
   144  	case int8:
   145  		return strconv.Itoa(int(v))
   146  	case uint8:
   147  		return strconv.Itoa(int(v))
   148  	case int16:
   149  		return strconv.Itoa(int(v))
   150  	case uint16:
   151  		return strconv.Itoa(int(v))
   152  	case int32:
   153  		return strconv.Itoa(int(v))
   154  	case uint32:
   155  		return strconv.Itoa(int(v))
   156  	case int64:
   157  		return strconv.Itoa(int(v))
   158  	case uint64:
   159  		return strconv.Itoa(int(v))
   160  	case string:
   161  		return v
   162  	// Mysql returns 1 and 0 for boolean values, mimic that
   163  	case bool:
   164  		if v {
   165  			return "1"
   166  		} else {
   167  			return "0"
   168  		}
   169  	default:
   170  		panic(fmt.Sprintf("No conversion for value %v of type %T", val, val))
   171  	}
   172  }
   173  
   174  func schemaToSchemaString(sch sql.Schema) (string, error) {
   175  	b := strings.Builder{}
   176  	for _, col := range sch {
   177  		switch col.Type.Type() {
   178  		case query.Type_INT8, query.Type_INT16, query.Type_INT24, query.Type_INT32, query.Type_INT64,
   179  			query.Type_UINT8, query.Type_UINT16, query.Type_UINT24, query.Type_UINT32, query.Type_UINT64,
   180  			query.Type_BIT:
   181  			b.WriteString("I")
   182  		case query.Type_TEXT, query.Type_VARCHAR:
   183  			b.WriteString("T")
   184  		case query.Type_FLOAT32, query.Type_FLOAT64, query.Type_DECIMAL:
   185  			b.WriteString("R")
   186  		default:
   187  			return "", fmt.Errorf("Unhandled type: %v", col.Type)
   188  		}
   189  	}
   190  	return b.String(), nil
   191  }