github.com/RevenueMonster/sqlike@v1.0.6/sql/dialect/mysql/insert.go (about)

     1  package mysql
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  
     7  	"github.com/RevenueMonster/sqlike/reflext"
     8  	"github.com/RevenueMonster/sqlike/spatial"
     9  	"github.com/RevenueMonster/sqlike/sql/codec"
    10  	sqlstmt "github.com/RevenueMonster/sqlike/sql/stmt"
    11  	"github.com/RevenueMonster/sqlike/sqlike/options"
    12  )
    13  
    14  // InsertInto :
    15  func (ms MySQL) InsertInto(stmt sqlstmt.Stmt, db, table, pk string, cache reflext.StructMapper, cdc codec.Codecer, fields []reflext.StructFielder, v reflect.Value, opt *options.InsertOptions) (err error) {
    16  	records := v.Len()
    17  
    18  	stmt.WriteString("INSERT")
    19  	if opt.Mode == options.InsertIgnore {
    20  		stmt.WriteString(" IGNORE")
    21  	}
    22  	stmt.WriteString(" INTO " + ms.TableName(db, table) + " (")
    23  
    24  	omitField := make(map[string]bool)
    25  	noOfOmit := len(opt.Omits)
    26  	for i := 0; i < len(fields); {
    27  		// omit all the field provided by user
    28  		if noOfOmit > 0 && opt.Omits.IndexOf(fields[i].Name()) > -1 {
    29  			if opt.Mode != options.InsertOnDuplicate {
    30  				fields = append(fields[:i], fields[i+1:]...)
    31  				continue
    32  			} else {
    33  				omitField[fields[i].Name()] = true
    34  			}
    35  		}
    36  
    37  		// omit all the struct field with `generated_column` tag, it shouldn't include when inserting to the db
    38  		if _, ok := fields[i].Tag().LookUp("generated_column"); ok {
    39  			fields = append(fields[:i], fields[i+1:]...)
    40  			continue
    41  		}
    42  
    43  		stmt.WriteString(ms.Quote(fields[i].Name()))
    44  		if i < len(fields)-1 {
    45  			stmt.WriteByte(',')
    46  		}
    47  
    48  		i++
    49  	}
    50  	stmt.WriteString(") VALUES ")
    51  
    52  	length := len(fields)
    53  	encoders := make([]codec.ValueEncoder, length)
    54  	for i := 0; i < records; i++ {
    55  		if i > 0 {
    56  			stmt.WriteByte(',')
    57  		}
    58  		stmt.WriteByte('(')
    59  		vi := reflext.Indirect(v.Index(i))
    60  
    61  		for j := range fields {
    62  			if j > 0 {
    63  				stmt.WriteByte(',')
    64  			}
    65  
    66  			// first record only find encoders
    67  			fv := cache.FieldByIndexesReadOnly(vi, fields[j].Index())
    68  			if i == 0 {
    69  				encoders[j], err = findEncoder(cdc, fields[j], fv)
    70  				if err != nil {
    71  					return err
    72  				}
    73  			}
    74  
    75  			val, err := encoders[j](fields[j], fv)
    76  			if err != nil {
    77  				return err
    78  			}
    79  
    80  			convertSpatial(stmt, val)
    81  		}
    82  		stmt.WriteByte(')')
    83  	}
    84  
    85  	var (
    86  		column string
    87  		name   string
    88  	)
    89  	if opt.Mode == options.InsertOnDuplicate {
    90  		stmt.WriteString(" ON DUPLICATE KEY UPDATE ")
    91  		next := false
    92  		for _, f := range fields {
    93  			name = f.Name()
    94  			// skip primary key on duplicate update
    95  			if name == pk {
    96  				continue
    97  			}
    98  
    99  			// skip primary key on duplicate update
   100  			if _, ok := f.Tag().LookUp("primary_key"); ok {
   101  				continue
   102  			}
   103  
   104  			if _, ok := f.Tag().LookUp("auto_increment"); ok {
   105  				continue
   106  			}
   107  
   108  			// skip omit fields on update
   109  			if _, ok := omitField[name]; ok {
   110  				continue
   111  			}
   112  
   113  			if next {
   114  				stmt.WriteByte(',')
   115  			}
   116  
   117  			column = ms.Quote(name)
   118  			stmt.WriteString(column + "=VALUES(" + column + ")")
   119  			next = true
   120  		}
   121  	}
   122  	stmt.WriteByte(';')
   123  	return
   124  }
   125  
   126  func findEncoder(c codec.Codecer, sf reflext.StructFielder, v reflect.Value) (codec.ValueEncoder, error) {
   127  	// auto_increment field should pass nil if it's empty
   128  	if _, ok := sf.Tag().LookUp("auto_increment"); ok && reflext.IsZero(v) {
   129  		return codec.NilEncoder, nil
   130  	}
   131  	encoder, err := c.LookupEncoder(v)
   132  	if err != nil {
   133  		return nil, err
   134  	}
   135  	return encoder, nil
   136  }
   137  
   138  func convertSpatial(stmt sqlstmt.Stmt, val interface{}) {
   139  	switch vi := val.(type) {
   140  	case spatial.Geometry:
   141  		switch vi.Type {
   142  		case spatial.Point:
   143  			stmt.WriteString("ST_PointFromText")
   144  		case spatial.LineString:
   145  			stmt.WriteString("ST_LineStringFromText")
   146  		case spatial.Polygon:
   147  			stmt.WriteString("ST_PolygonFromText")
   148  		case spatial.MultiPoint:
   149  			stmt.WriteString("ST_MultiPointFromText")
   150  		case spatial.MultiLineString:
   151  			stmt.WriteString("ST_MultiLineStringFromText")
   152  		case spatial.MultiPolygon:
   153  			stmt.WriteString("ST_MultiPolygonFromText")
   154  		default:
   155  		}
   156  
   157  		stmt.WriteString("(?")
   158  		if vi.SRID > 0 {
   159  			stmt.WriteString(fmt.Sprintf(",%d", vi.SRID))
   160  		}
   161  		stmt.WriteByte(')')
   162  		stmt.AppendArgs(vi.WKT)
   163  
   164  	default:
   165  		stmt.WriteByte('?')
   166  		stmt.AppendArgs(val)
   167  	}
   168  }