github.com/hasnat/dolt/go@v0.0.0-20210628190320-9eb5d843fbb7/libraries/doltcore/sqle/sqlfmt/row_fmt.go (about)

     1  // Copyright 2020 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sqlfmt
    16  
    17  import (
    18  	"bytes"
    19  	"fmt"
    20  	"strings"
    21  
    22  	"github.com/dolthub/vitess/go/sqltypes"
    23  
    24  	"github.com/dolthub/dolt/go/libraries/doltcore/row"
    25  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    26  	"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
    27  	"github.com/dolthub/dolt/go/store/types"
    28  )
    29  
    30  const singleQuote = `'`
    31  
    32  // Quotes the identifier given with backticks.
    33  func QuoteIdentifier(s string) string {
    34  	return "`" + s + "`"
    35  }
    36  
    37  // QuoteComment quotes the given string with apostrophes, and escapes any contained within the string.
    38  func QuoteComment(s string) string {
    39  	return `'` + strings.ReplaceAll(s, `'`, `\'`) + `'`
    40  }
    41  
    42  func RowAsInsertStmt(r row.Row, tableName string, tableSch schema.Schema) (string, error) {
    43  	var b strings.Builder
    44  	b.WriteString("INSERT INTO ")
    45  	b.WriteString(QuoteIdentifier(tableName))
    46  	b.WriteString(" ")
    47  
    48  	b.WriteString("(")
    49  	seenOne := false
    50  	err := tableSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
    51  		if seenOne {
    52  			b.WriteRune(',')
    53  		}
    54  		b.WriteString(QuoteIdentifier(col.Name))
    55  		seenOne = true
    56  		return false, nil
    57  	})
    58  
    59  	if err != nil {
    60  		return "", err
    61  	}
    62  
    63  	b.WriteString(")")
    64  
    65  	b.WriteString(" VALUES (")
    66  	seenOne = false
    67  	_, err = r.IterSchema(tableSch, func(tag uint64, val types.Value) (stop bool, err error) {
    68  		if seenOne {
    69  			b.WriteRune(',')
    70  		}
    71  		col, _ := tableSch.GetAllCols().GetByTag(tag)
    72  		sqlString, err := valueAsSqlString(col.TypeInfo, val)
    73  		if err != nil {
    74  			return true, err
    75  		}
    76  		b.WriteString(sqlString)
    77  		seenOne = true
    78  		return false, nil
    79  	})
    80  
    81  	if err != nil {
    82  		return "", err
    83  	}
    84  
    85  	b.WriteString(");")
    86  
    87  	return b.String(), nil
    88  }
    89  
    90  func RowAsDeleteStmt(r row.Row, tableName string, tableSch schema.Schema) (string, error) {
    91  	var b strings.Builder
    92  	b.WriteString("DELETE FROM ")
    93  	b.WriteString(QuoteIdentifier(tableName))
    94  
    95  	b.WriteString(" WHERE (")
    96  	seenOne := false
    97  	_, err := r.IterSchema(tableSch, func(tag uint64, val types.Value) (stop bool, err error) {
    98  		col, _ := tableSch.GetAllCols().GetByTag(tag)
    99  		if col.IsPartOfPK {
   100  			if seenOne {
   101  				b.WriteString(" AND ")
   102  			}
   103  			sqlString, err := valueAsSqlString(col.TypeInfo, val)
   104  			if err != nil {
   105  				return true, err
   106  			}
   107  			b.WriteString(QuoteIdentifier(col.Name))
   108  			b.WriteRune('=')
   109  			b.WriteString(sqlString)
   110  			seenOne = true
   111  		}
   112  		return false, nil
   113  	})
   114  
   115  	if err != nil {
   116  		return "", err
   117  	}
   118  
   119  	b.WriteString(");")
   120  	return b.String(), nil
   121  }
   122  
   123  func RowAsUpdateStmt(r row.Row, tableName string, tableSch schema.Schema) (string, error) {
   124  	var b strings.Builder
   125  	b.WriteString("UPDATE ")
   126  	b.WriteString(QuoteIdentifier(tableName))
   127  	b.WriteString(" ")
   128  
   129  	b.WriteString("SET ")
   130  	seenOne := false
   131  	_, err := r.IterSchema(tableSch, func(tag uint64, val types.Value) (stop bool, err error) {
   132  		col, _ := tableSch.GetAllCols().GetByTag(tag)
   133  		if !col.IsPartOfPK {
   134  			if seenOne {
   135  				b.WriteRune(',')
   136  			}
   137  			sqlString, err := valueAsSqlString(col.TypeInfo, val)
   138  			if err != nil {
   139  				return true, err
   140  			}
   141  			b.WriteString(QuoteIdentifier(col.Name))
   142  			b.WriteRune('=')
   143  			b.WriteString(sqlString)
   144  			seenOne = true
   145  		}
   146  		return false, nil
   147  	})
   148  
   149  	if err != nil {
   150  		return "", err
   151  	}
   152  
   153  	b.WriteString(" WHERE (")
   154  	seenOne = false
   155  	_, err = r.IterSchema(tableSch, func(tag uint64, val types.Value) (stop bool, err error) {
   156  		col, _ := tableSch.GetAllCols().GetByTag(tag)
   157  		if col.IsPartOfPK {
   158  			if seenOne {
   159  				b.WriteString(" AND ")
   160  			}
   161  			sqlString, err := valueAsSqlString(col.TypeInfo, val)
   162  			if err != nil {
   163  				return true, err
   164  			}
   165  			b.WriteString(QuoteIdentifier(col.Name))
   166  			b.WriteRune('=')
   167  			b.WriteString(sqlString)
   168  			seenOne = true
   169  		}
   170  		return false, nil
   171  	})
   172  
   173  	if err != nil {
   174  		return "", err
   175  	}
   176  
   177  	b.WriteString(");")
   178  	return b.String(), nil
   179  }
   180  
   181  func valueAsSqlString(ti typeinfo.TypeInfo, value types.Value) (string, error) {
   182  	if types.IsNull(value) {
   183  		return "NULL", nil
   184  	}
   185  
   186  	str, err := ti.FormatValue(value)
   187  
   188  	if err != nil {
   189  		return "", err
   190  	}
   191  
   192  	switch ti.GetTypeIdentifier() {
   193  	case typeinfo.BoolTypeIdentifier:
   194  		// todo: unclear if we want this to output with "TRUE/FALSE" or 1/0
   195  		if value.(types.Bool) {
   196  			return "TRUE", nil
   197  		}
   198  		return "FALSE", nil
   199  	case typeinfo.UuidTypeIdentifier:
   200  		return singleQuote + *str + singleQuote, nil
   201  	case typeinfo.TimeTypeIdentifier:
   202  		return singleQuote + *str + singleQuote, nil
   203  	case typeinfo.YearTypeIdentifier:
   204  		return singleQuote + *str + singleQuote, nil
   205  	case typeinfo.DatetimeTypeIdentifier:
   206  		return singleQuote + *str + singleQuote, nil
   207  	case typeinfo.VarStringTypeIdentifier:
   208  		s, ok := value.(types.String)
   209  		if !ok {
   210  			return "", fmt.Errorf("typeinfo.VarStringTypeIdentifier is not types.String")
   211  		}
   212  		return quoteAndEscapeString(string(s)), nil
   213  	default:
   214  		return *str, nil
   215  	}
   216  }
   217  
   218  // todo: this is a hack, varstring should handle this
   219  func quoteAndEscapeString(s string) string {
   220  	buf := &bytes.Buffer{}
   221  	v, err := sqltypes.NewValue(sqltypes.VarChar, []byte(s))
   222  	if err != nil {
   223  		panic(err)
   224  	}
   225  	v.EncodeSQL(buf)
   226  	return buf.String()
   227  }