github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/sqle/sqlutil/sql_row.go (about)

     1  // Copyright 2020 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sqlutil
    16  
    17  import (
    18  	"context"
    19  	"errors"
    20  	"fmt"
    21  	"strconv"
    22  	"time"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  
    26  	"github.com/dolthub/dolt/go/libraries/doltcore/row"
    27  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    28  	"github.com/dolthub/dolt/go/store/types"
    29  )
    30  
    31  // DoltRowToSqlRow constructs a go-mysql-server sql.Row from a Dolt row.Row.
    32  func DoltRowToSqlRow(doltRow row.Row, sch schema.Schema) (sql.Row, error) {
    33  	colVals := make(sql.Row, sch.GetAllCols().Size())
    34  	i := 0
    35  
    36  	_, err := doltRow.IterSchema(sch, func(tag uint64, val types.Value) (stop bool, err error) {
    37  		col, _ := sch.GetAllCols().GetByTag(tag)
    38  		colVals[i], err = col.TypeInfo.ConvertNomsValueToValue(val)
    39  		i++
    40  
    41  		stop = err != nil
    42  		return
    43  	})
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  
    48  	return sql.NewRow(colVals...), nil
    49  }
    50  
    51  // SqlRowToDoltRow constructs a Dolt row.Row from a go-mysql-server sql.Row.
    52  func SqlRowToDoltRow(ctx context.Context, vrw types.ValueReadWriter, r sql.Row, doltSchema schema.Schema) (row.Row, error) {
    53  	if schema.IsKeyless(doltSchema) {
    54  		return keylessDoltRowFromSqlRow(ctx, vrw, r, doltSchema)
    55  	}
    56  	return pkDoltRowFromSqlRow(ctx, vrw, r, doltSchema)
    57  }
    58  
    59  // DoltKeyValueAndMappingFromSqlRow converts a sql.Row to key and value tuples and keeps a mapping from tag to value that
    60  // can be used to speed up index key generation for foreign key checks.
    61  func DoltKeyValueAndMappingFromSqlRow(ctx context.Context, vrw types.ValueReadWriter, r sql.Row, doltSchema schema.Schema) (types.Tuple, types.Tuple, map[uint64]types.Value, error) {
    62  	allCols := doltSchema.GetAllCols()
    63  	nonPKCols := doltSchema.GetNonPKCols()
    64  
    65  	numCols := allCols.Size()
    66  	vals := make([]types.Value, numCols*2)
    67  	tagToVal := make(map[uint64]types.Value, numCols)
    68  
    69  	numNonPKVals := nonPKCols.Size() * 2
    70  	nonPKVals := vals[:numNonPKVals]
    71  	pkVals := vals[numNonPKVals:]
    72  
    73  	// values for the pk tuple are in schema order
    74  	pkIdx := 0
    75  	for i := 0; i < numCols; i++ {
    76  		schCol := allCols.GetAtIndex(i)
    77  		val := r[i]
    78  		if val == nil {
    79  			if !schCol.IsNullable() {
    80  				return types.Tuple{}, types.Tuple{}, nil, fmt.Errorf("column <%v> received nil but is non-nullable", schCol.Name)
    81  			}
    82  			continue
    83  		}
    84  
    85  		tag := schCol.Tag
    86  		nomsVal, err := schCol.TypeInfo.ConvertValueToNomsValue(ctx, vrw, val)
    87  
    88  		if err != nil {
    89  			return types.Tuple{}, types.Tuple{}, nil, err
    90  		}
    91  
    92  		tagToVal[tag] = nomsVal
    93  
    94  		if schCol.IsPartOfPK {
    95  			pkVals[pkIdx] = types.Uint(tag)
    96  			pkVals[pkIdx+1] = nomsVal
    97  			pkIdx += 2
    98  		}
    99  	}
   100  
   101  	// no nulls in keys
   102  	if pkIdx != len(pkVals) {
   103  		return types.Tuple{}, types.Tuple{}, nil, errors.New("not all pk columns have a value")
   104  	}
   105  
   106  	// non pk values in tag sorted order
   107  	nonPKIdx := 0
   108  	nonPKTags := len(tagToVal) - (pkIdx / 2)
   109  	for i := 0; i < len(nonPKCols.SortedTags) && nonPKIdx < (nonPKTags*2); i++ {
   110  		tag := nonPKCols.SortedTags[i]
   111  		val, ok := tagToVal[tag]
   112  
   113  		if ok {
   114  			nonPKVals[nonPKIdx] = types.Uint(tag)
   115  			nonPKVals[nonPKIdx+1] = val
   116  			nonPKIdx += 2
   117  		}
   118  	}
   119  
   120  	nonPKVals = nonPKVals[:nonPKIdx]
   121  
   122  	nbf := vrw.Format()
   123  	keyTuple, err := types.NewTuple(nbf, pkVals...)
   124  
   125  	if err != nil {
   126  		return types.Tuple{}, types.Tuple{}, nil, err
   127  	}
   128  
   129  	valTuple, err := types.NewTuple(nbf, nonPKVals...)
   130  
   131  	if err != nil {
   132  		return types.Tuple{}, types.Tuple{}, nil, err
   133  	}
   134  
   135  	return keyTuple, valTuple, tagToVal, nil
   136  }
   137  
   138  func pkDoltRowFromSqlRow(ctx context.Context, vrw types.ValueReadWriter, r sql.Row, doltSchema schema.Schema) (row.Row, error) {
   139  	taggedVals := make(row.TaggedValues)
   140  	allCols := doltSchema.GetAllCols()
   141  	for i, val := range r {
   142  		tag := allCols.Tags[i]
   143  		schCol := allCols.TagToCol[tag]
   144  		if val != nil {
   145  			var err error
   146  			taggedVals[tag], err = schCol.TypeInfo.ConvertValueToNomsValue(ctx, vrw, val)
   147  			if err != nil {
   148  				return nil, err
   149  			}
   150  		} else if !schCol.IsNullable() {
   151  			// TODO: this isn't an error in the case of result set construction (where non-null columns can indeed be null)
   152  			return nil, fmt.Errorf("column <%v> received nil but is non-nullable", schCol.Name)
   153  		}
   154  	}
   155  	return row.New(vrw.Format(), doltSchema, taggedVals)
   156  }
   157  
   158  func keylessDoltRowFromSqlRow(ctx context.Context, vrw types.ValueReadWriter, sqlRow sql.Row, sch schema.Schema) (row.Row, error) {
   159  	j := 0
   160  	vals := make([]types.Value, sch.GetAllCols().Size()*2)
   161  
   162  	for idx, val := range sqlRow {
   163  		if val != nil {
   164  			col := sch.GetAllCols().GetByIndex(idx)
   165  			nv, err := col.TypeInfo.ConvertValueToNomsValue(ctx, vrw, val)
   166  			if err != nil {
   167  				return nil, err
   168  			}
   169  
   170  			vals[j] = types.Uint(col.Tag)
   171  			vals[j+1] = nv
   172  			j += 2
   173  		}
   174  	}
   175  
   176  	return row.KeylessRow(vrw.Format(), vals[:j]...)
   177  }
   178  
   179  // SqlColToStr is a utility function for converting a sql column of type interface{} to a string
   180  func SqlColToStr(ctx context.Context, col interface{}) string {
   181  	if col != nil {
   182  		switch typedCol := col.(type) {
   183  		case int:
   184  			return strconv.FormatInt(int64(typedCol), 10)
   185  		case int32:
   186  			return strconv.FormatInt(int64(typedCol), 10)
   187  		case int64:
   188  			return strconv.FormatInt(int64(typedCol), 10)
   189  		case int16:
   190  			return strconv.FormatInt(int64(typedCol), 10)
   191  		case int8:
   192  			return strconv.FormatInt(int64(typedCol), 10)
   193  		case uint:
   194  			return strconv.FormatUint(uint64(typedCol), 10)
   195  		case uint32:
   196  			return strconv.FormatUint(uint64(typedCol), 10)
   197  		case uint64:
   198  			return strconv.FormatUint(uint64(typedCol), 10)
   199  		case uint16:
   200  			return strconv.FormatUint(uint64(typedCol), 10)
   201  		case uint8:
   202  			return strconv.FormatUint(uint64(typedCol), 10)
   203  		case float64:
   204  			return strconv.FormatFloat(float64(typedCol), 'g', -1, 64)
   205  		case float32:
   206  			return strconv.FormatFloat(float64(typedCol), 'g', -1, 32)
   207  		case string:
   208  			return typedCol
   209  		case bool:
   210  			if typedCol {
   211  				return "true"
   212  			} else {
   213  				return "false"
   214  			}
   215  		case time.Time:
   216  			return typedCol.Format("2006-01-02 15:04:05.999999 -0700 MST")
   217  		case sql.JSONValue:
   218  			s, err := typedCol.ToString(sql.NewContext(ctx))
   219  			if err != nil {
   220  				s = err.Error()
   221  			}
   222  			return s
   223  		default:
   224  			return fmt.Sprintf("%v", typedCol)
   225  		}
   226  	}
   227  
   228  	return ""
   229  }