github.com/gocaveman/caveman@v0.0.0-20191211162744-0ddf99dbdf6e/ddl/formatter-mysql.go (about)

     1  package ddl
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"strings"
     7  )
     8  
     9  type MySQLFormatter struct {
    10  	Template bool // set to true to enable template output (supports prefixes)
    11  }
    12  
    13  // NewMySQLFormatter returns a new MySQLFormatter. If the template argument
    14  // is true then table prefixes (and any other templatable features)
    15  // will be output in Go template form, for use with migrations.  Passing false
    16  // will produce raw SQL that can be executed directly.
    17  func NewMySQLFormatter(template bool) *MySQLFormatter {
    18  	return &MySQLFormatter{Template: template}
    19  }
    20  
    21  func (f *MySQLFormatter) tmplPrefix() string {
    22  	if f.Template {
    23  		return "{{.TablePrefix}}"
    24  	}
    25  	return ""
    26  }
    27  
    28  func (f *MySQLFormatter) DriverName() string {
    29  	return "mysql"
    30  }
    31  
    32  func (f *MySQLFormatter) Format(stmt Stmt) ([]string, error) {
    33  
    34  	var buf bytes.Buffer
    35  
    36  	switch st := stmt.(type) {
    37  
    38  	case *CreateTableStmt:
    39  		ifNotExistsStr := ""
    40  		if st.IfNotExistsValue {
    41  			ifNotExistsStr = "IF NOT EXISTS "
    42  		}
    43  		fmt.Fprintf(&buf, `CREATE TABLE %s%s (`+"\n", ifNotExistsStr, mysqlQuoteIdent(f.tmplPrefix()+st.NameValue))
    44  
    45  		for _, col := range st.Columns {
    46  
    47  			colstr, err := mysqlColStr(col)
    48  			if err != nil {
    49  				return nil, err
    50  			}
    51  			fmt.Fprintf(&buf, "    %s,\n", colstr)
    52  		}
    53  
    54  		if len(st.PrimaryKeys) > 0 {
    55  			fmt.Fprintf(&buf, "    PRIMARY KEY(")
    56  			for idx, pk := range st.PrimaryKeys {
    57  				fmt.Fprintf(&buf, "%s", mysqlQuoteIdent(pk))
    58  				if idx < len(st.PrimaryKeys)-1 {
    59  					fmt.Fprintf(&buf, ",")
    60  				}
    61  			}
    62  			fmt.Fprintf(&buf, "),\n")
    63  		}
    64  
    65  		for _, fk := range st.ForeignKeys {
    66  			fmt.Fprintf(&buf, "    FOREIGN KEY(%s) REFERENCES %s(%s),",
    67  				mysqlQuoteIdent(fk.ColumnValue),
    68  				mysqlQuoteIdent(f.tmplPrefix()+fk.OtherTableValue),
    69  				mysqlQuoteIdent(fk.OtherColumnValue),
    70  			)
    71  		}
    72  
    73  		// Use utf8mb4 as the character set for everything not explicitly
    74  		// set on a column. Let the db choose the default collation, since
    75  		// it will use the most recent case-insensitive unicode comparision,
    76  		// which is usually exactly what you want.
    77  		tableSuffixStr := " /*!50508 CHARSET=utf8mb4 */"
    78  
    79  		// remove any trailing comma and close table definition
    80  		fullStr := strings.TrimSuffix(strings.TrimSpace(buf.String()), ",") + "\n)" +
    81  			tableSuffixStr
    82  		return []string{fullStr}, nil
    83  
    84  	case *DropTableStmt:
    85  		fmt.Fprintf(&buf, `DROP TABLE %s`, mysqlQuoteIdent(f.tmplPrefix()+st.NameValue))
    86  		return []string{buf.String()}, nil
    87  
    88  	case *AlterTableRenameStmt:
    89  		fmt.Fprintf(&buf, `ALTER TABLE %s RENAME TO %s`,
    90  			mysqlQuoteIdent(f.tmplPrefix()+st.OldNameValue),
    91  			mysqlQuoteIdent(f.tmplPrefix()+st.NewNameValue),
    92  		)
    93  		return []string{buf.String()}, nil
    94  
    95  	case *AlterTableAddStmt:
    96  		colStr, err := mysqlColStr(&st.DataTypeDef)
    97  		if err != nil {
    98  			return nil, err
    99  		}
   100  		fmt.Fprintf(&buf, `ALTER TABLE %s ADD COLUMN %s`,
   101  			mysqlQuoteIdent(f.tmplPrefix()+st.NameValue),
   102  			colStr,
   103  		)
   104  		return []string{buf.String()}, nil
   105  
   106  	case *CreateIndexStmt:
   107  		uniqueStr := ""
   108  		if st.UniqueValue {
   109  			uniqueStr = " UNIQUE"
   110  		}
   111  		// for, just error on this, rather than screwing around with the MariaDB/MySQL fiasco,
   112  		// the whole point of migrations is to avoid this crap anyway - maybe if not exists
   113  		// functionality should just be removed...
   114  		if st.IfNotExistsValue {
   115  			return nil, fmt.Errorf("CREATE INDEX IF NOT EXISTS is not supported by MySQL")
   116  		}
   117  		ifNotExistsStr := ""
   118  		// if st.IfNotExistsValue {
   119  		// 	ifNotExistsStr = " IF NOT EXISTS"
   120  		// }
   121  		colStr := ""
   122  		for _, colName := range st.ColumnNames {
   123  			colStr += mysqlQuoteIdent(colName) + ","
   124  		}
   125  		colStr = strings.TrimRight(colStr, ",")
   126  		fmt.Fprintf(&buf, `CREATE%s INDEX%s %s ON %s(%s)`,
   127  			uniqueStr,
   128  			ifNotExistsStr,
   129  			mysqlQuoteIdent(f.tmplPrefix()+st.NameValue),
   130  			mysqlQuoteIdent(f.tmplPrefix()+st.TableNameValue),
   131  			colStr,
   132  		)
   133  		return []string{buf.String()}, nil
   134  
   135  	case *DropIndexStmt:
   136  		fmt.Fprintf(&buf, `DROP INDEX %s ON %s`,
   137  			mysqlQuoteIdent(f.tmplPrefix()+st.NameValue),
   138  			mysqlQuoteIdent(f.tmplPrefix()+st.TableNameValue),
   139  		)
   140  		return []string{buf.String()}, nil
   141  
   142  	}
   143  
   144  	return nil, fmt.Errorf("unknown statement type %T", stmt)
   145  }
   146  
   147  func mysqlQuoteIdent(ident string) string {
   148  	return quoteIdent(ident, "`")
   149  }
   150  
   151  // https://dev.mysql.com/doc/refman/8.0/en/string-literals.html
   152  var mysqlEncodeStringReplacer = strings.NewReplacer(`'`, `\'`, `\`, `\\`)
   153  
   154  func mysqlEncodeString(s string) string {
   155  	return `'` + mysqlEncodeStringReplacer.Replace(s) + `'`
   156  }
   157  
   158  func mysqlColStr(col *DataTypeDef) (string, error) {
   159  
   160  	defaultStr := ""
   161  	if col.DefaultValue != nil {
   162  		if s, ok := col.DefaultValue.(string); ok {
   163  			defaultStr = fmt.Sprintf(" DEFAULT %s", mysqlEncodeString(s))
   164  		} else {
   165  			// FIXME: we should be more careful about what escaping and formatting is used here
   166  			// and the various possible data types
   167  			defaultStr = fmt.Sprintf(" DEFAULT %v", col.DefaultValue)
   168  		}
   169  	}
   170  	nullStr := " NOT NULL"
   171  	if col.NullValue {
   172  		nullStr = " NULL"
   173  	}
   174  	caseSensitiveStr := "" // by default it will be case insensitive, no need to specify anything
   175  	if col.CaseSensitiveValue {
   176  		caseSensitiveStr = " /*!50508 COLLATE utf8mb4_bin */" // case sensitive is done by changing collation to binary
   177  	}
   178  
   179  	// pk/fk columns are different, make them ascii and case sensitive
   180  	keyColSuffix := " CHARACTER SET ascii COLLATE ascii_bin"
   181  
   182  	switch col.DataTypeValue {
   183  	case Custom:
   184  		return fmt.Sprintf("%s %s", mysqlQuoteIdent(col.NameValue), col.CustomSQLValue), nil
   185  	case VarCharPK:
   186  		// always case sensitive
   187  		return fmt.Sprintf("%s VARCHAR(64)%s%s%s", mysqlQuoteIdent(col.NameValue), keyColSuffix, nullStr, defaultStr), nil
   188  	case BigIntAutoPK:
   189  		return fmt.Sprintf("%s BIGINT NOT NULL AUTO_INCREMENT", mysqlQuoteIdent(col.NameValue)), nil
   190  	case VarCharFK:
   191  		// always case sensitive
   192  		return fmt.Sprintf("%s VARCHAR(64)%s%s%s", mysqlQuoteIdent(col.NameValue), keyColSuffix, nullStr, defaultStr), nil
   193  	case BigIntFK:
   194  		return fmt.Sprintf("%s BIGINT%s%s", mysqlQuoteIdent(col.NameValue), nullStr, defaultStr), nil
   195  	case Int:
   196  		return fmt.Sprintf("%s INTEGER%s%s", mysqlQuoteIdent(col.NameValue), nullStr, defaultStr), nil
   197  	case IntU:
   198  		return fmt.Sprintf("%s INTEGER UNSIGNED %s%s", mysqlQuoteIdent(col.NameValue), nullStr, defaultStr), nil
   199  	case BigInt:
   200  		return fmt.Sprintf("%s BIGINT%s%s", mysqlQuoteIdent(col.NameValue), nullStr, defaultStr), nil
   201  	case BigIntU:
   202  		return fmt.Sprintf("%s BIGINT UNSIGNED%s%s", mysqlQuoteIdent(col.NameValue), nullStr, defaultStr), nil
   203  	case Double:
   204  		return fmt.Sprintf("%s REAL%s%s", mysqlQuoteIdent(col.NameValue), nullStr, defaultStr), nil
   205  	case DateTime:
   206  		// use native DATETIME type and add the extra sub-second precision if supported
   207  		return fmt.Sprintf("%s DATETIME/*!50604 (6) */%s%s", mysqlQuoteIdent(col.NameValue), nullStr, defaultStr), nil
   208  	case VarChar:
   209  		// default string length - big enough to handle most human-entered values, but
   210  		// small enough to fit under the InnoDB 767 index max byte length (i.e. if you
   211  		// set this to 255 with utf8mb4 and try to index it the index creation will fail,
   212  		// it's pain in the ass - 128 is a better default)
   213  		lengthStr := "(128)"
   214  		if col.LengthValue > 0 {
   215  			lengthStr = fmt.Sprintf("(%d)", col.LengthValue)
   216  		}
   217  		return fmt.Sprintf("%s VARCHAR%s%s%s%s", mysqlQuoteIdent(col.NameValue), lengthStr, nullStr, caseSensitiveStr, defaultStr), nil
   218  	case Text:
   219  		lengthStr := "" // no explicit length unless specified
   220  		if col.LengthValue > 0 {
   221  			lengthStr = fmt.Sprintf("(%d)", col.LengthValue)
   222  		}
   223  		return fmt.Sprintf("%s TEXT%s%s%s%s", mysqlQuoteIdent(col.NameValue), lengthStr, nullStr, caseSensitiveStr, defaultStr), nil
   224  	case Bool:
   225  		return fmt.Sprintf("%s BOOLEAN%s%s", mysqlQuoteIdent(col.NameValue), nullStr, defaultStr), nil
   226  	case Blob:
   227  		lengthStr := "" // no explicit length unless specified
   228  		if col.LengthValue > 0 {
   229  			lengthStr = fmt.Sprintf("(%d)", col.LengthValue)
   230  		}
   231  		return fmt.Sprintf("%s BLOB%s%s%s", mysqlQuoteIdent(col.NameValue), lengthStr, nullStr, defaultStr), nil
   232  
   233  	}
   234  
   235  	return "", fmt.Errorf("unknown DataType: %v", col.DataTypeValue)
   236  }