gitlab.com/beacon-software/gadget@v0.0.0-20181217202115-54565ea1ed5e/database/qb/insert.go (about)

     1  package qb
     2  
     3  import (
     4  	"fmt"
     5  	"strings"
     6  
     7  	"gitlab.com/beacon-software/gadget/errors"
     8  )
     9  
    10  // InsertQuery for inserting a row into the database.
    11  type InsertQuery struct {
    12  	columns           []TableField
    13  	values            [][]interface{}
    14  	onDuplicate       []TableField
    15  	onDuplicateValues []interface{}
    16  	err               error
    17  }
    18  
    19  // Values to be inserted. Call multiple times to insert multiple rows.
    20  func (q *InsertQuery) Values(values ...interface{}) *InsertQuery {
    21  	if len(values) != len(q.columns) {
    22  		q.err = errors.New("insert field/value count mismatch")
    23  	} else {
    24  		q.values = append(q.values, values)
    25  	}
    26  	return q
    27  }
    28  
    29  // OnDuplicate update these fields / values
    30  func (q *InsertQuery) OnDuplicate(fields []TableField, values ...interface{}) *InsertQuery {
    31  	q.onDuplicate = append(q.onDuplicate, fields...)
    32  	q.onDuplicateValues = values
    33  	return q
    34  }
    35  
    36  // GetAlias of the passed table name in this query.
    37  func (q *InsertQuery) GetAlias(tableName string) string {
    38  	return tableName
    39  }
    40  
    41  // SQL that represents this insert query.
    42  func (q *InsertQuery) SQL() (string, []interface{}, error) {
    43  	if len(q.columns) == 0 {
    44  		return "", nil, errors.New("no columns specified for insert")
    45  	}
    46  	colExp := make([]string, len(q.columns))
    47  	qms := make([]string, len(q.columns))
    48  	for i, col := range q.columns {
    49  		colExp[i] = col.SQL()
    50  		if col.Table != q.columns[0].Table {
    51  			return "", nil, errors.New("insert columns must be from the same table")
    52  		}
    53  		qms[i] = "?"
    54  	}
    55  	valExp := fmt.Sprintf("(%s)", strings.Join(qms, ", "))
    56  	valExps := make([]string, len(q.values))
    57  	values := []interface{}{}
    58  	for i, valGrp := range q.values {
    59  		valExps[i] = valExp
    60  		values = append(values, valGrp...)
    61  	}
    62  	onDuplicate := ""
    63  	if len(q.onDuplicate) > 0 {
    64  		if len(q.values) > 1 {
    65  			return "", nil, errors.New("cannot use on duplicate with multi-insert")
    66  		}
    67  		updateFields := make([]string, len(q.onDuplicate))
    68  		for _, col := range q.onDuplicate {
    69  			if col.Table != q.columns[0].Table {
    70  				return "", nil, errors.New("insert columns must be from the same table")
    71  			}
    72  			for i, col := range q.onDuplicate {
    73  				updateFields[i] = fmt.Sprintf("%s = ?", col.SQL())
    74  			}
    75  		}
    76  		values = append(values, q.onDuplicateValues...)
    77  		onDuplicate = " ON DUPLICATE KEY UPDATE " + strings.Join(updateFields, ", ")
    78  	}
    79  	return fmt.Sprintf("INSERT INTO `%s` (%s) VALUES %s%s", q.columns[0].Table, strings.Join(colExp, ", "),
    80  		strings.Join(valExps, ", "), onDuplicate), values, q.err
    81  }
    82  
    83  // ParameterizedSQL that represents this insert query.
    84  func (q *InsertQuery) ParameterizedSQL() (string, error) {
    85  	if len(q.columns) == 0 {
    86  		return "", errors.New("no columns specified for insert")
    87  	}
    88  	colExp := make([]string, len(q.columns))
    89  	qms := make([]string, len(q.columns))
    90  	for i, col := range q.columns {
    91  		colExp[i] = col.SQL()
    92  		if col.Table != q.columns[0].Table {
    93  			return "", errors.New("insert columns must be from the same table")
    94  		}
    95  		qms[i] = ":" + col.GetName()
    96  	}
    97  	onDuplicate := ""
    98  	if len(q.onDuplicate) > 0 {
    99  		if len(q.values) > 1 {
   100  			return "", errors.New("cannot use on duplicate with multi-insert")
   101  		}
   102  		updateFields := make([]string, len(q.onDuplicate))
   103  		for i, field := range q.onDuplicate {
   104  			updateFields[i] = fmt.Sprintf("%s = :%s", field.SQL(), field.GetName())
   105  		}
   106  		onDuplicate = " ON DUPLICATE KEY UPDATE " + strings.Join(updateFields, ", ")
   107  	}
   108  	return fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)%s", q.columns[0].Table, strings.Join(colExp, ", "),
   109  		strings.Join(qms, ", "), onDuplicate), q.err
   110  }