github.com/wanlay/gorm-dm8@v1.0.5/create.go (about)

     1  package dm
     2  
     3  import (
     4  	"bytes"
     5  	"database/sql"
     6  	"reflect"
     7  
     8  	"github.com/thoas/go-funk"
     9  	"github.com/wanlay/gorm-dm8/clauses"
    10  	"gorm.io/gorm"
    11  	"gorm.io/gorm/callbacks"
    12  	"gorm.io/gorm/clause"
    13  	gormSchema "gorm.io/gorm/schema"
    14  )
    15  
    16  func Create(db *gorm.DB) {
    17  	stmt := db.Statement
    18  	schema := stmt.Schema
    19  	boundVars := make(map[string]int)
    20  
    21  	if stmt == nil || schema == nil {
    22  		return
    23  	}
    24  
    25  	hasDefaultValues := len(schema.FieldsWithDefaultDBValue) > 0
    26  
    27  	if !stmt.Unscoped {
    28  		for _, c := range schema.CreateClauses {
    29  			stmt.AddClause(c)
    30  		}
    31  	}
    32  
    33  	if stmt.SQL.String() == "" {
    34  		values := callbacks.ConvertToCreateValues(stmt)
    35  		onConflict, hasConflict := stmt.Clauses["ON CONFLICT"].Expression.(clause.OnConflict)
    36  		if hasConflict {
    37  			stmt.AddClauseIfNotExists(clauses.Merge{
    38  				Using: []clause.Interface{
    39  					clause.Select{
    40  						Columns: funk.Map(values.Columns, func(column clause.Column) clause.Column {
    41  							// HACK: I can not come up with a better alternative for now
    42  							// I want to add a value to the list of variable and then capture the bind variable position as well
    43  							buf := bytes.NewBufferString("")
    44  							stmt.Vars = append(stmt.Vars, values.Values[0][funk.IndexOf(values.Columns, column)])
    45  							stmt.BindVarTo(buf, stmt, nil)
    46  
    47  							column.Alias = column.Name
    48  							// then the captured bind var will be the name
    49  							column.Name = buf.String()
    50  							return column
    51  						}).([]clause.Column),
    52  					},
    53  					clause.From{
    54  						Tables: []clause.Table{{Name: "DUAL"}},
    55  					},
    56  				},
    57  				On: funk.Map(onConflict.Columns, func(field clause.Column) clause.Expression {
    58  					return clause.Eq{
    59  						Column: clause.Column{Table: stmt.Table, Name: field.Name},
    60  						Value:  clause.Column{Table: clauses.MergeDefaultExcludeName(), Name: field.Name},
    61  					}
    62  				}).([]clause.Expression),
    63  			})
    64  
    65  			stmt.AddClauseIfNotExists(clauses.WhenMatched{Set: onConflict.DoUpdates})
    66  			stmt.AddClauseIfNotExists(clauses.WhenNotMatched{Values: values})
    67  
    68  			stmt.Build("MERGE", "WHEN MATCHED", "WHEN NOT MATCHED")
    69  		} else {
    70  			stmt.AddClauseIfNotExists(clause.Insert{Table: clause.Table{Name: stmt.Table}})
    71  			stmt.AddClause(clause.Values{Columns: values.Columns, Values: [][]interface{}{values.Values[0]}})
    72  			if hasDefaultValues {
    73  				stmt.AddClauseIfNotExists(clause.Returning{
    74  					Columns: funk.Map(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) clause.Column {
    75  						return clause.Column{Name: field.DBName}
    76  					}).([]clause.Column),
    77  				})
    78  			}
    79  			stmt.Build("INSERT", "VALUES")
    80  			// 返回自增主键
    81  			// stmt.Build("INSERT", "VALUES", "RETURNING")
    82  			// if hasDefaultValues {
    83  			//	stmt.WriteString(" INTO ")
    84  			//	for idx, field := range schema.FieldsWithDefaultDBValue {
    85  			//		if idx > 0 {
    86  			//			stmt.WriteByte(',')
    87  			//		}
    88  			//		boundVars[field.Name] = len(stmt.Vars)
    89  			//		stmt.AddVar(stmt, sql.Out{Dest: reflect.New(field.FieldType).Interface()})
    90  			//	}
    91  			// }
    92  		}
    93  
    94  		if !db.DryRun {
    95  			for idx, vals := range values.Values {
    96  				// HACK HACK: replace values one by one, assuming its value layout will be the same all the time, i.e. aligned
    97  				for idx, val := range vals {
    98  					switch v := val.(type) {
    99  					case bool:
   100  						if v {
   101  							val = 1
   102  						} else {
   103  							val = 0
   104  						}
   105  					}
   106  
   107  					stmt.Vars[idx] = val
   108  				}
   109  				// and then we insert each row one by one then put the returning values back (i.e. last return id => smart insert)
   110  				// we keep track of the index so that the sub-reflected value is also correct
   111  
   112  				// BIG BUG: what if any of the transactions failed? some result might already be inserted that dm is so
   113  				// sneaky that some transaction inserts will exceed the buffer and so will be pushed at unknown point,
   114  				// resulting in dangling row entries, so we might need to delete them if an error happens
   115  
   116  				sqlStr := stmt.Explain(stmt.SQL.String(), stmt.Vars...)
   117  				switch result, err := stmt.ConnPool.ExecContext(stmt.Context, sqlStr, stmt.Vars...); err {
   118  				case nil: // success
   119  					db.RowsAffected, _ = result.RowsAffected()
   120  
   121  					insertTo := stmt.ReflectValue
   122  					switch insertTo.Kind() {
   123  					case reflect.Slice, reflect.Array:
   124  						insertTo = insertTo.Index(idx)
   125  					}
   126  
   127  					if hasDefaultValues {
   128  						// bind returning value back to reflected value in the respective fields
   129  						funk.ForEach(
   130  							funk.Filter(schema.FieldsWithDefaultDBValue, func(field *gormSchema.Field) bool {
   131  								return funk.Contains(boundVars, field.Name)
   132  							}),
   133  							func(field *gormSchema.Field) {
   134  								switch insertTo.Kind() {
   135  								case reflect.Struct:
   136  									if err = field.Set(stmt.Context, insertTo, stmt.Vars[boundVars[field.Name]].(sql.Out).Dest); err != nil {
   137  										db.AddError(err)
   138  									}
   139  								case reflect.Map:
   140  									// todo 设置id的值
   141  								}
   142  							},
   143  						)
   144  					}
   145  				default: // failure
   146  					db.AddError(err)
   147  				}
   148  			}
   149  		}
   150  	}
   151  }