go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/db/dbgen/table_from.go (about) 1 /* 2 3 Copyright (c) 2023 - Present. Will Charczuk. All rights reserved. 4 Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository. 5 6 */ 7 8 package dbgen 9 10 import ( 11 "bytes" 12 "fmt" 13 "reflect" 14 "strings" 15 "text/template" 16 17 "go.charczuk.com/sdk/db" 18 "go.charczuk.com/sdk/db/migration" 19 ) 20 21 // TableFrom returns a migration step for a given type to initialize 22 // a database with a table for that type. 23 // 24 // Note it does _not_ cover extra stuff like indices and constraints 25 // and other table features; for those you'll want to generate those with 26 // dedicated helpers and pass them in the `extra ...string` varaidic argument. 27 func TableFrom(obj any, extra ...string) *migration.Step { 28 tableName := db.TableName(obj) 29 30 var columns []string 31 var primaryKeys []string 32 var columnDefinition string 33 var leadingComma string 34 for index, column := range db.TypeMetaFor(obj).Columns() { 35 if index > 0 { 36 leadingComma = ", " 37 } 38 39 if column.IsPrimaryKey { 40 primaryKeys = append(primaryKeys, column.ColumnName) 41 if column.IsAuto { 42 columnDefinition = fmt.Sprintf("%s%s %s NOT NULL DEFAULT %s", leadingComma, column.ColumnName, dbTypeForFieldType(column.FieldType), dbAutoForFieldType(column.FieldType)) 43 } else { 44 columnDefinition = fmt.Sprintf("%s%s %s NOT NULL", leadingComma, column.ColumnName, dbTypeForFieldType(column.FieldType)) 45 } 46 } else if column.IsAuto { 47 columnDefinition = fmt.Sprintf("%s%s %s NOT NULL DEFAULT %s", leadingComma, column.ColumnName, dbTypeForFieldType(column.FieldType), dbAutoForFieldType(column.FieldType)) 48 } else if column.IsJSON { 49 columnDefinition = fmt.Sprintf("%s%s %s", leadingComma, column.ColumnName, "JSONB") 50 } else { 51 if column.FieldType.Kind() == reflect.Ptr { 52 columnDefinition = fmt.Sprintf("%s%s %s", leadingComma, column.ColumnName, dbTypeForFieldType(column.FieldType)) 53 } else { 54 columnDefinition = fmt.Sprintf("%s%s %s NOT NULL", leadingComma, column.ColumnName, dbTypeForFieldType(column.FieldType)) 55 } 56 } 57 columns = append(columns, columnDefinition) 58 } 59 60 var statements []string 61 statements = append(statements, f(`CREATE TABLE {{ .TableName }} ( 62 {{- range $index, $column := .Columns }} 63 {{ $column }} 64 {{- end }} 65 )`, v{"TableName": tableName, "Columns": columns}), 66 ) 67 68 if len(primaryKeys) > 0 { 69 statements = append(statements, f(`ALTER TABLE {{ .TableName }} ADD CONSTRAINT {{ .ConstraintName }} PRIMARY KEY ({{.Columns}})`, v{ 70 "TableName": tableName, 71 "ConstraintName": fmt.Sprintf("pk_%s_%s", tableName, strings.Join(primaryKeys, "_")), 72 "Columns": strings.Join(primaryKeys, ","), 73 })) 74 } 75 return migration.NewStep( 76 migration.TableNotExists(tableName), 77 migration.Statements( 78 append(statements, extra...)..., 79 ), 80 ) 81 } 82 83 func dbTypeForFieldType(t reflect.Type) string { 84 for t.Kind() == reflect.Ptr { 85 t = t.Elem() 86 } 87 switch t.Kind() { 88 case reflect.Bool: 89 return "BOOLEAN" 90 case reflect.String: 91 return "TEXT" 92 case reflect.Int32: 93 return "INT" 94 case reflect.Int, reflect.Int64: 95 return "BIGINT" 96 case reflect.Uint32: 97 return "INT" 98 case reflect.Uint, reflect.Uint64: 99 return "BIGINT" 100 case reflect.Float32: 101 return "REAL" 102 case reflect.Float64: 103 return "DOUBLE PRECISION" 104 default: 105 } 106 switch t.Name() { 107 case "Time": 108 if t.PkgPath() == "time" { 109 return "TIMESTAMP" 110 } 111 case "Duration": 112 if t.PkgPath() == "time" { 113 return "INTERVAL" 114 } 115 case "UUID": 116 if t.PkgPath() == "go.charczuk.com/sdk/uuid" { 117 return "UUID" 118 } 119 } 120 panic(fmt.Sprintf("unknown field type for db: %s/%s", t.PkgPath(), t.Name())) 121 } 122 123 func dbAutoForFieldType(t reflect.Type) string { 124 for t.Kind() == reflect.Ptr { 125 t = t.Elem() 126 } 127 switch t.Name() { 128 case "Time": 129 if t.PkgPath() == "time" { 130 return "current_timestamp" 131 } 132 case "UUID": 133 if t.PkgPath() == "go.charczuk.com/sdk/uuid" { 134 return "gen_random_uuid()" 135 } 136 default: 137 } 138 panic(fmt.Sprintf("unknown auto field type for db: %s/%s", t.PkgPath(), t.Name())) 139 } 140 141 // v is a map between string and anything. 142 // 143 // It is typically used as an argument to `f(...)` and 144 // lets you associate values with keys. 145 type v map[string]any 146 147 // f uses text/template to format a given string 148 func f(format string, args any) string { 149 t, err := template.New("").Parse(format) 150 if err != nil { 151 return "" 152 } 153 154 buf := new(bytes.Buffer) 155 _ = t.Execute(buf, args) 156 return buf.String() 157 }