github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/databases/orm/db_postgres.go (about)

     1  // The original package is migrated from beego and modified, you can find orignal from following link:
     2  //    "github.com/beego/beego/"
     3  //
     4  // Copyright 2023 IAC. All Rights Reserved.
     5  //
     6  // Licensed under the Apache License, Version 2.0 (the "License");
     7  // you may not use this file except in compliance with the License.
     8  // You may obtain a copy of the License at
     9  //
    10  //      http://www.apache.org/licenses/LICENSE-2.0
    11  //
    12  // Unless required by applicable law or agreed to in writing, software
    13  // distributed under the License is distributed on an "AS IS" BASIS,
    14  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15  // See the License for the specific language governing permissions and
    16  // limitations under the License.
    17  
    18  package orm
    19  
    20  import (
    21  	"context"
    22  	"fmt"
    23  	"strconv"
    24  )
    25  
    26  // postgresql operators.
    27  var postgresOperators = map[string]string{
    28  	"exact":       "= ?",
    29  	"iexact":      "= UPPER(?)",
    30  	"contains":    "LIKE ?",
    31  	"icontains":   "LIKE UPPER(?)",
    32  	"gt":          "> ?",
    33  	"gte":         ">= ?",
    34  	"lt":          "< ?",
    35  	"lte":         "<= ?",
    36  	"eq":          "= ?",
    37  	"ne":          "!= ?",
    38  	"startswith":  "LIKE ?",
    39  	"endswith":    "LIKE ?",
    40  	"istartswith": "LIKE UPPER(?)",
    41  	"iendswith":   "LIKE UPPER(?)",
    42  }
    43  
    44  // postgresql column field types.
    45  var postgresTypes = map[string]string{
    46  	"auto":                "serial NOT NULL PRIMARY KEY",
    47  	"pk":                  "NOT NULL PRIMARY KEY",
    48  	"bool":                "bool",
    49  	"string":              "varchar(%d)",
    50  	"string-char":         "char(%d)",
    51  	"string-text":         "text",
    52  	"time.Time-date":      "date",
    53  	"time.Time":           "timestamp with time zone",
    54  	"int8":                `smallint CHECK("%COL%" >= -127 AND "%COL%" <= 128)`,
    55  	"int16":               "smallint",
    56  	"int32":               "integer",
    57  	"int64":               "bigint",
    58  	"uint8":               `smallint CHECK("%COL%" >= 0 AND "%COL%" <= 255)`,
    59  	"uint16":              `integer CHECK("%COL%" >= 0)`,
    60  	"uint32":              `bigint CHECK("%COL%" >= 0)`,
    61  	"uint64":              `bigint CHECK("%COL%" >= 0)`,
    62  	"float64":             "double precision",
    63  	"float64-decimal":     "numeric(%d, %d)",
    64  	"json":                "json",
    65  	"jsonb":               "jsonb",
    66  	"time.Time-precision": "timestamp(%d) with time zone",
    67  }
    68  
    69  // postgresql dbBaser.
    70  type dbBasePostgres struct {
    71  	dbBase
    72  }
    73  
    74  var _ dbBaser = new(dbBasePostgres)
    75  
    76  // get postgresql operator.
    77  func (d *dbBasePostgres) OperatorSQL(operator string) string {
    78  	return postgresOperators[operator]
    79  }
    80  
    81  // generate functioned sql string, such as contains(text).
    82  func (d *dbBasePostgres) GenerateOperatorLeftCol(fi *fieldInfo, operator string, leftCol *string) {
    83  	switch operator {
    84  	case "contains", "startswith", "endswith":
    85  		*leftCol = fmt.Sprintf("%s::text", *leftCol)
    86  	case "iexact", "icontains", "istartswith", "iendswith":
    87  		*leftCol = fmt.Sprintf("UPPER(%s::text)", *leftCol)
    88  	}
    89  }
    90  
    91  // postgresql unsupports updating joined record.
    92  func (d *dbBasePostgres) SupportUpdateJoin() bool {
    93  	return false
    94  }
    95  
    96  func (d *dbBasePostgres) MaxLimit() uint64 {
    97  	return 0
    98  }
    99  
   100  // postgresql quote is ".
   101  func (d *dbBasePostgres) TableQuote() string {
   102  	return `"`
   103  }
   104  
   105  // postgresql value placeholder is $n.
   106  // replace default ? to $n.
   107  func (d *dbBasePostgres) ReplaceMarks(query *string) {
   108  	q := *query
   109  	num := 0
   110  	for _, c := range q {
   111  		if c == '?' {
   112  			num++
   113  		}
   114  	}
   115  	if num == 0 {
   116  		return
   117  	}
   118  	data := make([]byte, 0, len(q)+num)
   119  	num = 1
   120  	for i := 0; i < len(q); i++ {
   121  		c := q[i]
   122  		if c == '?' {
   123  			data = append(data, '$')
   124  			data = append(data, []byte(strconv.Itoa(num))...)
   125  			num++
   126  		} else {
   127  			data = append(data, c)
   128  		}
   129  	}
   130  	*query = string(data)
   131  }
   132  
   133  // make returning sql support for postgresql.
   134  func (d *dbBasePostgres) HasReturningID(mi *modelInfo, query *string) bool {
   135  	fi := mi.fields.pk
   136  	if fi.fieldType&IsPositiveIntegerField == 0 && fi.fieldType&IsIntegerField == 0 {
   137  		return false
   138  	}
   139  
   140  	if query != nil {
   141  		*query = fmt.Sprintf(`%s RETURNING "%s"`, *query, fi.column)
   142  	}
   143  	return true
   144  }
   145  
   146  // sync auto key
   147  func (d *dbBasePostgres) setval(ctx context.Context, db dbQuerier, mi *modelInfo, autoFields []string) error {
   148  	if len(autoFields) == 0 {
   149  		return nil
   150  	}
   151  
   152  	Q := d.ins.TableQuote()
   153  	for _, name := range autoFields {
   154  		query := fmt.Sprintf("SELECT setval(pg_get_serial_sequence('%s', '%s'), (SELECT MAX(%s%s%s) FROM %s%s%s));",
   155  			mi.table, name,
   156  			Q, name, Q,
   157  			Q, mi.table, Q)
   158  		if _, err := db.ExecContext(ctx, query); err != nil {
   159  			return err
   160  		}
   161  	}
   162  	return nil
   163  }
   164  
   165  // show table sql for postgresql.
   166  func (d *dbBasePostgres) ShowTablesQuery() string {
   167  	return "SELECT table_name FROM information_schema.tables WHERE table_type = 'BASE TABLE' AND table_schema NOT IN ('pg_catalog', 'information_schema')"
   168  }
   169  
   170  // show table columns sql for postgresql.
   171  func (d *dbBasePostgres) ShowColumnsQuery(table string) string {
   172  	return fmt.Sprintf("SELECT column_name, data_type, is_nullable FROM information_schema.columns where table_schema NOT IN ('pg_catalog', 'information_schema') and table_name = '%s'", table)
   173  }
   174  
   175  // get column types of postgresql.
   176  func (d *dbBasePostgres) DbTypes() map[string]string {
   177  	return postgresTypes
   178  }
   179  
   180  // check index exist in postgresql.
   181  func (d *dbBasePostgres) IndexExists(ctx context.Context, db dbQuerier, table string, name string) bool {
   182  	query := fmt.Sprintf("SELECT COUNT(*) FROM pg_indexes WHERE tablename = '%s' AND indexname = '%s'", table, name)
   183  	row := db.QueryRowContext(ctx, query)
   184  	var cnt int
   185  	row.Scan(&cnt)
   186  	return cnt > 0
   187  }
   188  
   189  // GenerateSpecifyIndex return a specifying index clause
   190  func (d *dbBasePostgres) GenerateSpecifyIndex(tableName string, useIndex int, indexes []string) string {
   191  	DebugLog.Println("[WARN] Not support any specifying index action, so that action is ignored")
   192  	return ``
   193  }
   194  
   195  // create new postgresql dbBaser.
   196  func newdbBasePostgres() dbBaser {
   197  	b := new(dbBasePostgres)
   198  	b.ins = b
   199  	return b
   200  }