github.com/dolthub/go-mysql-server@v0.18.0/enginetest/mysqlshim/iter.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package mysqlshim 16 17 import ( 18 dsql "database/sql" 19 "io" 20 "reflect" 21 "time" 22 23 "github.com/dolthub/go-mysql-server/sql" 24 ) 25 26 // mysqlIter wraps an iterator returned by the MySQL connection. 27 type mysqlIter struct { 28 rows *dsql.Rows 29 types []reflect.Type 30 } 31 32 var _ sql.RowIter = mysqlIter{} 33 34 // newMySQLIter returns a new mysqlIter. 35 func newMySQLIter(rows *dsql.Rows) mysqlIter { 36 columnTypes, err := rows.ColumnTypes() 37 if err != nil { 38 panic(err) 39 } 40 types := make([]reflect.Type, len(columnTypes)) 41 for i, columnType := range columnTypes { 42 scanType := columnType.ScanType() 43 switch scanType { 44 case reflect.TypeOf(dsql.RawBytes{}): 45 scanType = reflect.TypeOf("") 46 case reflect.TypeOf(dsql.NullBool{}): 47 scanType = reflect.TypeOf(true) 48 case reflect.TypeOf(dsql.NullByte{}): 49 scanType = reflect.TypeOf(byte(0)) 50 case reflect.TypeOf(dsql.NullFloat64{}): 51 scanType = reflect.TypeOf(float64(0)) 52 case reflect.TypeOf(dsql.NullInt16{}): 53 scanType = reflect.TypeOf(int16(0)) 54 case reflect.TypeOf(dsql.NullInt32{}): 55 scanType = reflect.TypeOf(int32(0)) 56 case reflect.TypeOf(dsql.NullInt64{}): 57 scanType = reflect.TypeOf(int64(0)) 58 case reflect.TypeOf(dsql.NullString{}): 59 scanType = reflect.TypeOf("") 60 case reflect.TypeOf(dsql.NullTime{}): 61 scanType = reflect.TypeOf(time.Time{}) 62 } 63 types[i] = scanType 64 } 65 return mysqlIter{rows, types} 66 } 67 68 // Next implements the interface sql.RowIter. 69 func (m mysqlIter) Next(ctx *sql.Context) (sql.Row, error) { 70 if m.rows.Next() { 71 output := make(sql.Row, len(m.types)) 72 for i, typ := range m.types { 73 output[i] = reflect.New(typ).Interface() 74 } 75 err := m.rows.Scan(output...) 76 if err != nil { 77 return nil, err 78 } 79 for i, val := range output { 80 reflectVal := reflect.ValueOf(val) 81 if reflectVal.IsNil() { 82 output[i] = nil 83 } else { 84 output[i] = reflectVal.Elem().Interface() 85 if byteSlice, ok := val.([]byte); ok { 86 output[i] = string(byteSlice) 87 } 88 } 89 } 90 return output, nil 91 } 92 return nil, io.EOF 93 } 94 95 // Close implements the interface sql.RowIter. 96 func (m mysqlIter) Close(ctx *sql.Context) error { 97 return m.rows.Close() 98 }