github.com/wanlay/gorm-dm8@v1.0.5/create.go (about) 1 package dm 2 3 import ( 4 "bytes" 5 "database/sql" 6 "reflect" 7 8 "github.com/thoas/go-funk" 9 "github.com/wanlay/gorm-dm8/clauses" 10 "gorm.io/gorm" 11 "gorm.io/gorm/callbacks" 12 "gorm.io/gorm/clause" 13 gormSchema "gorm.io/gorm/schema" 14 ) 15 16 func Create(db *gorm.DB) { 17 stmt := db.Statement 18 schema := stmt.Schema 19 boundVars := make(map[string]int) 20 21 if stmt == nil || schema == nil { 22 return 23 } 24 25 hasDefaultValues := len(schema.FieldsWithDefaultDBValue) > 0 26 27 if !stmt.Unscoped { 28 for _, c := range schema.CreateClauses { 29 stmt.AddClause(c) 30 } 31 } 32 33 if stmt.SQL.String() == "" { 34 values := callbacks.ConvertToCreateValues(stmt) 35 onConflict, hasConflict := stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict) 36 if hasConflict { 37 stmt.AddClauseIfNotExists(clauses.Merge{ 38 Using: []clause.Interface{ 39 clause.Select{ 40 Columns: funk.Map(values.Columns, func(column clause.Column) clause.Column { 41 // HACK: I can not come up with a better alternative for now 42 // I want to add a value to the list of variable and then capture the bind variable position as well 43 buf := bytes.NewBufferString("") 44 stmt.Vars = append(stmt.Vars, values.Values[0][funk.IndexOf(values.Columns, column)]) 45 stmt.BindVarTo(buf, stmt, nil) 46 47 column.Alias = column.Name 48 // then the captured bind var will be the name 49 column.Name = buf.String() 50 return column 51 }).([]clause.Column), 52 }, 53 clause.From{ 54 Tables: []clause.Table{{Name: "DUAL"}}, 55 }, 56 }, 57 On: funk.Map(onConflict.Columns, func(field clause.Column) clause.Expression { 58 return clause.Eq{ 59 Column: clause.Column{Table: stmt.Table, Name: field.Name}, 60 Value: clause.Column{Table: clauses.MergeDefaultExcludeName(), Name: field.Name}, 61 } 62 }).([]clause.Expression), 63 }) 64 65 stmt.AddClauseIfNotExists(clauses.WhenMatched{Set: onConflict.DoUpdates}) 66 stmt.AddClauseIfNotExists(clauses.WhenNotMatched{Values: values}) 67 68 stmt.Build("MERGE", "WHEN MATCHED", "WHEN NOT MATCHED") 69 } else { 70 stmt.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: stmt.Table}}) 71 stmt.AddClause(clause.Values{Columns: values.Columns, Values: [][]interface{}{values.Values[0]}}) 72 if hasDefaultValues { 73 stmt.AddClauseIfNotExists(clause.Returning{ 74 Columns: funk.Map(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) clause.Column { 75 return clause.Column{Name: field.DBName} 76 }).([]clause.Column), 77 }) 78 } 79 stmt.Build("INSERT", "VALUES") 80 // 返回自增主键 81 // stmt.Build("INSERT", "VALUES", "RETURNING") 82 // if hasDefaultValues { 83 // stmt.WriteString(" INTO ") 84 // for idx, field := range schema.FieldsWithDefaultDBValue { 85 // if idx > 0 { 86 // stmt.WriteByte(',') 87 // } 88 // boundVars[field.Name] = len(stmt.Vars) 89 // stmt.AddVar(stmt, sql.Out{Dest: reflect.New(field.FieldType).Interface()}) 90 // } 91 // } 92 } 93 94 if !db.DryRun { 95 for idx, vals := range values.Values { 96 // HACK HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned 97 for idx, val := range vals { 98 switch v := val.(type) { 99 case bool: 100 if v { 101 val = 1 102 } else { 103 val = 0 104 } 105 } 106 107 stmt.Vars[idx] = val 108 } 109 // and then we insert each row one by one then put the returning values back (i.e. last return id => smart insert) 110 // we keep track of the index so that the sub-reflected value is also correct 111 112 // BIG BUG: what if any of the transactions failed? some result might already be inserted that dm is so 113 // sneaky that some transaction inserts will exceed the buffer and so will be pushed at unknown point, 114 // resulting in dangling row entries, so we might need to delete them if an error happens 115 116 sqlStr := stmt.Explain(stmt.SQL.String(), stmt.Vars...) 117 switch result, err := stmt.ConnPool.ExecContext(stmt.Context, sqlStr, stmt.Vars...); err { 118 case nil: // success 119 db.RowsAffected, _ = result.RowsAffected() 120 121 insertTo := stmt.ReflectValue 122 switch insertTo.Kind() { 123 case reflect.Slice, reflect.Array: 124 insertTo = insertTo.Index(idx) 125 } 126 127 if hasDefaultValues { 128 // bind returning value back to reflected value in the respective fields 129 funk.ForEach( 130 funk.Filter(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) bool { 131 return funk.Contains(boundVars, field.Name) 132 }), 133 func(field *gormSchema.Field) { 134 switch insertTo.Kind() { 135 case reflect.Struct: 136 if err = field.Set(stmt.Context, insertTo, stmt.Vars[boundVars[field.Name]].(sql.Out).Dest); err != nil { 137 db.AddError(err) 138 } 139 case reflect.Map: 140 // todo 设置id的值 141 } 142 }, 143 ) 144 } 145 default: // failure 146 db.AddError(err) 147 } 148 } 149 } 150 } 151 }