github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/sqlfmt/row_fmt.go (about)

     1  // Copyright 2020 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package sqlfmt
    16  
    17  import (
    18  	"bytes"
    19  	"encoding/hex"
    20  	"fmt"
    21  	"strings"
    22  
    23  	"github.com/dolthub/go-mysql-server/sql"
    24  	"github.com/dolthub/vitess/go/sqltypes"
    25  	"github.com/dolthub/vitess/go/vt/sqlparser"
    26  
    27  	"github.com/dolthub/dolt/go/libraries/doltcore/row"
    28  	"github.com/dolthub/dolt/go/libraries/doltcore/schema"
    29  	"github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo"
    30  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil"
    31  	"github.com/dolthub/dolt/go/libraries/utils/set"
    32  	"github.com/dolthub/dolt/go/store/types"
    33  )
    34  
    35  const singleQuote = `'`
    36  
    37  // Quotes the identifier given with backticks.
    38  func QuoteIdentifier(s string) string {
    39  	return "`" + s + "`"
    40  }
    41  
    42  // QuoteComment quotes the given string with apostrophes, and escapes any contained within the string.
    43  func QuoteComment(s string) string {
    44  	return `'` + strings.ReplaceAll(s, `'`, `\'`) + `'`
    45  }
    46  
    47  func RowAsInsertStmt(r row.Row, tableName string, tableSch schema.Schema) (string, error) {
    48  	var b strings.Builder
    49  	b.WriteString("INSERT INTO ")
    50  	b.WriteString(QuoteIdentifier(tableName))
    51  	b.WriteString(" ")
    52  
    53  	b.WriteString("(")
    54  	seenOne := false
    55  	err := tableSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
    56  		if seenOne {
    57  			b.WriteRune(',')
    58  		}
    59  		b.WriteString(QuoteIdentifier(col.Name))
    60  		seenOne = true
    61  		return false, nil
    62  	})
    63  
    64  	if err != nil {
    65  		return "", err
    66  	}
    67  
    68  	b.WriteString(")")
    69  
    70  	b.WriteString(" VALUES (")
    71  	seenOne = false
    72  	_, err = r.IterSchema(tableSch, func(tag uint64, val types.Value) (stop bool, err error) {
    73  		if seenOne {
    74  			b.WriteRune(',')
    75  		}
    76  		col, _ := tableSch.GetAllCols().GetByTag(tag)
    77  		sqlString, err := ValueAsSqlString(col.TypeInfo, val)
    78  		if err != nil {
    79  			return true, err
    80  		}
    81  		b.WriteString(sqlString)
    82  		seenOne = true
    83  		return false, nil
    84  	})
    85  
    86  	if err != nil {
    87  		return "", err
    88  	}
    89  
    90  	b.WriteString(");")
    91  
    92  	return b.String(), nil
    93  }
    94  
    95  func RowAsDeleteStmt(r row.Row, tableName string, tableSch schema.Schema) (string, error) {
    96  	var b strings.Builder
    97  	b.WriteString("DELETE FROM ")
    98  	b.WriteString(QuoteIdentifier(tableName))
    99  
   100  	b.WriteString(" WHERE (")
   101  	seenOne := false
   102  	isKeyless := tableSch.GetPKCols().Size() == 0
   103  	_, err := r.IterSchema(tableSch, func(tag uint64, val types.Value) (stop bool, err error) {
   104  		col, _ := tableSch.GetAllCols().GetByTag(tag)
   105  		if col.IsPartOfPK || isKeyless {
   106  			if seenOne {
   107  				b.WriteString(" AND ")
   108  			}
   109  			sqlString, err := ValueAsSqlString(col.TypeInfo, val)
   110  			if err != nil {
   111  				return true, err
   112  			}
   113  			b.WriteString(QuoteIdentifier(col.Name))
   114  			b.WriteRune('=')
   115  			b.WriteString(sqlString)
   116  			seenOne = true
   117  		}
   118  		return false, nil
   119  	})
   120  
   121  	if err != nil {
   122  		return "", err
   123  	}
   124  
   125  	b.WriteString(");")
   126  	return b.String(), nil
   127  }
   128  
   129  func RowAsUpdateStmt(r row.Row, tableName string, tableSch schema.Schema, colsToUpdate *set.StrSet) (string, error) {
   130  	var b strings.Builder
   131  	b.WriteString("UPDATE ")
   132  	b.WriteString(QuoteIdentifier(tableName))
   133  	b.WriteString(" ")
   134  
   135  	b.WriteString("SET ")
   136  	seenOne := false
   137  	_, err := r.IterSchema(tableSch, func(tag uint64, val types.Value) (stop bool, err error) {
   138  		col, _ := tableSch.GetAllCols().GetByTag(tag)
   139  		exists := colsToUpdate.Contains(col.Name)
   140  		if !col.IsPartOfPK && exists {
   141  			if seenOne {
   142  				b.WriteRune(',')
   143  			}
   144  			sqlString, err := ValueAsSqlString(col.TypeInfo, val)
   145  			if err != nil {
   146  				return true, err
   147  			}
   148  			b.WriteString(QuoteIdentifier(col.Name))
   149  			b.WriteRune('=')
   150  			b.WriteString(sqlString)
   151  			seenOne = true
   152  		}
   153  		return false, nil
   154  	})
   155  
   156  	if err != nil {
   157  		return "", err
   158  	}
   159  
   160  	b.WriteString(" WHERE (")
   161  	seenOne = false
   162  	_, err = r.IterSchema(tableSch, func(tag uint64, val types.Value) (stop bool, err error) {
   163  		col, _ := tableSch.GetAllCols().GetByTag(tag)
   164  		if col.IsPartOfPK {
   165  			if seenOne {
   166  				b.WriteString(" AND ")
   167  			}
   168  			sqlString, err := ValueAsSqlString(col.TypeInfo, val)
   169  			if err != nil {
   170  				return true, err
   171  			}
   172  			b.WriteString(QuoteIdentifier(col.Name))
   173  			b.WriteRune('=')
   174  			b.WriteString(sqlString)
   175  			seenOne = true
   176  		}
   177  		return false, nil
   178  	})
   179  
   180  	if err != nil {
   181  		return "", err
   182  	}
   183  
   184  	b.WriteString(");")
   185  	return b.String(), nil
   186  }
   187  
   188  // RowAsTupleString converts a row into it's tuple string representation for SQL insert statements.
   189  func RowAsTupleString(r row.Row, tableSch schema.Schema) (string, error) {
   190  	var b strings.Builder
   191  
   192  	b.WriteString("(")
   193  	seenOne := false
   194  	_, err := r.IterSchema(tableSch, func(tag uint64, val types.Value) (stop bool, err error) {
   195  		if seenOne {
   196  			b.WriteRune(',')
   197  		}
   198  		col, _ := tableSch.GetAllCols().GetByTag(tag)
   199  		sqlString, err := ValueAsSqlString(col.TypeInfo, val)
   200  		if err != nil {
   201  			return true, err
   202  		}
   203  
   204  		b.WriteString(sqlString)
   205  		seenOne = true
   206  		return false, err
   207  	})
   208  
   209  	if err != nil {
   210  		return "", err
   211  	}
   212  
   213  	b.WriteString(")")
   214  
   215  	return b.String(), nil
   216  }
   217  
   218  // InsertStatementPrefix returns the first part of an SQL insert query for a given table
   219  func InsertStatementPrefix(tableName string, tableSch schema.Schema) (string, error) {
   220  	var b strings.Builder
   221  
   222  	b.WriteString("INSERT INTO ")
   223  	b.WriteString(QuoteIdentifier(tableName))
   224  	b.WriteString(" (")
   225  
   226  	seenOne := false
   227  	err := tableSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) {
   228  		if seenOne {
   229  			b.WriteRune(',')
   230  		}
   231  		b.WriteString(QuoteIdentifier(col.Name))
   232  		seenOne = true
   233  		return false, nil
   234  	})
   235  
   236  	if err != nil {
   237  		return "", err
   238  	}
   239  
   240  	b.WriteString(") VALUES ")
   241  	return b.String(), nil
   242  }
   243  
   244  // SqlRowAsCreateProcStmt Converts a Row into either a CREATE PROCEDURE statement
   245  // This function expects a row from the dolt_procedures table.
   246  func SqlRowAsCreateProcStmt(r sql.Row) (string, error) {
   247  	var b strings.Builder
   248  
   249  	// Write create procedure
   250  	prefix := "CREATE PROCEDURE "
   251  	b.WriteString(prefix)
   252  
   253  	// Write procedure name
   254  	nameStr := r[0].(string)
   255  	b.WriteString(QuoteIdentifier(nameStr))
   256  	b.WriteString(" ") // add a space
   257  
   258  	// Write definition
   259  	defStmt, err := sqlparser.Parse(r[1].(string))
   260  	if err != nil {
   261  		return "", err
   262  	}
   263  	defStr := sqlparser.String(defStmt)
   264  	defStr = defStr[len(prefix)+len(nameStr)+1:]
   265  	b.WriteString(defStr)
   266  
   267  	b.WriteString(";")
   268  	return b.String(), nil
   269  }
   270  
   271  // SqlRowAsCreateFragStmt Converts a Row into either a CREATE TRIGGER or CREATE VIEW statement
   272  // This function expects a row from the dolt_schemas table
   273  func SqlRowAsCreateFragStmt(r sql.Row) (string, error) {
   274  	var b strings.Builder
   275  
   276  	// If type is view, add DROP VIEW IF EXISTS statement before CREATE VIEW STATEMENT
   277  	typeStr := strings.ToUpper(r[0].(string))
   278  	if typeStr == "VIEW" {
   279  		nameStr := r[1].(string)
   280  		dropStmt := fmt.Sprintf("DROP VIEW IF EXISTS `%s`", nameStr)
   281  		b.WriteString(dropStmt)
   282  		b.WriteString(";\n")
   283  	}
   284  
   285  	// Parse statement to extract definition (and remove any weird whitespace issues)
   286  	defStmt, err := sqlparser.Parse(r[2].(string))
   287  	if err != nil {
   288  		return "", err
   289  	}
   290  
   291  	defStr := sqlparser.String(defStmt)
   292  
   293  	// TODO: this is temporary fix for create statements
   294  	if typeStr == "TRIGGER" {
   295  		nameStr := r[1].(string)
   296  		defStr = fmt.Sprintf("CREATE TRIGGER `%s` %s", nameStr, defStr[len("CREATE TRIGGER ")+len(nameStr)+1:])
   297  	} else {
   298  		defStr = strings.Replace(defStr, "create ", "CREATE ", -1)
   299  		defStr = strings.Replace(defStr, " view ", " VIEW ", -1)
   300  		defStr = strings.Replace(defStr, " as ", " AS ", -1)
   301  	}
   302  
   303  	b.WriteString(defStr)
   304  
   305  	b.WriteString(";")
   306  	return b.String(), nil
   307  }
   308  
   309  func SqlRowAsInsertStmt(r sql.Row, tableName string, tableSch schema.Schema) (string, error) {
   310  	var b strings.Builder
   311  
   312  	// Write insert prefix
   313  	prefix, err := InsertStatementPrefix(tableName, tableSch)
   314  	if err != nil {
   315  		return "", err
   316  	}
   317  	b.WriteString(prefix)
   318  
   319  	// Write single insert
   320  	str, err := SqlRowAsTupleString(r, tableSch)
   321  	if err != nil {
   322  		return "", err
   323  	}
   324  	b.WriteString(str)
   325  
   326  	b.WriteString(";")
   327  	return b.String(), nil
   328  }
   329  
   330  // SqlRowAsTupleString converts a sql row into it's tuple string representation for SQL insert statements.
   331  func SqlRowAsTupleString(r sql.Row, tableSch schema.Schema) (string, error) {
   332  	var b strings.Builder
   333  	var err error
   334  
   335  	b.WriteString("(")
   336  	seenOne := false
   337  	for i, val := range r {
   338  		if seenOne {
   339  			b.WriteRune(',')
   340  		}
   341  		col := tableSch.GetAllCols().GetByIndex(i)
   342  		str := "NULL"
   343  		if val != nil {
   344  			str, err = interfaceValueAsSqlString(col.TypeInfo, val)
   345  			if err != nil {
   346  				return "", err
   347  			}
   348  		}
   349  
   350  		b.WriteString(str)
   351  		seenOne = true
   352  	}
   353  	b.WriteString(")")
   354  
   355  	return b.String(), nil
   356  }
   357  
   358  // SqlRowAsStrings returns the string representation for each column of |r|
   359  // which should have schema |sch|.
   360  func SqlRowAsStrings(r sql.Row, sch sql.Schema) ([]string, error) {
   361  	out := make([]string, len(r))
   362  	for i := range out {
   363  		v := r[i]
   364  		sqlType := sch[i].Type
   365  		s, err := sqlutil.SqlColToStr(sqlType, v)
   366  		if err != nil {
   367  			return nil, err
   368  		}
   369  		out[i] = s
   370  	}
   371  	return out, nil
   372  }
   373  
   374  // SqlRowAsDeleteStmt generates a sql statement. Non-zero |limit| adds a limit clause.
   375  func SqlRowAsDeleteStmt(r sql.Row, tableName string, tableSch schema.Schema, limit uint64) (string, error) {
   376  	var b strings.Builder
   377  	b.WriteString("DELETE FROM ")
   378  	b.WriteString(QuoteIdentifier(tableName))
   379  
   380  	b.WriteString(" WHERE ")
   381  	seenOne := false
   382  	i := 0
   383  	isKeyless := schema.IsKeyless(tableSch)
   384  
   385  	err := tableSch.GetAllCols().Iter(func(_ uint64, col schema.Column) (stop bool, err error) {
   386  		if col.IsPartOfPK || isKeyless {
   387  			if seenOne {
   388  				b.WriteString(" AND ")
   389  			}
   390  			sqlString, err := interfaceValueAsSqlString(col.TypeInfo, r[i])
   391  			if err != nil {
   392  				return true, err
   393  			}
   394  			b.WriteString(QuoteIdentifier(col.Name))
   395  			b.WriteRune('=')
   396  			b.WriteString(sqlString)
   397  			seenOne = true
   398  		}
   399  		i++
   400  		return false, nil
   401  	})
   402  
   403  	if err != nil {
   404  		return "", err
   405  	}
   406  
   407  	if limit != 0 {
   408  		b.WriteString(" LIMIT ")
   409  		s, err := interfaceValueAsSqlString(typeinfo.FromKind(types.UintKind), limit)
   410  		if err != nil {
   411  			return "", err
   412  		}
   413  		b.WriteString(s)
   414  	}
   415  
   416  	b.WriteString(";")
   417  	return b.String(), nil
   418  }
   419  
   420  func SqlRowAsUpdateStmt(r sql.Row, tableName string, tableSch schema.Schema, colsToUpdate *set.StrSet) (string, error) {
   421  	var b strings.Builder
   422  	b.WriteString("UPDATE ")
   423  	b.WriteString(QuoteIdentifier(tableName))
   424  	b.WriteString(" ")
   425  
   426  	b.WriteString("SET ")
   427  
   428  	i := 0
   429  	seenOne := false
   430  	err := tableSch.GetAllCols().Iter(func(_ uint64, col schema.Column) (stop bool, err error) {
   431  		if colsToUpdate.Contains(col.Name) {
   432  			if seenOne {
   433  				b.WriteRune(',')
   434  			}
   435  			seenOne = true
   436  
   437  			sqlString, err := interfaceValueAsSqlString(col.TypeInfo, r[i])
   438  			if err != nil {
   439  				return true, err
   440  			}
   441  			b.WriteString(QuoteIdentifier(col.Name))
   442  			b.WriteRune('=')
   443  			b.WriteString(sqlString)
   444  		}
   445  		i++
   446  		return false, nil
   447  	})
   448  
   449  	if err != nil {
   450  		return "", err
   451  	}
   452  
   453  	b.WriteString(" WHERE ")
   454  
   455  	i = 0
   456  	seenOne = false
   457  	err = tableSch.GetAllCols().Iter(func(_ uint64, col schema.Column) (stop bool, err error) {
   458  		if col.IsPartOfPK {
   459  			if seenOne {
   460  				b.WriteString(" AND ")
   461  			}
   462  			seenOne = true
   463  
   464  			sqlString, err := interfaceValueAsSqlString(col.TypeInfo, r[i])
   465  			if err != nil {
   466  				return true, err
   467  			}
   468  			b.WriteString(QuoteIdentifier(col.Name))
   469  			b.WriteRune('=')
   470  			b.WriteString(sqlString)
   471  		}
   472  		i++
   473  		return false, nil
   474  	})
   475  
   476  	if err != nil {
   477  		return "", err
   478  	}
   479  
   480  	b.WriteString(";")
   481  	return b.String(), nil
   482  }
   483  
   484  func ValueAsSqlString(ti typeinfo.TypeInfo, value types.Value) (string, error) {
   485  	if types.IsNull(value) {
   486  		return "NULL", nil
   487  	}
   488  
   489  	str, err := ti.FormatValue(value)
   490  
   491  	if err != nil {
   492  		return "", err
   493  	}
   494  
   495  	switch ti.GetTypeIdentifier() {
   496  	case typeinfo.BoolTypeIdentifier:
   497  		// todo: unclear if we want this to output with "TRUE/FALSE" or 1/0
   498  		if value.(types.Bool) {
   499  			return "TRUE", nil
   500  		}
   501  		return "FALSE", nil
   502  	case typeinfo.UuidTypeIdentifier, typeinfo.TimeTypeIdentifier, typeinfo.YearTypeIdentifier, typeinfo.DatetimeTypeIdentifier:
   503  		return singleQuote + *str + singleQuote, nil
   504  	case typeinfo.BlobStringTypeIdentifier, typeinfo.VarBinaryTypeIdentifier, typeinfo.InlineBlobTypeIdentifier, typeinfo.JSONTypeIdentifier, typeinfo.EnumTypeIdentifier, typeinfo.SetTypeIdentifier:
   505  		return quoteAndEscapeString(*str), nil
   506  	case typeinfo.VarStringTypeIdentifier:
   507  		s, ok := value.(types.String)
   508  		if !ok {
   509  			return "", fmt.Errorf("typeinfo.VarStringTypeIdentifier is not types.String")
   510  		}
   511  		return quoteAndEscapeString(string(s)), nil
   512  	default:
   513  		return *str, nil
   514  	}
   515  }
   516  
   517  func interfaceValueAsSqlString(ti typeinfo.TypeInfo, value interface{}) (string, error) {
   518  	if value == nil {
   519  		return "NULL", nil
   520  	}
   521  
   522  	str, err := sqlutil.SqlColToStr(ti.ToSqlType(), value)
   523  	if err != nil {
   524  		return "", err
   525  	}
   526  
   527  	switch ti.GetTypeIdentifier() {
   528  	case typeinfo.BoolTypeIdentifier:
   529  		if value.(bool) {
   530  			return "1", nil
   531  		}
   532  		return "0", nil
   533  	case typeinfo.UuidTypeIdentifier, typeinfo.TimeTypeIdentifier, typeinfo.YearTypeIdentifier:
   534  		return singleQuote + str + singleQuote, nil
   535  	case typeinfo.DatetimeTypeIdentifier:
   536  		return singleQuote + str + singleQuote, nil
   537  	case typeinfo.InlineBlobTypeIdentifier, typeinfo.VarBinaryTypeIdentifier:
   538  		switch v := value.(type) {
   539  		case []byte:
   540  			return hexEncodeBytes(v), nil
   541  		case string:
   542  			return hexEncodeBytes([]byte(v)), nil
   543  		default:
   544  			return "", fmt.Errorf("unexpected type for binary value: %T (SQL type info: %v)", value, ti)
   545  		}
   546  	case typeinfo.JSONTypeIdentifier, typeinfo.EnumTypeIdentifier, typeinfo.SetTypeIdentifier, typeinfo.BlobStringTypeIdentifier:
   547  		return quoteAndEscapeString(str), nil
   548  	case typeinfo.VarStringTypeIdentifier:
   549  		s, ok := value.(string)
   550  		if !ok {
   551  			return "", fmt.Errorf("typeinfo.VarStringTypeIdentifier is not types.String")
   552  		}
   553  		return quoteAndEscapeString(s), nil
   554  	case typeinfo.GeometryTypeIdentifier,
   555  		typeinfo.PointTypeIdentifier,
   556  		typeinfo.LineStringTypeIdentifier,
   557  		typeinfo.PolygonTypeIdentifier,
   558  		typeinfo.MultiPointTypeIdentifier,
   559  		typeinfo.MultiLineStringTypeIdentifier,
   560  		typeinfo.MultiPolygonTypeIdentifier,
   561  		typeinfo.GeometryCollectionTypeIdentifier:
   562  		return singleQuote + str + singleQuote, nil
   563  	default:
   564  		return str, nil
   565  	}
   566  }
   567  
   568  func quoteAndEscapeString(s string) string {
   569  	buf := &bytes.Buffer{}
   570  	v, err := sqltypes.NewValue(sqltypes.VarChar, []byte(s))
   571  	if err != nil {
   572  		panic(err)
   573  	}
   574  	v.EncodeSQL(buf)
   575  	return buf.String()
   576  }
   577  
   578  func hexEncodeBytes(bytes []byte) string {
   579  	return "0x" + hex.EncodeToString(bytes)
   580  }