gitlab.com/beacon-software/gadget@v0.0.0-20181217202115-54565ea1ed5e/database/qb/select.go (about)

     1  package qb
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  )
     7  
     8  // SelectExpression for use in identifying the fields desired in a select query.
     9  type SelectExpression interface {
    10  	// GetName that can be used to reference this expression
    11  	GetName() string
    12  	// GetTables that are used in this expression
    13  	GetTables() []string
    14  	// SQL that represents this SelectExpression
    15  	SQL() string
    16  }
    17  
    18  type alias struct {
    19  	field TableField
    20  	alias string
    21  }
    22  
    23  func (a alias) GetName() string {
    24  	return a.alias
    25  }
    26  
    27  func (a alias) GetTables() []string {
    28  	return a.field.GetTables()
    29  }
    30  
    31  func (a alias) SQL() string {
    32  	return fmt.Sprintf("%s AS `%s`", a.field.SQL(), a.alias)
    33  }
    34  
    35  // Alias the passed table field for use in or as a SelectExpression
    36  func Alias(tableField TableField, aliasName string) SelectExpression {
    37  	return alias{field: tableField, alias: aliasName}
    38  }
    39  
    40  type notNull struct {
    41  	field TableField
    42  	alias string
    43  }
    44  
    45  func (nn notNull) GetName() string {
    46  	return nn.alias
    47  }
    48  
    49  func (nn notNull) GetTables() []string {
    50  	return nn.field.GetTables()
    51  }
    52  
    53  func (nn notNull) SQL() string {
    54  	return fmt.Sprintf("(%s IS NOT NULL) AS `%s`", nn.field.SQL(), nn.alias)
    55  }
    56  
    57  // NotNull field for use as a select expression
    58  func NotNull(tableField TableField, alias string) SelectExpression {
    59  	return notNull{field: tableField, alias: alias}
    60  }
    61  
    62  type coalesce struct {
    63  	field TableField
    64  	def   string
    65  	name  string
    66  }
    67  
    68  func (c coalesce) GetName() string {
    69  	return c.name
    70  }
    71  
    72  func (c coalesce) GetTables() []string {
    73  	return c.field.GetTables()
    74  }
    75  
    76  func (c coalesce) SQL() string {
    77  	return fmt.Sprintf("COALESCE(%s, '%s') AS `%s`", c.field.SQL(), c.def, c.name)
    78  }
    79  
    80  // Coalesce creates a SQL coalesce that can be used as a SelectExpression
    81  func Coalesce(column TableField, defaultValue string, alias string) SelectExpression {
    82  	return coalesce{field: column, def: defaultValue, name: alias}
    83  }
    84  
    85  // SelectQuery for retrieving data from a database table.
    86  type SelectQuery struct {
    87  	distinct   bool
    88  	from       Table
    89  	selectExps []SelectExpression
    90  	joins      []*Join
    91  	orderBy    *orderBy
    92  	groupBy    []SelectExpression
    93  	where      *whereCondition
    94  	Seperator  string
    95  	err        error
    96  }
    97  
    98  // GetAlias of the passed table name in this query.
    99  func (q *SelectQuery) GetAlias(tableName string) string {
   100  	return tableName
   101  }
   102  
   103  // From sets the primary table the query will get values from.
   104  func (q *SelectQuery) From(table Table) *SelectQuery {
   105  	q.from = table
   106  	return q
   107  }
   108  
   109  // InnerJoin with another table in the database.
   110  func (q *SelectQuery) InnerJoin(table Table) *Join {
   111  	join := NewJoin(Inner, Right, table)
   112  	q.joins = append(q.joins, join)
   113  	return join
   114  }
   115  
   116  // OuterJoin with another table in the database.
   117  func (q *SelectQuery) OuterJoin(direction JoinDirection, table Table) *Join {
   118  	join := NewJoin(Outer, direction, table)
   119  	q.joins = append(q.joins, join)
   120  	return join
   121  }
   122  
   123  // Where the comparison between the two tablefields evaluates to true.
   124  func (q *SelectQuery) Where(condition *ConditionExpression) *SelectQuery {
   125  	q.where.expression = condition
   126  	return q
   127  }
   128  
   129  // OrderBy the passed field and direction.
   130  func (q *SelectQuery) OrderBy(field TableField, direction OrderDirection) *SelectQuery {
   131  	q.orderBy.addExpression(field, direction)
   132  	return q
   133  }
   134  
   135  // GroupBy the passed table field.
   136  func (q *SelectQuery) GroupBy(expressions ...SelectExpression) *SelectQuery {
   137  	q.groupBy = append(q.groupBy, expressions...)
   138  	return q
   139  }
   140  
   141  func (q *SelectQuery) selectExpressionsSQL() string {
   142  	var prefix string
   143  	if q.distinct {
   144  		prefix = "SELECT DISTINCT"
   145  	} else {
   146  		prefix = "SELECT"
   147  	}
   148  	expressions := make([]string, len(q.selectExps))
   149  	for i, exp := range q.selectExps {
   150  		expressions[i] = exp.SQL()
   151  	}
   152  	return fmt.Sprintf("%s %s", prefix, strings.Join(expressions, ", "))
   153  }
   154  
   155  // Validate that this query can be executed.
   156  func (q *SelectQuery) Validate() bool {
   157  	q.err = nil
   158  	// gather up all the tables that must be present in the from or in a join
   159  	tablesRequired := make(map[string]bool)
   160  	// check the select
   161  	for _, exp := range q.selectExps {
   162  		for _, table := range exp.GetTables() {
   163  			tablesRequired[table] = true
   164  		}
   165  	}
   166  	// check the where expressions
   167  	for _, table := range q.where.tables() {
   168  		tablesRequired[table] = true
   169  	}
   170  	// now get the tables from the order by
   171  	for _, table := range q.orderBy.getTables() {
   172  		tablesRequired[table] = true
   173  	}
   174  	// grab the tables from the group by
   175  	for _, tf := range q.groupBy {
   176  		for _, table := range tf.GetTables() {
   177  			tablesRequired[table] = true
   178  		}
   179  	}
   180  	// check that the from table is set
   181  	if nil == q.from {
   182  		q.err = NewValidationFromNotSetError()
   183  		return false
   184  	}
   185  	delete(tablesRequired, q.from.GetAlias())
   186  
   187  	for _, join := range q.joins {
   188  		delete(tablesRequired, join.table.GetAlias())
   189  		if nil != join.err {
   190  			q.err = join.err
   191  			return false
   192  		}
   193  	}
   194  
   195  	if len(tablesRequired) > 0 {
   196  		missingTables := make([]string, len(tablesRequired))
   197  		i := 0
   198  		for key := range tablesRequired {
   199  			missingTables[i] = key
   200  			i++
   201  		}
   202  		q.err = NewMissingTablesError(missingTables)
   203  		return false
   204  	}
   205  
   206  	return true
   207  }
   208  
   209  // SQL statement corresponding to this query.
   210  func (q *SelectQuery) SQL(limit, offset uint) (string, []interface{}, error) {
   211  	if !q.Validate() {
   212  		return "", []interface{}{}, q.err
   213  	}
   214  	// SELECT
   215  	lines := []string{q.selectExpressionsSQL()}
   216  	var values []interface{}
   217  
   218  	// FROM
   219  	from := fmt.Sprintf("FROM `%s` AS `%s`",
   220  		q.from.GetName(),
   221  		q.from.GetAlias())
   222  	lines = append(lines, from)
   223  
   224  	// JOIN
   225  	for _, join := range q.joins {
   226  		joinSQL, joinValues := join.SQL()
   227  		lines = append(lines, joinSQL)
   228  		values = append(values, joinValues...)
   229  	}
   230  
   231  	// WHERE
   232  	if where, whereValues, ok := q.where.sql(); ok {
   233  		lines = append(lines, "WHERE", where)
   234  		values = append(values, whereValues...)
   235  	}
   236  
   237  	// GROUP BY
   238  	if len(q.groupBy) > 0 {
   239  		groupByLines := []string{}
   240  		for _, tf := range q.groupBy {
   241  			groupByLines = append(groupByLines, tf.SQL())
   242  		}
   243  		groupByStatement := "GROUP BY " + strings.Join(groupByLines, ", ")
   244  		lines = append(lines, groupByStatement)
   245  	}
   246  
   247  	// ORDER BY
   248  	if orderby, ok := q.orderBy.sql(); ok {
   249  		lines = append(lines, orderby)
   250  	}
   251  
   252  	// LIMIT, OFFSET
   253  	if NoLimit != limit {
   254  		lines = append(lines, fmt.Sprintf("LIMIT %d OFFSET %d", limit, offset))
   255  	}
   256  	return strings.Join(lines, q.Seperator), values, q.err
   257  }