github.com/Azareal/Gosora@v0.0.0-20210729070923-553e66b59003/query_gen/pgsql.go (about)

     1  /* WIP Under Really Heavy Construction */
     2  package qgen
     3  
     4  import (
     5  	"database/sql"
     6  	"errors"
     7  	"strconv"
     8  	"strings"
     9  )
    10  
    11  func init() {
    12  	Registry = append(Registry,
    13  		&PgsqlAdapter{Name: "pgsql", Buffer: make(map[string]DBStmt)},
    14  	)
    15  }
    16  
    17  type PgsqlAdapter struct {
    18  	Name        string // ? - Do we really need this? Can't we hard-code this?
    19  	Buffer      map[string]DBStmt
    20  	BufferOrder []string // Map iteration order is random, so we need this to track the order, so we don't get huge diffs every commit
    21  }
    22  
    23  // GetName gives you the name of the database adapter. In this case, it's pgsql
    24  func (a *PgsqlAdapter) GetName() string {
    25  	return a.Name
    26  }
    27  
    28  func (a *PgsqlAdapter) GetStmt(name string) DBStmt {
    29  	return a.Buffer[name]
    30  }
    31  
    32  func (a *PgsqlAdapter) GetStmts() map[string]DBStmt {
    33  	return a.Buffer
    34  }
    35  
    36  // TODO: Implement this
    37  func (a *PgsqlAdapter) BuildConn(config map[string]string) (*sql.DB, error) {
    38  	return nil, nil
    39  }
    40  
    41  func (a *PgsqlAdapter) DbVersion() string {
    42  	return "SELECT version()"
    43  }
    44  
    45  func (a *PgsqlAdapter) DropTable(name, table string) (string, error) {
    46  	if table == "" {
    47  		return "", errors.New("You need a name for this table")
    48  	}
    49  	q := "DROP TABLE IF EXISTS \"" + table + "\";"
    50  	a.pushStatement(name, "drop-table", q)
    51  	return q, nil
    52  }
    53  
    54  // TODO: Implement this
    55  // We may need to change the CreateTable API to better suit PGSQL and the other database drivers which are coming up
    56  func (a *PgsqlAdapter) CreateTable(name, table, charset, collation string, cols []DBTableColumn, keys []DBTableKey) (string, error) {
    57  	if table == "" {
    58  		return "", errors.New("You need a name for this table")
    59  	}
    60  	if len(cols) == 0 {
    61  		return "", errors.New("You can't have a table with no columns")
    62  	}
    63  
    64  	q := "CREATE TABLE \"" + table + "\" ("
    65  	for _, col := range cols {
    66  		if col.AutoIncrement {
    67  			col.Type = "serial"
    68  		} else if col.Type == "createdAt" {
    69  			col.Type = "timestamp"
    70  		} else if col.Type == "datetime" {
    71  			col.Type = "timestamp"
    72  		}
    73  
    74  		var size string
    75  		if col.Size > 0 {
    76  			size = " (" + strconv.Itoa(col.Size) + ")"
    77  		}
    78  
    79  		var end string
    80  		if col.Default != "" {
    81  			end = " DEFAULT "
    82  			if a.stringyType(col.Type) && col.Default != "''" {
    83  				end += "'" + col.Default + "'"
    84  			} else {
    85  				end += col.Default
    86  			}
    87  		}
    88  		if !col.Null {
    89  			end += " not null"
    90  		}
    91  
    92  		q += "\n\t`" + col.Name + "` " + col.Type + size + end + ","
    93  	}
    94  
    95  	if len(keys) > 0 {
    96  		for _, key := range keys {
    97  			q += "\n\t" + key.Type
    98  			if key.Type != "unique" {
    99  				q += " key"
   100  			}
   101  			q += "("
   102  			for _, column := range strings.Split(key.Columns, ",") {
   103  				q += "`" + column + "`,"
   104  			}
   105  			q = q[0:len(q)-1] + "),"
   106  		}
   107  	}
   108  
   109  	q = q[0:len(q)-1] + "\n);"
   110  	a.pushStatement(name, "create-table", q)
   111  	return q, nil
   112  }
   113  
   114  // TODO: Implement this
   115  func (a *PgsqlAdapter) AddColumn(name, table string, column DBTableColumn, key *DBTableKey) (string, error) {
   116  	if table == "" {
   117  		return "", errors.New("You need a name for this table")
   118  	}
   119  	return "", nil
   120  }
   121  
   122  // TODO: Implement this
   123  func (a *PgsqlAdapter) DropColumn(name, table, colName string) (string, error) {
   124  	return "", errors.New("not implemented")
   125  }
   126  
   127  // TODO: Implement this
   128  func (a *PgsqlAdapter) RenameColumn(name, table, oldName, newName string) (string, error) {
   129  	return "", errors.New("not implemented")
   130  }
   131  
   132  // TODO: Implement this
   133  func (a *PgsqlAdapter) ChangeColumn(name, table, colName string, col DBTableColumn) (string, error) {
   134  	return "", errors.New("not implemented")
   135  }
   136  
   137  // TODO: Implement this
   138  func (a *PgsqlAdapter) SetDefaultColumn(name, table, colName, colType, defaultStr string) (string, error) {
   139  	if colType == "text" {
   140  		return "", errors.New("text fields cannot have default values")
   141  	}
   142  	return "", errors.New("not implemented")
   143  }
   144  
   145  // TODO: Implement this
   146  // TODO: Test to make sure everything works here
   147  func (a *PgsqlAdapter) AddIndex(name, table, iname, colname string) (string, error) {
   148  	if table == "" {
   149  		return "", errors.New("You need a name for this table")
   150  	}
   151  	if iname == "" {
   152  		return "", errors.New("You need a name for the index")
   153  	}
   154  	if colname == "" {
   155  		return "", errors.New("You need a name for the column")
   156  	}
   157  	return "", errors.New("not implemented")
   158  }
   159  
   160  // TODO: Implement this
   161  // TODO: Test to make sure everything works here
   162  func (a *PgsqlAdapter) AddKey(name, table, column string, key DBTableKey) (string, error) {
   163  	if table == "" {
   164  		return "", errors.New("You need a name for this table")
   165  	}
   166  	if column == "" {
   167  		return "", errors.New("You need a name for the column")
   168  	}
   169  	return "", errors.New("not implemented")
   170  }
   171  
   172  // TODO: Implement this
   173  // TODO: Test to make sure everything works here
   174  func (a *PgsqlAdapter) RemoveIndex(name, table, iname string) (string, error) {
   175  	if table == "" {
   176  		return "", errors.New("You need a name for this table")
   177  	}
   178  	if iname == "" {
   179  		return "", errors.New("You need a name for the index")
   180  	}
   181  	return "", errors.New("not implemented")
   182  }
   183  
   184  // TODO: Implement this
   185  // TODO: Test to make sure everything works here
   186  func (a *PgsqlAdapter) AddForeignKey(name, table, column, ftable, fcolumn string, cascade bool) (out string, e error) {
   187  	var c = func(str string, val bool) {
   188  		if e != nil || !val {
   189  			return
   190  		}
   191  		e = errors.New("You need a " + str + " for this table")
   192  	}
   193  	c("name", table == "")
   194  	c("column", column == "")
   195  	c("ftable", ftable == "")
   196  	c("fcolumn", fcolumn == "")
   197  	if e != nil {
   198  		return "", e
   199  	}
   200  	return "", errors.New("not implemented")
   201  }
   202  
   203  // TODO: Test this
   204  // ! We need to get the last ID out of this somehow, maybe add returning to every query? Might require some sort of wrapper over the sql statements
   205  func (a *PgsqlAdapter) SimpleInsert(name, table, columns, fields string) (string, error) {
   206  	if table == "" {
   207  		return "", errors.New("You need a name for this table")
   208  	}
   209  
   210  	q := "INSERT INTO \"" + table + "\"("
   211  	if columns != "" {
   212  		q += a.buildColumns(columns) + ") VALUES ("
   213  		for _, field := range processFields(fields) {
   214  			nameLen := len(field.Name)
   215  			if field.Name[0] == '"' && field.Name[nameLen-1] == '"' && nameLen >= 3 {
   216  				field.Name = "'" + field.Name[1:nameLen-1] + "'"
   217  			}
   218  			if field.Name[0] == '\'' && field.Name[nameLen-1] == '\'' && nameLen >= 3 {
   219  				field.Name = "'" + strings.Replace(field.Name[1:nameLen-1], "'", "''", -1) + "'"
   220  			}
   221  			q += field.Name + ","
   222  		}
   223  		q = q[0 : len(q)-1]
   224  	} else {
   225  		q += ") VALUES ("
   226  	}
   227  	q += ")"
   228  
   229  	a.pushStatement(name, "insert", q)
   230  	return q, nil
   231  }
   232  
   233  // TODO: Implement this
   234  func (a *PgsqlAdapter) SimpleBulkInsert(name, table, columns string, fieldSet []string) (string, error) {
   235  	return "", nil
   236  }
   237  
   238  func (a *PgsqlAdapter) buildColumns(cols string) (q string) {
   239  	if cols == "" {
   240  		return ""
   241  	}
   242  	// Escape the column names, just in case we've used a reserved keyword
   243  	for _, col := range processColumns(cols) {
   244  		if col.Type == TokenFunc {
   245  			q += col.Left + ","
   246  		} else {
   247  			q += "\"" + col.Left + "\","
   248  		}
   249  	}
   250  	return q[0 : len(q)-1]
   251  }
   252  
   253  // TODO: Implement this
   254  func (a *PgsqlAdapter) SimpleReplace(name, table, columns, fields string) (string, error) {
   255  	if table == "" {
   256  		return "", errors.New("You need a name for this table")
   257  	}
   258  	if len(columns) == 0 {
   259  		return "", errors.New("No columns found for SimpleInsert")
   260  	}
   261  	if len(fields) == 0 {
   262  		return "", errors.New("No input data found for SimpleInsert")
   263  	}
   264  	return "", nil
   265  }
   266  
   267  // TODO: Implement this
   268  func (a *PgsqlAdapter) SimpleUpsert(name, table, columns, fields, where string) (string, error) {
   269  	if table == "" {
   270  		return "", errors.New("You need a name for this table")
   271  	}
   272  	if len(columns) == 0 {
   273  		return "", errors.New("No columns found for SimpleInsert")
   274  	}
   275  	if len(fields) == 0 {
   276  		return "", errors.New("No input data found for SimpleInsert")
   277  	}
   278  	return "", nil
   279  }
   280  
   281  // TODO: Implemented, but we need CreateTable and a better installer to *test* it
   282  func (a *PgsqlAdapter) SimpleUpdate(up *updatePrebuilder) (string, error) {
   283  	if up.table == "" {
   284  		return "", errors.New("You need a name for this table")
   285  	}
   286  	if up.set == "" {
   287  		return "", errors.New("You need to set data in this update statement")
   288  	}
   289  
   290  	q := "UPDATE \"" + up.table + "\" SET "
   291  	for _, item := range processSet(up.set) {
   292  		q += "`" + item.Column + "`="
   293  		for _, token := range item.Expr {
   294  			switch token.Type {
   295  			case TokenFunc:
   296  				// TODO: Write a more sophisticated function parser on the utils side.
   297  				if strings.ToUpper(token.Contents) == "UTC_TIMESTAMP()" {
   298  					token.Contents = "LOCALTIMESTAMP()"
   299  				}
   300  				q += " " + token.Contents
   301  			case TokenOp, TokenNumber, TokenSub, TokenOr:
   302  				q += " " + token.Contents
   303  			case TokenColumn:
   304  				q += " `" + token.Contents + "`"
   305  			case TokenString:
   306  				q += " '" + token.Contents + "'"
   307  			}
   308  		}
   309  		q += ","
   310  	}
   311  	q = q[0 : len(q)-1]
   312  
   313  	// Add support for BETWEEN x.x
   314  	if len(up.where) != 0 {
   315  		q += " WHERE"
   316  		for _, loc := range processWhere(up.where) {
   317  			for _, token := range loc.Expr {
   318  				switch token.Type {
   319  				case TokenFunc:
   320  					// TODO: Write a more sophisticated function parser on the utils side. What's the situation in regards to case sensitivity?
   321  					if strings.ToUpper(token.Contents) == "UTC_TIMESTAMP()" {
   322  						token.Contents = "LOCALTIMESTAMP()"
   323  					}
   324  					q += " " + token.Contents
   325  				case TokenOp, TokenNumber, TokenSub, TokenOr, TokenNot, TokenLike:
   326  					q += " " + token.Contents
   327  				case TokenColumn:
   328  					q += " `" + token.Contents + "`"
   329  				case TokenString:
   330  					q += " '" + token.Contents + "'"
   331  				default:
   332  					panic("This token doesn't exist o_o")
   333  				}
   334  			}
   335  			q += " AND"
   336  		}
   337  		q = q[0 : len(q)-4]
   338  	}
   339  
   340  	a.pushStatement(up.name, "update", q)
   341  	return q, nil
   342  }
   343  
   344  // TODO: Implement this
   345  func (a *PgsqlAdapter) SimpleUpdateSelect(up *updatePrebuilder) (string, error) {
   346  	return "", errors.New("not implemented")
   347  }
   348  
   349  // TODO: Implement this
   350  func (a *PgsqlAdapter) SimpleDelete(name, table, where string) (string, error) {
   351  	if table == "" {
   352  		return "", errors.New("You need a name for this table")
   353  	}
   354  	if where == "" {
   355  		return "", errors.New("You need to specify what data you want to delete")
   356  	}
   357  	return "", nil
   358  }
   359  
   360  // TODO: Implement this
   361  func (a *PgsqlAdapter) ComplexDelete(b *deletePrebuilder) (string, error) {
   362  	if b.table == "" {
   363  		return "", errors.New("You need a name for this table")
   364  	}
   365  	if b.where == "" {
   366  		return "", errors.New("You need to specify what data you want to delete")
   367  	}
   368  	return "", nil
   369  }
   370  
   371  // TODO: Implement this
   372  // We don't want to accidentally wipe tables, so we'll have a separate method for purging tables instead
   373  func (a *PgsqlAdapter) Purge(name, table string) (string, error) {
   374  	if table == "" {
   375  		return "", errors.New("You need a name for this table")
   376  	}
   377  	return "", nil
   378  }
   379  
   380  // TODO: Implement this
   381  func (a *PgsqlAdapter) SimpleSelect(name, table, columns, where, orderby, limit string) (string, error) {
   382  	if table == "" {
   383  		return "", errors.New("You need a name for this table")
   384  	}
   385  	if len(columns) == 0 {
   386  		return "", errors.New("No columns found for SimpleSelect")
   387  	}
   388  	return "", nil
   389  }
   390  
   391  // TODO: Implement this
   392  func (a *PgsqlAdapter) ComplexSelect(prebuilder *selectPrebuilder) (string, error) {
   393  	if prebuilder.table == "" {
   394  		return "", errors.New("You need a name for this table")
   395  	}
   396  	if len(prebuilder.columns) == 0 {
   397  		return "", errors.New("No columns found for ComplexSelect")
   398  	}
   399  	return "", nil
   400  }
   401  
   402  // TODO: Implement this
   403  func (a *PgsqlAdapter) SimpleLeftJoin(name, table1, table2, columns, joiners, where, orderby, limit string) (string, error) {
   404  	if table1 == "" {
   405  		return "", errors.New("You need a name for the left table")
   406  	}
   407  	if table2 == "" {
   408  		return "", errors.New("You need a name for the right table")
   409  	}
   410  	if len(columns) == 0 {
   411  		return "", errors.New("No columns found for SimpleLeftJoin")
   412  	}
   413  	if len(joiners) == 0 {
   414  		return "", errors.New("No joiners found for SimpleLeftJoin")
   415  	}
   416  	return "", nil
   417  }
   418  
   419  // TODO: Implement this
   420  func (a *PgsqlAdapter) SimpleInnerJoin(name, table1, table2, columns, joiners, where, orderby, limit string) (string, error) {
   421  	if table1 == "" {
   422  		return "", errors.New("You need a name for the left table")
   423  	}
   424  	if table2 == "" {
   425  		return "", errors.New("You need a name for the right table")
   426  	}
   427  	if len(columns) == 0 {
   428  		return "", errors.New("No columns found for SimpleInnerJoin")
   429  	}
   430  	if len(joiners) == 0 {
   431  		return "", errors.New("No joiners found for SimpleInnerJoin")
   432  	}
   433  	return "", nil
   434  }
   435  
   436  // TODO: Implement this
   437  func (a *PgsqlAdapter) SimpleInsertSelect(name string, ins DBInsert, sel DBSelect) (string, error) {
   438  	return "", nil
   439  }
   440  
   441  // TODO: Implement this
   442  func (a *PgsqlAdapter) SimpleInsertLeftJoin(name string, ins DBInsert, sel DBJoin) (string, error) {
   443  	return "", nil
   444  }
   445  
   446  // TODO: Implement this
   447  func (a *PgsqlAdapter) SimpleInsertInnerJoin(name string, ins DBInsert, sel DBJoin) (string, error) {
   448  	return "", nil
   449  }
   450  
   451  // TODO: Implement this
   452  func (a *PgsqlAdapter) SimpleCount(name, table, where, limit string) (string, error) {
   453  	if table == "" {
   454  		return "", errors.New("You need a name for this table")
   455  	}
   456  	return "", nil
   457  }
   458  
   459  func (a *PgsqlAdapter) Builder() *prebuilder {
   460  	return &prebuilder{a}
   461  }
   462  
   463  func (a *PgsqlAdapter) Write() error {
   464  	var stmts, body string
   465  	for _, name := range a.BufferOrder {
   466  		if name[0] == '_' {
   467  			continue
   468  		}
   469  		stmt := a.Buffer[name]
   470  		// TODO: Add support for create-table? Table creation might be a little complex for Go to do outside a SQL file :(
   471  		if stmt.Type != "create-table" {
   472  			stmts += "\t" + name + " *sql.Stmt\n"
   473  			body += `	
   474  	common.DebugLog("Preparing ` + name + ` statement.")
   475  	stmts.` + name + `, err = db.Prepare("` + strings.Replace(stmt.Contents, "\"", "\\\"", -1) + `")
   476  	if err != nil {
   477  		log.Print("Error in ` + name + ` statement.")
   478  		return err
   479  	}
   480  	`
   481  		}
   482  	}
   483  
   484  	// TODO: Move these custom queries out of this file
   485  	out := `// +build pgsql
   486  
   487  // This file was generated by Gosora's Query Generator. Please try to avoid modifying this file, as it might change at any time.
   488  package main
   489  
   490  import "log"
   491  import "database/sql"
   492  import "github.com/Azareal/Gosora/common"
   493  
   494  // nolint
   495  type Stmts struct {
   496  ` + stmts + `
   497  	getActivityFeedByWatcher *sql.Stmt
   498  	getActivityCountByWatcher *sql.Stmt
   499  
   500  	Mocks bool
   501  }
   502  
   503  // nolint
   504  func _gen_pgsql() (err error) {
   505  	common.DebugLog("Building the generated statements")
   506  ` + body + `
   507  	return nil
   508  }
   509  `
   510  	return writeFile("./gen_pgsql.go", out)
   511  }
   512  
   513  // Internal methods, not exposed in the interface
   514  func (a *PgsqlAdapter) pushStatement(name, stype, q string) {
   515  	if name == "" {
   516  		return
   517  	}
   518  	a.Buffer[name] = DBStmt{q, stype}
   519  	a.BufferOrder = append(a.BufferOrder, name)
   520  }
   521  
   522  func (a *PgsqlAdapter) stringyType(ctype string) bool {
   523  	ctype = strings.ToLower(ctype)
   524  	return ctype == "char" || ctype == "varchar" || ctype == "timestamp" || ctype == "text"
   525  }