github.com/dolthub/go-mysql-server@v0.18.0/sql/sqlfmt.go (about) 1 // Copyright 2023 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package sql 16 17 import ( 18 "fmt" 19 "strings" 20 ) 21 22 // All functions here are used together to generate 'CREATE TABLE' statement. Each function takes what it requires 23 // to build the definition, which are mostly exact names or values (e.g. columns, indexes names, types, etc.) 24 // These functions allow creating the compatible 'CREATE TABLE' statement from both GMS and Dolt, which use different 25 // implementations of schema, column and other objects. 26 27 // GenerateCreateTableStatement returns 'CREATE TABLE' statement with given table names 28 // and column definition statements in order and the collation and character set names for the table 29 func GenerateCreateTableStatement(tblName string, colStmts []string, tblCharsetName, tblCollName string, comment string) string { 30 if comment != "" { 31 // Escape any single quotes in the comment and add the COMMENT keyword 32 comment = strings.ReplaceAll(comment, "'", "''") 33 comment = fmt.Sprintf(" COMMENT='%s'", comment) 34 } 35 36 return fmt.Sprintf( 37 "CREATE TABLE %s (\n%s\n) ENGINE=InnoDB DEFAULT CHARSET=%s COLLATE=%s%s", 38 QuoteIdentifier(tblName), 39 strings.Join(colStmts, ",\n"), 40 tblCharsetName, 41 tblCollName, 42 comment, 43 ) 44 } 45 46 // GenerateCreateTableColumnDefinition returns column definition string for 'CREATE TABLE' statement for given column. 47 // This part comes first in the 'CREATE TABLE' statement. 48 func GenerateCreateTableColumnDefinition(col *Column, colDefault, onUpdate string, tableCollation CollationID) string { 49 var colTypeString string 50 if collationType, ok := col.Type.(TypeWithCollation); ok { 51 colTypeString = collationType.StringWithTableCollation(tableCollation) 52 } else { 53 colTypeString = col.Type.String() 54 } 55 stmt := fmt.Sprintf(" %s %s", QuoteIdentifier(col.Name), colTypeString) 56 if !col.Nullable { 57 stmt = fmt.Sprintf("%s NOT NULL", stmt) 58 } 59 60 if col.AutoIncrement { 61 stmt = fmt.Sprintf("%s AUTO_INCREMENT", stmt) 62 } 63 64 if c, ok := col.Type.(SpatialColumnType); ok { 65 if s, d := c.GetSpatialTypeSRID(); d { 66 stmt = fmt.Sprintf("%s /*!80003 SRID %v */", stmt, s) 67 } 68 } 69 70 if col.Generated != nil { 71 storedStr := "" 72 if !col.Virtual { 73 storedStr = " STORED" 74 } 75 stmt = fmt.Sprintf("%s GENERATED ALWAYS AS %s%s", stmt, col.Generated.String(), storedStr) 76 } 77 78 if col.Default != nil && col.Generated == nil { 79 stmt = fmt.Sprintf("%s DEFAULT %s", stmt, colDefault) 80 } 81 82 if col.OnUpdate != nil { 83 stmt = fmt.Sprintf("%s ON UPDATE %s", stmt, onUpdate) 84 } 85 86 if col.Comment != "" { 87 stmt = fmt.Sprintf("%s COMMENT '%s'", stmt, col.Comment) 88 } 89 return stmt 90 } 91 92 // GenerateCreateTablePrimaryKeyDefinition returns primary key definition string for 'CREATE TABLE' statement 93 // for given column(s). This part comes after each column definitions. 94 func GenerateCreateTablePrimaryKeyDefinition(pkCols []string) string { 95 return fmt.Sprintf(" PRIMARY KEY (%s)", strings.Join(QuoteIdentifiers(pkCols), ",")) 96 } 97 98 // GenerateCreateTableIndexDefinition returns index definition string for 'CREATE TABLE' statement 99 // for given index. This part comes after primary key definition if there is any. 100 func GenerateCreateTableIndexDefinition(isUnique, isSpatial, isFullText bool, indexID string, indexCols []string, comment string) string { 101 unique := "" 102 if isUnique { 103 unique = "UNIQUE " 104 } 105 106 spatial := "" 107 if isSpatial { 108 unique = "SPATIAL " 109 } 110 111 fulltext := "" 112 if isFullText { 113 fulltext = "FULLTEXT " 114 } 115 key := fmt.Sprintf(" %s%s%sKEY %s (%s)", unique, spatial, fulltext, QuoteIdentifier(indexID), strings.Join(indexCols, ",")) 116 if comment != "" { 117 key = fmt.Sprintf("%s COMMENT '%s'", key, comment) 118 } 119 return key 120 } 121 122 // GenerateCreateTableForiegnKeyDefinition returns foreign key constraint definition string for 'CREATE TABLE' statement 123 // for given foreign key. This part comes after index definitions if there are any. 124 func GenerateCreateTableForiegnKeyDefinition(fkName string, fkCols []string, parentTbl string, parentCols []string, onDelete, onUpdate string) string { 125 keyCols := strings.Join(QuoteIdentifiers(fkCols), ",") 126 refCols := strings.Join(QuoteIdentifiers(parentCols), ",") 127 fkey := fmt.Sprintf(" CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s)", QuoteIdentifier(fkName), keyCols, QuoteIdentifier(parentTbl), refCols) 128 if onDelete != "" { 129 fkey = fmt.Sprintf("%s ON DELETE %s", fkey, onDelete) 130 } 131 if onUpdate != "" { 132 fkey = fmt.Sprintf("%s ON UPDATE %s", fkey, onUpdate) 133 } 134 return fkey 135 } 136 137 // GenerateCreateTableCheckConstraintClause returns check constraint clause string for 'CREATE TABLE' statement 138 // for given check constraint. This part comes the last and after foreign key definitions if there are any. 139 func GenerateCreateTableCheckConstraintClause(checkName, checkExpr string, enforced bool) string { 140 cc := fmt.Sprintf(" CONSTRAINT %s CHECK (%s)", QuoteIdentifier(checkName), checkExpr) 141 if !enforced { 142 cc = fmt.Sprintf("%s /*!80016 NOT ENFORCED */", cc) 143 } 144 return cc 145 } 146 147 // QuoteIdentifier wraps the specified identifier in backticks and escapes all occurrences of backticks in the 148 // identifier by replacing them with double backticks. 149 func QuoteIdentifier(id string) string { 150 id = strings.ReplaceAll(id, "`", "``") 151 return fmt.Sprintf("`%s`", id) 152 } 153 154 // QuoteIdentifiers wraps each of the specified identifiers in backticks, escapes all occurrences of backticks in 155 // the identifier, and returns a slice of the quoted identifiers. 156 func QuoteIdentifiers(ids []string) []string { 157 quoted := make([]string, len(ids)) 158 for i, id := range ids { 159 quoted[i] = QuoteIdentifier(id) 160 } 161 return quoted 162 }