go.charczuk.com@v0.0.0-20240327042549-bc490516bd1a/sdk/db/dbgen/table_from.go (about)

     1  /*
     2  
     3  Copyright (c) 2023 - Present. Will Charczuk. All rights reserved.
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file at the root of the repository.
     5  
     6  */
     7  
     8  package dbgen
     9  
    10  import (
    11  	"bytes"
    12  	"fmt"
    13  	"reflect"
    14  	"strings"
    15  	"text/template"
    16  
    17  	"go.charczuk.com/sdk/db"
    18  	"go.charczuk.com/sdk/db/migration"
    19  )
    20  
    21  // TableFrom returns a migration step for a given type to initialize
    22  // a database with a table for that type.
    23  //
    24  // Note it does _not_ cover extra stuff like indices and constraints
    25  // and other table features; for those you'll want to generate those with
    26  // dedicated helpers and pass them in the `extra ...string` varaidic argument.
    27  func TableFrom(obj any, extra ...string) *migration.Step {
    28  	tableName := db.TableName(obj)
    29  
    30  	var columns []string
    31  	var primaryKeys []string
    32  	var columnDefinition string
    33  	var leadingComma string
    34  	for index, column := range db.TypeMetaFor(obj).Columns() {
    35  		if index > 0 {
    36  			leadingComma = ", "
    37  		}
    38  
    39  		if column.IsPrimaryKey {
    40  			primaryKeys = append(primaryKeys, column.ColumnName)
    41  			if column.IsAuto {
    42  				columnDefinition = fmt.Sprintf("%s%s %s NOT NULL DEFAULT %s", leadingComma, column.ColumnName, dbTypeForFieldType(column.FieldType), dbAutoForFieldType(column.FieldType))
    43  			} else {
    44  				columnDefinition = fmt.Sprintf("%s%s %s NOT NULL", leadingComma, column.ColumnName, dbTypeForFieldType(column.FieldType))
    45  			}
    46  		} else if column.IsAuto {
    47  			columnDefinition = fmt.Sprintf("%s%s %s NOT NULL DEFAULT %s", leadingComma, column.ColumnName, dbTypeForFieldType(column.FieldType), dbAutoForFieldType(column.FieldType))
    48  		} else if column.IsJSON {
    49  			columnDefinition = fmt.Sprintf("%s%s %s", leadingComma, column.ColumnName, "JSONB")
    50  		} else {
    51  			if column.FieldType.Kind() == reflect.Ptr {
    52  				columnDefinition = fmt.Sprintf("%s%s %s", leadingComma, column.ColumnName, dbTypeForFieldType(column.FieldType))
    53  			} else {
    54  				columnDefinition = fmt.Sprintf("%s%s %s NOT NULL", leadingComma, column.ColumnName, dbTypeForFieldType(column.FieldType))
    55  			}
    56  		}
    57  		columns = append(columns, columnDefinition)
    58  	}
    59  
    60  	var statements []string
    61  	statements = append(statements, f(`CREATE TABLE {{ .TableName }} (
    62  {{- range $index, $column := .Columns }}
    63  	{{ $column }}
    64  {{- end }}
    65  )`, v{"TableName": tableName, "Columns": columns}),
    66  	)
    67  
    68  	if len(primaryKeys) > 0 {
    69  		statements = append(statements, f(`ALTER TABLE {{ .TableName }} ADD CONSTRAINT {{ .ConstraintName }} PRIMARY KEY ({{.Columns}})`, v{
    70  			"TableName":      tableName,
    71  			"ConstraintName": fmt.Sprintf("pk_%s_%s", tableName, strings.Join(primaryKeys, "_")),
    72  			"Columns":        strings.Join(primaryKeys, ","),
    73  		}))
    74  	}
    75  	return migration.NewStep(
    76  		migration.TableNotExists(tableName),
    77  		migration.Statements(
    78  			append(statements, extra...)...,
    79  		),
    80  	)
    81  }
    82  
    83  func dbTypeForFieldType(t reflect.Type) string {
    84  	for t.Kind() == reflect.Ptr {
    85  		t = t.Elem()
    86  	}
    87  	switch t.Kind() {
    88  	case reflect.Bool:
    89  		return "BOOLEAN"
    90  	case reflect.String:
    91  		return "TEXT"
    92  	case reflect.Int32:
    93  		return "INT"
    94  	case reflect.Int, reflect.Int64:
    95  		return "BIGINT"
    96  	case reflect.Uint32:
    97  		return "INT"
    98  	case reflect.Uint, reflect.Uint64:
    99  		return "BIGINT"
   100  	case reflect.Float32:
   101  		return "REAL"
   102  	case reflect.Float64:
   103  		return "DOUBLE PRECISION"
   104  	default:
   105  	}
   106  	switch t.Name() {
   107  	case "Time":
   108  		if t.PkgPath() == "time" {
   109  			return "TIMESTAMP"
   110  		}
   111  	case "Duration":
   112  		if t.PkgPath() == "time" {
   113  			return "INTERVAL"
   114  		}
   115  	case "UUID":
   116  		if t.PkgPath() == "go.charczuk.com/sdk/uuid" {
   117  			return "UUID"
   118  		}
   119  	}
   120  	panic(fmt.Sprintf("unknown field type for db: %s/%s", t.PkgPath(), t.Name()))
   121  }
   122  
   123  func dbAutoForFieldType(t reflect.Type) string {
   124  	for t.Kind() == reflect.Ptr {
   125  		t = t.Elem()
   126  	}
   127  	switch t.Name() {
   128  	case "Time":
   129  		if t.PkgPath() == "time" {
   130  			return "current_timestamp"
   131  		}
   132  	case "UUID":
   133  		if t.PkgPath() == "go.charczuk.com/sdk/uuid" {
   134  			return "gen_random_uuid()"
   135  		}
   136  	default:
   137  	}
   138  	panic(fmt.Sprintf("unknown auto field type for db: %s/%s", t.PkgPath(), t.Name()))
   139  }
   140  
   141  // v is a map between string and anything.
   142  //
   143  // It is typically used as an argument to `f(...)` and
   144  // lets you associate values with keys.
   145  type v map[string]any
   146  
   147  // f uses text/template to format a given string
   148  func f(format string, args any) string {
   149  	t, err := template.New("").Parse(format)
   150  	if err != nil {
   151  		return ""
   152  	}
   153  
   154  	buf := new(bytes.Buffer)
   155  	_ = t.Execute(buf, args)
   156  	return buf.String()
   157  }