github.com/RevenueMonster/sqlike@v1.0.6/sql/dialect/mysql/insert.go (about) 1 package mysql 2 3 import ( 4 "fmt" 5 "reflect" 6 7 "github.com/RevenueMonster/sqlike/reflext" 8 "github.com/RevenueMonster/sqlike/spatial" 9 "github.com/RevenueMonster/sqlike/sql/codec" 10 sqlstmt "github.com/RevenueMonster/sqlike/sql/stmt" 11 "github.com/RevenueMonster/sqlike/sqlike/options" 12 ) 13 14 // InsertInto : 15 func (ms MySQL) InsertInto(stmt sqlstmt.Stmt, db, table, pk string, cache reflext.StructMapper, cdc codec.Codecer, fields []reflext.StructFielder, v reflect.Value, opt *options.InsertOptions) (err error) { 16 records := v.Len() 17 18 stmt.WriteString("INSERT") 19 if opt.Mode == options.InsertIgnore { 20 stmt.WriteString(" IGNORE") 21 } 22 stmt.WriteString(" INTO " + ms.TableName(db, table) + " (") 23 24 omitField := make(map[string]bool) 25 noOfOmit := len(opt.Omits) 26 for i := 0; i < len(fields); { 27 // omit all the field provided by user 28 if noOfOmit > 0 && opt.Omits.IndexOf(fields[i].Name()) > -1 { 29 if opt.Mode != options.InsertOnDuplicate { 30 fields = append(fields[:i], fields[i+1:]...) 31 continue 32 } else { 33 omitField[fields[i].Name()] = true 34 } 35 } 36 37 // omit all the struct field with `generated_column` tag, it shouldn't include when inserting to the db 38 if _, ok := fields[i].Tag().LookUp("generated_column"); ok { 39 fields = append(fields[:i], fields[i+1:]...) 40 continue 41 } 42 43 stmt.WriteString(ms.Quote(fields[i].Name())) 44 if i < len(fields)-1 { 45 stmt.WriteByte(',') 46 } 47 48 i++ 49 } 50 stmt.WriteString(") VALUES ") 51 52 length := len(fields) 53 encoders := make([]codec.ValueEncoder, length) 54 for i := 0; i < records; i++ { 55 if i > 0 { 56 stmt.WriteByte(',') 57 } 58 stmt.WriteByte('(') 59 vi := reflext.Indirect(v.Index(i)) 60 61 for j := range fields { 62 if j > 0 { 63 stmt.WriteByte(',') 64 } 65 66 // first record only find encoders 67 fv := cache.FieldByIndexesReadOnly(vi, fields[j].Index()) 68 if i == 0 { 69 encoders[j], err = findEncoder(cdc, fields[j], fv) 70 if err != nil { 71 return err 72 } 73 } 74 75 val, err := encoders[j](fields[j], fv) 76 if err != nil { 77 return err 78 } 79 80 convertSpatial(stmt, val) 81 } 82 stmt.WriteByte(')') 83 } 84 85 var ( 86 column string 87 name string 88 ) 89 if opt.Mode == options.InsertOnDuplicate { 90 stmt.WriteString(" ON DUPLICATE KEY UPDATE ") 91 next := false 92 for _, f := range fields { 93 name = f.Name() 94 // skip primary key on duplicate update 95 if name == pk { 96 continue 97 } 98 99 // skip primary key on duplicate update 100 if _, ok := f.Tag().LookUp("primary_key"); ok { 101 continue 102 } 103 104 if _, ok := f.Tag().LookUp("auto_increment"); ok { 105 continue 106 } 107 108 // skip omit fields on update 109 if _, ok := omitField[name]; ok { 110 continue 111 } 112 113 if next { 114 stmt.WriteByte(',') 115 } 116 117 column = ms.Quote(name) 118 stmt.WriteString(column + "=VALUES(" + column + ")") 119 next = true 120 } 121 } 122 stmt.WriteByte(';') 123 return 124 } 125 126 func findEncoder(c codec.Codecer, sf reflext.StructFielder, v reflect.Value) (codec.ValueEncoder, error) { 127 // auto_increment field should pass nil if it's empty 128 if _, ok := sf.Tag().LookUp("auto_increment"); ok && reflext.IsZero(v) { 129 return codec.NilEncoder, nil 130 } 131 encoder, err := c.LookupEncoder(v) 132 if err != nil { 133 return nil, err 134 } 135 return encoder, nil 136 } 137 138 func convertSpatial(stmt sqlstmt.Stmt, val interface{}) { 139 switch vi := val.(type) { 140 case spatial.Geometry: 141 switch vi.Type { 142 case spatial.Point: 143 stmt.WriteString("ST_PointFromText") 144 case spatial.LineString: 145 stmt.WriteString("ST_LineStringFromText") 146 case spatial.Polygon: 147 stmt.WriteString("ST_PolygonFromText") 148 case spatial.MultiPoint: 149 stmt.WriteString("ST_MultiPointFromText") 150 case spatial.MultiLineString: 151 stmt.WriteString("ST_MultiLineStringFromText") 152 case spatial.MultiPolygon: 153 stmt.WriteString("ST_MultiPolygonFromText") 154 default: 155 } 156 157 stmt.WriteString("(?") 158 if vi.SRID > 0 { 159 stmt.WriteString(fmt.Sprintf(",%d", vi.SRID)) 160 } 161 stmt.WriteByte(')') 162 stmt.AppendArgs(vi.WKT) 163 164 default: 165 stmt.WriteByte('?') 166 stmt.AppendArgs(val) 167 } 168 }