gitlab.com/beacon-software/gadget@v0.0.0-20181217202115-54565ea1ed5e/database/qb/insert.go (about) 1 package qb 2 3 import ( 4 "fmt" 5 "strings" 6 7 "gitlab.com/beacon-software/gadget/errors" 8 ) 9 10 // InsertQuery for inserting a row into the database. 11 type InsertQuery struct { 12 columns []TableField 13 values [][]interface{} 14 onDuplicate []TableField 15 onDuplicateValues []interface{} 16 err error 17 } 18 19 // Values to be inserted. Call multiple times to insert multiple rows. 20 func (q *InsertQuery) Values(values ...interface{}) *InsertQuery { 21 if len(values) != len(q.columns) { 22 q.err = errors.New("insert field/value count mismatch") 23 } else { 24 q.values = append(q.values, values) 25 } 26 return q 27 } 28 29 // OnDuplicate update these fields / values 30 func (q *InsertQuery) OnDuplicate(fields []TableField, values ...interface{}) *InsertQuery { 31 q.onDuplicate = append(q.onDuplicate, fields...) 32 q.onDuplicateValues = values 33 return q 34 } 35 36 // GetAlias of the passed table name in this query. 37 func (q *InsertQuery) GetAlias(tableName string) string { 38 return tableName 39 } 40 41 // SQL that represents this insert query. 42 func (q *InsertQuery) SQL() (string, []interface{}, error) { 43 if len(q.columns) == 0 { 44 return "", nil, errors.New("no columns specified for insert") 45 } 46 colExp := make([]string, len(q.columns)) 47 qms := make([]string, len(q.columns)) 48 for i, col := range q.columns { 49 colExp[i] = col.SQL() 50 if col.Table != q.columns[0].Table { 51 return "", nil, errors.New("insert columns must be from the same table") 52 } 53 qms[i] = "?" 54 } 55 valExp := fmt.Sprintf("(%s)", strings.Join(qms, ", ")) 56 valExps := make([]string, len(q.values)) 57 values := []interface{}{} 58 for i, valGrp := range q.values { 59 valExps[i] = valExp 60 values = append(values, valGrp...) 61 } 62 onDuplicate := "" 63 if len(q.onDuplicate) > 0 { 64 if len(q.values) > 1 { 65 return "", nil, errors.New("cannot use on duplicate with multi-insert") 66 } 67 updateFields := make([]string, len(q.onDuplicate)) 68 for _, col := range q.onDuplicate { 69 if col.Table != q.columns[0].Table { 70 return "", nil, errors.New("insert columns must be from the same table") 71 } 72 for i, col := range q.onDuplicate { 73 updateFields[i] = fmt.Sprintf("%s = ?", col.SQL()) 74 } 75 } 76 values = append(values, q.onDuplicateValues...) 77 onDuplicate = " ON DUPLICATE KEY UPDATE " + strings.Join(updateFields, ", ") 78 } 79 return fmt.Sprintf("INSERT INTO `%s` (%s) VALUES %s%s", q.columns[0].Table, strings.Join(colExp, ", "), 80 strings.Join(valExps, ", "), onDuplicate), values, q.err 81 } 82 83 // ParameterizedSQL that represents this insert query. 84 func (q *InsertQuery) ParameterizedSQL() (string, error) { 85 if len(q.columns) == 0 { 86 return "", errors.New("no columns specified for insert") 87 } 88 colExp := make([]string, len(q.columns)) 89 qms := make([]string, len(q.columns)) 90 for i, col := range q.columns { 91 colExp[i] = col.SQL() 92 if col.Table != q.columns[0].Table { 93 return "", errors.New("insert columns must be from the same table") 94 } 95 qms[i] = ":" + col.GetName() 96 } 97 onDuplicate := "" 98 if len(q.onDuplicate) > 0 { 99 if len(q.values) > 1 { 100 return "", errors.New("cannot use on duplicate with multi-insert") 101 } 102 updateFields := make([]string, len(q.onDuplicate)) 103 for i, field := range q.onDuplicate { 104 updateFields[i] = fmt.Sprintf("%s = :%s", field.SQL(), field.GetName()) 105 } 106 onDuplicate = " ON DUPLICATE KEY UPDATE " + strings.Join(updateFields, ", ") 107 } 108 return fmt.Sprintf("INSERT INTO `%s` (%s) VALUES (%s)%s", q.columns[0].Table, strings.Join(colExp, ", "), 109 strings.Join(qms, ", "), onDuplicate), q.err 110 }