github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/sqlfmt/row_fmt.go (about) 1 // Copyright 2020 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package sqlfmt 16 17 import ( 18 "bytes" 19 "encoding/hex" 20 "fmt" 21 "strings" 22 23 "github.com/dolthub/go-mysql-server/sql" 24 "github.com/dolthub/vitess/go/sqltypes" 25 "github.com/dolthub/vitess/go/vt/sqlparser" 26 27 "github.com/dolthub/dolt/go/libraries/doltcore/row" 28 "github.com/dolthub/dolt/go/libraries/doltcore/schema" 29 "github.com/dolthub/dolt/go/libraries/doltcore/schema/typeinfo" 30 "github.com/dolthub/dolt/go/libraries/doltcore/sqle/sqlutil" 31 "github.com/dolthub/dolt/go/libraries/utils/set" 32 "github.com/dolthub/dolt/go/store/types" 33 ) 34 35 const singleQuote = `'` 36 37 // Quotes the identifier given with backticks. 38 func QuoteIdentifier(s string) string { 39 return "`" + s + "`" 40 } 41 42 // QuoteComment quotes the given string with apostrophes, and escapes any contained within the string. 43 func QuoteComment(s string) string { 44 return `'` + strings.ReplaceAll(s, `'`, `\'`) + `'` 45 } 46 47 func RowAsInsertStmt(r row.Row, tableName string, tableSch schema.Schema) (string, error) { 48 var b strings.Builder 49 b.WriteString("INSERT INTO ") 50 b.WriteString(QuoteIdentifier(tableName)) 51 b.WriteString(" ") 52 53 b.WriteString("(") 54 seenOne := false 55 err := tableSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { 56 if seenOne { 57 b.WriteRune(',') 58 } 59 b.WriteString(QuoteIdentifier(col.Name)) 60 seenOne = true 61 return false, nil 62 }) 63 64 if err != nil { 65 return "", err 66 } 67 68 b.WriteString(")") 69 70 b.WriteString(" VALUES (") 71 seenOne = false 72 _, err = r.IterSchema(tableSch, func(tag uint64, val types.Value) (stop bool, err error) { 73 if seenOne { 74 b.WriteRune(',') 75 } 76 col, _ := tableSch.GetAllCols().GetByTag(tag) 77 sqlString, err := ValueAsSqlString(col.TypeInfo, val) 78 if err != nil { 79 return true, err 80 } 81 b.WriteString(sqlString) 82 seenOne = true 83 return false, nil 84 }) 85 86 if err != nil { 87 return "", err 88 } 89 90 b.WriteString(");") 91 92 return b.String(), nil 93 } 94 95 func RowAsDeleteStmt(r row.Row, tableName string, tableSch schema.Schema) (string, error) { 96 var b strings.Builder 97 b.WriteString("DELETE FROM ") 98 b.WriteString(QuoteIdentifier(tableName)) 99 100 b.WriteString(" WHERE (") 101 seenOne := false 102 isKeyless := tableSch.GetPKCols().Size() == 0 103 _, err := r.IterSchema(tableSch, func(tag uint64, val types.Value) (stop bool, err error) { 104 col, _ := tableSch.GetAllCols().GetByTag(tag) 105 if col.IsPartOfPK || isKeyless { 106 if seenOne { 107 b.WriteString(" AND ") 108 } 109 sqlString, err := ValueAsSqlString(col.TypeInfo, val) 110 if err != nil { 111 return true, err 112 } 113 b.WriteString(QuoteIdentifier(col.Name)) 114 b.WriteRune('=') 115 b.WriteString(sqlString) 116 seenOne = true 117 } 118 return false, nil 119 }) 120 121 if err != nil { 122 return "", err 123 } 124 125 b.WriteString(");") 126 return b.String(), nil 127 } 128 129 func RowAsUpdateStmt(r row.Row, tableName string, tableSch schema.Schema, colsToUpdate *set.StrSet) (string, error) { 130 var b strings.Builder 131 b.WriteString("UPDATE ") 132 b.WriteString(QuoteIdentifier(tableName)) 133 b.WriteString(" ") 134 135 b.WriteString("SET ") 136 seenOne := false 137 _, err := r.IterSchema(tableSch, func(tag uint64, val types.Value) (stop bool, err error) { 138 col, _ := tableSch.GetAllCols().GetByTag(tag) 139 exists := colsToUpdate.Contains(col.Name) 140 if !col.IsPartOfPK && exists { 141 if seenOne { 142 b.WriteRune(',') 143 } 144 sqlString, err := ValueAsSqlString(col.TypeInfo, val) 145 if err != nil { 146 return true, err 147 } 148 b.WriteString(QuoteIdentifier(col.Name)) 149 b.WriteRune('=') 150 b.WriteString(sqlString) 151 seenOne = true 152 } 153 return false, nil 154 }) 155 156 if err != nil { 157 return "", err 158 } 159 160 b.WriteString(" WHERE (") 161 seenOne = false 162 _, err = r.IterSchema(tableSch, func(tag uint64, val types.Value) (stop bool, err error) { 163 col, _ := tableSch.GetAllCols().GetByTag(tag) 164 if col.IsPartOfPK { 165 if seenOne { 166 b.WriteString(" AND ") 167 } 168 sqlString, err := ValueAsSqlString(col.TypeInfo, val) 169 if err != nil { 170 return true, err 171 } 172 b.WriteString(QuoteIdentifier(col.Name)) 173 b.WriteRune('=') 174 b.WriteString(sqlString) 175 seenOne = true 176 } 177 return false, nil 178 }) 179 180 if err != nil { 181 return "", err 182 } 183 184 b.WriteString(");") 185 return b.String(), nil 186 } 187 188 // RowAsTupleString converts a row into it's tuple string representation for SQL insert statements. 189 func RowAsTupleString(r row.Row, tableSch schema.Schema) (string, error) { 190 var b strings.Builder 191 192 b.WriteString("(") 193 seenOne := false 194 _, err := r.IterSchema(tableSch, func(tag uint64, val types.Value) (stop bool, err error) { 195 if seenOne { 196 b.WriteRune(',') 197 } 198 col, _ := tableSch.GetAllCols().GetByTag(tag) 199 sqlString, err := ValueAsSqlString(col.TypeInfo, val) 200 if err != nil { 201 return true, err 202 } 203 204 b.WriteString(sqlString) 205 seenOne = true 206 return false, err 207 }) 208 209 if err != nil { 210 return "", err 211 } 212 213 b.WriteString(")") 214 215 return b.String(), nil 216 } 217 218 // InsertStatementPrefix returns the first part of an SQL insert query for a given table 219 func InsertStatementPrefix(tableName string, tableSch schema.Schema) (string, error) { 220 var b strings.Builder 221 222 b.WriteString("INSERT INTO ") 223 b.WriteString(QuoteIdentifier(tableName)) 224 b.WriteString(" (") 225 226 seenOne := false 227 err := tableSch.GetAllCols().Iter(func(tag uint64, col schema.Column) (stop bool, err error) { 228 if seenOne { 229 b.WriteRune(',') 230 } 231 b.WriteString(QuoteIdentifier(col.Name)) 232 seenOne = true 233 return false, nil 234 }) 235 236 if err != nil { 237 return "", err 238 } 239 240 b.WriteString(") VALUES ") 241 return b.String(), nil 242 } 243 244 // SqlRowAsCreateProcStmt Converts a Row into either a CREATE PROCEDURE statement 245 // This function expects a row from the dolt_procedures table. 246 func SqlRowAsCreateProcStmt(r sql.Row) (string, error) { 247 var b strings.Builder 248 249 // Write create procedure 250 prefix := "CREATE PROCEDURE " 251 b.WriteString(prefix) 252 253 // Write procedure name 254 nameStr := r[0].(string) 255 b.WriteString(QuoteIdentifier(nameStr)) 256 b.WriteString(" ") // add a space 257 258 // Write definition 259 defStmt, err := sqlparser.Parse(r[1].(string)) 260 if err != nil { 261 return "", err 262 } 263 defStr := sqlparser.String(defStmt) 264 defStr = defStr[len(prefix)+len(nameStr)+1:] 265 b.WriteString(defStr) 266 267 b.WriteString(";") 268 return b.String(), nil 269 } 270 271 // SqlRowAsCreateFragStmt Converts a Row into either a CREATE TRIGGER or CREATE VIEW statement 272 // This function expects a row from the dolt_schemas table 273 func SqlRowAsCreateFragStmt(r sql.Row) (string, error) { 274 var b strings.Builder 275 276 // If type is view, add DROP VIEW IF EXISTS statement before CREATE VIEW STATEMENT 277 typeStr := strings.ToUpper(r[0].(string)) 278 if typeStr == "VIEW" { 279 nameStr := r[1].(string) 280 dropStmt := fmt.Sprintf("DROP VIEW IF EXISTS `%s`", nameStr) 281 b.WriteString(dropStmt) 282 b.WriteString(";\n") 283 } 284 285 // Parse statement to extract definition (and remove any weird whitespace issues) 286 defStmt, err := sqlparser.Parse(r[2].(string)) 287 if err != nil { 288 return "", err 289 } 290 291 defStr := sqlparser.String(defStmt) 292 293 // TODO: this is temporary fix for create statements 294 if typeStr == "TRIGGER" { 295 nameStr := r[1].(string) 296 defStr = fmt.Sprintf("CREATE TRIGGER `%s` %s", nameStr, defStr[len("CREATE TRIGGER ")+len(nameStr)+1:]) 297 } else { 298 defStr = strings.Replace(defStr, "create ", "CREATE ", -1) 299 defStr = strings.Replace(defStr, " view ", " VIEW ", -1) 300 defStr = strings.Replace(defStr, " as ", " AS ", -1) 301 } 302 303 b.WriteString(defStr) 304 305 b.WriteString(";") 306 return b.String(), nil 307 } 308 309 func SqlRowAsInsertStmt(r sql.Row, tableName string, tableSch schema.Schema) (string, error) { 310 var b strings.Builder 311 312 // Write insert prefix 313 prefix, err := InsertStatementPrefix(tableName, tableSch) 314 if err != nil { 315 return "", err 316 } 317 b.WriteString(prefix) 318 319 // Write single insert 320 str, err := SqlRowAsTupleString(r, tableSch) 321 if err != nil { 322 return "", err 323 } 324 b.WriteString(str) 325 326 b.WriteString(";") 327 return b.String(), nil 328 } 329 330 // SqlRowAsTupleString converts a sql row into it's tuple string representation for SQL insert statements. 331 func SqlRowAsTupleString(r sql.Row, tableSch schema.Schema) (string, error) { 332 var b strings.Builder 333 var err error 334 335 b.WriteString("(") 336 seenOne := false 337 for i, val := range r { 338 if seenOne { 339 b.WriteRune(',') 340 } 341 col := tableSch.GetAllCols().GetByIndex(i) 342 str := "NULL" 343 if val != nil { 344 str, err = interfaceValueAsSqlString(col.TypeInfo, val) 345 if err != nil { 346 return "", err 347 } 348 } 349 350 b.WriteString(str) 351 seenOne = true 352 } 353 b.WriteString(")") 354 355 return b.String(), nil 356 } 357 358 // SqlRowAsStrings returns the string representation for each column of |r| 359 // which should have schema |sch|. 360 func SqlRowAsStrings(r sql.Row, sch sql.Schema) ([]string, error) { 361 out := make([]string, len(r)) 362 for i := range out { 363 v := r[i] 364 sqlType := sch[i].Type 365 s, err := sqlutil.SqlColToStr(sqlType, v) 366 if err != nil { 367 return nil, err 368 } 369 out[i] = s 370 } 371 return out, nil 372 } 373 374 // SqlRowAsDeleteStmt generates a sql statement. Non-zero |limit| adds a limit clause. 375 func SqlRowAsDeleteStmt(r sql.Row, tableName string, tableSch schema.Schema, limit uint64) (string, error) { 376 var b strings.Builder 377 b.WriteString("DELETE FROM ") 378 b.WriteString(QuoteIdentifier(tableName)) 379 380 b.WriteString(" WHERE ") 381 seenOne := false 382 i := 0 383 isKeyless := schema.IsKeyless(tableSch) 384 385 err := tableSch.GetAllCols().Iter(func(_ uint64, col schema.Column) (stop bool, err error) { 386 if col.IsPartOfPK || isKeyless { 387 if seenOne { 388 b.WriteString(" AND ") 389 } 390 sqlString, err := interfaceValueAsSqlString(col.TypeInfo, r[i]) 391 if err != nil { 392 return true, err 393 } 394 b.WriteString(QuoteIdentifier(col.Name)) 395 b.WriteRune('=') 396 b.WriteString(sqlString) 397 seenOne = true 398 } 399 i++ 400 return false, nil 401 }) 402 403 if err != nil { 404 return "", err 405 } 406 407 if limit != 0 { 408 b.WriteString(" LIMIT ") 409 s, err := interfaceValueAsSqlString(typeinfo.FromKind(types.UintKind), limit) 410 if err != nil { 411 return "", err 412 } 413 b.WriteString(s) 414 } 415 416 b.WriteString(";") 417 return b.String(), nil 418 } 419 420 func SqlRowAsUpdateStmt(r sql.Row, tableName string, tableSch schema.Schema, colsToUpdate *set.StrSet) (string, error) { 421 var b strings.Builder 422 b.WriteString("UPDATE ") 423 b.WriteString(QuoteIdentifier(tableName)) 424 b.WriteString(" ") 425 426 b.WriteString("SET ") 427 428 i := 0 429 seenOne := false 430 err := tableSch.GetAllCols().Iter(func(_ uint64, col schema.Column) (stop bool, err error) { 431 if colsToUpdate.Contains(col.Name) { 432 if seenOne { 433 b.WriteRune(',') 434 } 435 seenOne = true 436 437 sqlString, err := interfaceValueAsSqlString(col.TypeInfo, r[i]) 438 if err != nil { 439 return true, err 440 } 441 b.WriteString(QuoteIdentifier(col.Name)) 442 b.WriteRune('=') 443 b.WriteString(sqlString) 444 } 445 i++ 446 return false, nil 447 }) 448 449 if err != nil { 450 return "", err 451 } 452 453 b.WriteString(" WHERE ") 454 455 i = 0 456 seenOne = false 457 err = tableSch.GetAllCols().Iter(func(_ uint64, col schema.Column) (stop bool, err error) { 458 if col.IsPartOfPK { 459 if seenOne { 460 b.WriteString(" AND ") 461 } 462 seenOne = true 463 464 sqlString, err := interfaceValueAsSqlString(col.TypeInfo, r[i]) 465 if err != nil { 466 return true, err 467 } 468 b.WriteString(QuoteIdentifier(col.Name)) 469 b.WriteRune('=') 470 b.WriteString(sqlString) 471 } 472 i++ 473 return false, nil 474 }) 475 476 if err != nil { 477 return "", err 478 } 479 480 b.WriteString(";") 481 return b.String(), nil 482 } 483 484 func ValueAsSqlString(ti typeinfo.TypeInfo, value types.Value) (string, error) { 485 if types.IsNull(value) { 486 return "NULL", nil 487 } 488 489 str, err := ti.FormatValue(value) 490 491 if err != nil { 492 return "", err 493 } 494 495 switch ti.GetTypeIdentifier() { 496 case typeinfo.BoolTypeIdentifier: 497 // todo: unclear if we want this to output with "TRUE/FALSE" or 1/0 498 if value.(types.Bool) { 499 return "TRUE", nil 500 } 501 return "FALSE", nil 502 case typeinfo.UuidTypeIdentifier, typeinfo.TimeTypeIdentifier, typeinfo.YearTypeIdentifier, typeinfo.DatetimeTypeIdentifier: 503 return singleQuote + *str + singleQuote, nil 504 case typeinfo.BlobStringTypeIdentifier, typeinfo.VarBinaryTypeIdentifier, typeinfo.InlineBlobTypeIdentifier, typeinfo.JSONTypeIdentifier, typeinfo.EnumTypeIdentifier, typeinfo.SetTypeIdentifier: 505 return quoteAndEscapeString(*str), nil 506 case typeinfo.VarStringTypeIdentifier: 507 s, ok := value.(types.String) 508 if !ok { 509 return "", fmt.Errorf("typeinfo.VarStringTypeIdentifier is not types.String") 510 } 511 return quoteAndEscapeString(string(s)), nil 512 default: 513 return *str, nil 514 } 515 } 516 517 func interfaceValueAsSqlString(ti typeinfo.TypeInfo, value interface{}) (string, error) { 518 if value == nil { 519 return "NULL", nil 520 } 521 522 str, err := sqlutil.SqlColToStr(ti.ToSqlType(), value) 523 if err != nil { 524 return "", err 525 } 526 527 switch ti.GetTypeIdentifier() { 528 case typeinfo.BoolTypeIdentifier: 529 if value.(bool) { 530 return "1", nil 531 } 532 return "0", nil 533 case typeinfo.UuidTypeIdentifier, typeinfo.TimeTypeIdentifier, typeinfo.YearTypeIdentifier: 534 return singleQuote + str + singleQuote, nil 535 case typeinfo.DatetimeTypeIdentifier: 536 return singleQuote + str + singleQuote, nil 537 case typeinfo.InlineBlobTypeIdentifier, typeinfo.VarBinaryTypeIdentifier: 538 switch v := value.(type) { 539 case []byte: 540 return hexEncodeBytes(v), nil 541 case string: 542 return hexEncodeBytes([]byte(v)), nil 543 default: 544 return "", fmt.Errorf("unexpected type for binary value: %T (SQL type info: %v)", value, ti) 545 } 546 case typeinfo.JSONTypeIdentifier, typeinfo.EnumTypeIdentifier, typeinfo.SetTypeIdentifier, typeinfo.BlobStringTypeIdentifier: 547 return quoteAndEscapeString(str), nil 548 case typeinfo.VarStringTypeIdentifier: 549 s, ok := value.(string) 550 if !ok { 551 return "", fmt.Errorf("typeinfo.VarStringTypeIdentifier is not types.String") 552 } 553 return quoteAndEscapeString(s), nil 554 case typeinfo.GeometryTypeIdentifier, 555 typeinfo.PointTypeIdentifier, 556 typeinfo.LineStringTypeIdentifier, 557 typeinfo.PolygonTypeIdentifier, 558 typeinfo.MultiPointTypeIdentifier, 559 typeinfo.MultiLineStringTypeIdentifier, 560 typeinfo.MultiPolygonTypeIdentifier, 561 typeinfo.GeometryCollectionTypeIdentifier: 562 return singleQuote + str + singleQuote, nil 563 default: 564 return str, nil 565 } 566 } 567 568 func quoteAndEscapeString(s string) string { 569 buf := &bytes.Buffer{} 570 v, err := sqltypes.NewValue(sqltypes.VarChar, []byte(s)) 571 if err != nil { 572 panic(err) 573 } 574 v.EncodeSQL(buf) 575 return buf.String() 576 } 577 578 func hexEncodeBytes(bytes []byte) string { 579 return "0x" + hex.EncodeToString(bytes) 580 }