github.com/dolthub/go-mysql-server@v0.18.0/sql/rows.go (about)

     1  // Copyright 2020-2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sql
    16  
    17  import (
    18  	"fmt"
    19  	"io"
    20  	"strings"
    21  
    22  	"github.com/dolthub/vitess/go/vt/proto/query"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql/values"
    25  )
    26  
    27  // Row is a tuple of values.
    28  type Row []interface{}
    29  
    30  // NewRow creates a row from the given values.
    31  func NewRow(values ...interface{}) Row {
    32  	row := make([]interface{}, len(values))
    33  	copy(row, values)
    34  	return row
    35  }
    36  
    37  // Copy creates a new row with the same values as the current one.
    38  func (r Row) Copy() Row {
    39  	return NewRow(r...)
    40  }
    41  
    42  // Append appends all the values in r2 to this row and returns the result
    43  func (r Row) Append(r2 Row) Row {
    44  	row := make(Row, len(r)+len(r2))
    45  	copy(row, r)
    46  	for i := range r2 {
    47  		row[i+len(r)] = r2[i]
    48  	}
    49  	return row
    50  }
    51  
    52  // Equals checks whether two rows are equal given a schema.
    53  func (r Row) Equals(row Row, schema Schema) (bool, error) {
    54  	if len(row) != len(r) || len(row) != len(schema) {
    55  		return false, nil
    56  	}
    57  
    58  	for i, colLeft := range r {
    59  		colRight := row[i]
    60  		cmp, err := schema[i].Type.Compare(colLeft, colRight)
    61  		if err != nil {
    62  			return false, err
    63  		}
    64  		if cmp != 0 {
    65  			return false, nil
    66  		}
    67  	}
    68  
    69  	return true, nil
    70  }
    71  
    72  // FormatRow returns a formatted string representing this row's values
    73  func FormatRow(row Row) string {
    74  	var sb strings.Builder
    75  	sb.WriteRune('[')
    76  	for i, v := range row {
    77  		if i > 0 {
    78  			sb.WriteRune(',')
    79  		}
    80  		sb.WriteString(fmt.Sprintf("%v", v))
    81  	}
    82  	sb.WriteRune(']')
    83  	return sb.String()
    84  }
    85  
    86  // RowIter is an iterator that produces rows.
    87  // TODO: most row iters need to be Disposable for CachedResult safety
    88  type RowIter interface {
    89  	// Next retrieves the next row. It will return io.EOF if it's the last row.
    90  	// After retrieving the last row, Close will be automatically closed.
    91  	Next(ctx *Context) (Row, error)
    92  	Closer
    93  }
    94  
    95  // RowIterToRows converts a row iterator to a slice of rows.
    96  func RowIterToRows(ctx *Context, i RowIter) ([]Row, error) {
    97  	var rows []Row
    98  	for {
    99  		row, err := i.Next(ctx)
   100  		if err == io.EOF {
   101  			break
   102  		}
   103  
   104  		if err != nil {
   105  			i.Close(ctx)
   106  			return nil, err
   107  		}
   108  
   109  		rows = append(rows, row)
   110  	}
   111  
   112  	return rows, i.Close(ctx)
   113  }
   114  
   115  func rowFromRow2(sch Schema, r Row2) Row {
   116  	row := make(Row, len(sch))
   117  	for i, col := range sch {
   118  		switch col.Type.Type() {
   119  		case query.Type_INT8:
   120  			row[i] = values.ReadInt8(r.GetField(i).Val)
   121  		case query.Type_UINT8:
   122  			row[i] = values.ReadUint8(r.GetField(i).Val)
   123  		case query.Type_INT16:
   124  			row[i] = values.ReadInt16(r.GetField(i).Val)
   125  		case query.Type_UINT16:
   126  			row[i] = values.ReadUint16(r.GetField(i).Val)
   127  		case query.Type_INT32:
   128  			row[i] = values.ReadInt32(r.GetField(i).Val)
   129  		case query.Type_UINT32:
   130  			row[i] = values.ReadUint32(r.GetField(i).Val)
   131  		case query.Type_INT64:
   132  			row[i] = values.ReadInt64(r.GetField(i).Val)
   133  		case query.Type_UINT64:
   134  			row[i] = values.ReadUint64(r.GetField(i).Val)
   135  		case query.Type_FLOAT32:
   136  			row[i] = values.ReadFloat32(r.GetField(i).Val)
   137  		case query.Type_FLOAT64:
   138  			row[i] = values.ReadFloat64(r.GetField(i).Val)
   139  		case query.Type_TEXT, query.Type_VARCHAR, query.Type_CHAR:
   140  			row[i] = values.ReadString(r.GetField(i).Val, values.ByteOrderCollation)
   141  		case query.Type_BLOB, query.Type_VARBINARY, query.Type_BINARY:
   142  			row[i] = values.ReadBytes(r.GetField(i).Val, values.ByteOrderCollation)
   143  		case query.Type_BIT:
   144  			fallthrough
   145  		case query.Type_ENUM:
   146  			fallthrough
   147  		case query.Type_SET:
   148  			fallthrough
   149  		case query.Type_TUPLE:
   150  			fallthrough
   151  		case query.Type_GEOMETRY:
   152  			fallthrough
   153  		case query.Type_JSON:
   154  			fallthrough
   155  		case query.Type_EXPRESSION:
   156  			fallthrough
   157  		case query.Type_INT24:
   158  			fallthrough
   159  		case query.Type_UINT24:
   160  			fallthrough
   161  		case query.Type_TIMESTAMP:
   162  			fallthrough
   163  		case query.Type_DATE:
   164  			fallthrough
   165  		case query.Type_TIME:
   166  			fallthrough
   167  		case query.Type_DATETIME:
   168  			fallthrough
   169  		case query.Type_YEAR:
   170  			fallthrough
   171  		case query.Type_DECIMAL:
   172  			panic(fmt.Sprintf("Unimplemented type conversion: %T", col.Type))
   173  		default:
   174  			panic(fmt.Sprintf("unknown type %T", col.Type))
   175  		}
   176  	}
   177  	return row
   178  }
   179  
   180  // RowsToRowIter creates a RowIter that iterates over the given rows.
   181  func RowsToRowIter(rows ...Row) RowIter {
   182  	return &sliceRowIter{rows: rows}
   183  }
   184  
   185  type sliceRowIter struct {
   186  	rows []Row
   187  	idx  int
   188  }
   189  
   190  func (i *sliceRowIter) Next(*Context) (Row, error) {
   191  	if i.idx >= len(i.rows) {
   192  		return nil, io.EOF
   193  	}
   194  
   195  	r := i.rows[i.idx]
   196  	i.idx++
   197  	return r.Copy(), nil
   198  }
   199  
   200  func (i *sliceRowIter) Close(*Context) error {
   201  	i.rows = nil
   202  	return nil
   203  }