github.com/wangyougui/gf/v2@v2.6.5/database/gdb/gdb_model_insert.go (about)

     1  // Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the MIT License.
     4  // If a copy of the MIT was not distributed with this file,
     5  // You can obtain one at https://github.com/wangyougui/gf.
     6  
     7  package gdb
     8  
     9  import (
    10  	"context"
    11  	"database/sql"
    12  	"reflect"
    13  
    14  	"github.com/wangyougui/gf/v2/container/gset"
    15  	"github.com/wangyougui/gf/v2/errors/gcode"
    16  	"github.com/wangyougui/gf/v2/errors/gerror"
    17  	"github.com/wangyougui/gf/v2/internal/reflection"
    18  	"github.com/wangyougui/gf/v2/text/gstr"
    19  	"github.com/wangyougui/gf/v2/util/gconv"
    20  	"github.com/wangyougui/gf/v2/util/gutil"
    21  )
    22  
    23  // Batch sets the batch operation number for the model.
    24  func (m *Model) Batch(batch int) *Model {
    25  	model := m.getModel()
    26  	model.batch = batch
    27  	return model
    28  }
    29  
    30  // Data sets the operation data for the model.
    31  // The parameter `data` can be type of string/map/gmap/slice/struct/*struct, etc.
    32  // Note that, it uses shallow value copying for `data` if `data` is type of map/slice
    33  // to avoid changing it inside function.
    34  // Eg:
    35  // Data("uid=10000")
    36  // Data("uid", 10000)
    37  // Data("uid=? AND name=?", 10000, "john")
    38  // Data(g.Map{"uid": 10000, "name":"john"})
    39  // Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"}).
    40  func (m *Model) Data(data ...interface{}) *Model {
    41  	var model = m.getModel()
    42  	if len(data) > 1 {
    43  		if s := gconv.String(data[0]); gstr.Contains(s, "?") {
    44  			model.data = s
    45  			model.extraArgs = data[1:]
    46  		} else {
    47  			m := make(map[string]interface{})
    48  			for i := 0; i < len(data); i += 2 {
    49  				m[gconv.String(data[i])] = data[i+1]
    50  			}
    51  			model.data = m
    52  		}
    53  	} else if len(data) == 1 {
    54  		switch value := data[0].(type) {
    55  		case Result:
    56  			model.data = value.List()
    57  
    58  		case Record:
    59  			model.data = value.Map()
    60  
    61  		case List:
    62  			list := make(List, len(value))
    63  			for k, v := range value {
    64  				list[k] = gutil.MapCopy(v)
    65  			}
    66  			model.data = list
    67  
    68  		case Map:
    69  			model.data = gutil.MapCopy(value)
    70  
    71  		default:
    72  			reflectInfo := reflection.OriginValueAndKind(value)
    73  			switch reflectInfo.OriginKind {
    74  			case reflect.Slice, reflect.Array:
    75  				if reflectInfo.OriginValue.Len() > 0 {
    76  					// If the `data` parameter is a DO struct,
    77  					// it then adds `OmitNilData` option for this condition,
    78  					// which will filter all nil parameters in `data`.
    79  					if isDoStruct(reflectInfo.OriginValue.Index(0).Interface()) {
    80  						model = model.OmitNilData()
    81  						model.option |= optionOmitNilDataInternal
    82  					}
    83  				}
    84  				list := make(List, reflectInfo.OriginValue.Len())
    85  				for i := 0; i < reflectInfo.OriginValue.Len(); i++ {
    86  					list[i] = anyValueToMapBeforeToRecord(reflectInfo.OriginValue.Index(i).Interface())
    87  				}
    88  				model.data = list
    89  
    90  			case reflect.Struct:
    91  				// If the `data` parameter is a DO struct,
    92  				// it then adds `OmitNilData` option for this condition,
    93  				// which will filter all nil parameters in `data`.
    94  				if isDoStruct(value) {
    95  					model = model.OmitNilData()
    96  				}
    97  				if v, ok := data[0].(iInterfaces); ok {
    98  					var (
    99  						array = v.Interfaces()
   100  						list  = make(List, len(array))
   101  					)
   102  					for i := 0; i < len(array); i++ {
   103  						list[i] = anyValueToMapBeforeToRecord(array[i])
   104  					}
   105  					model.data = list
   106  				} else {
   107  					model.data = anyValueToMapBeforeToRecord(data[0])
   108  				}
   109  
   110  			case reflect.Map:
   111  				model.data = anyValueToMapBeforeToRecord(data[0])
   112  
   113  			default:
   114  				model.data = data[0]
   115  			}
   116  		}
   117  	}
   118  	return model
   119  }
   120  
   121  // OnConflict sets the primary key or index when columns conflicts occurs.
   122  // It's not necessary for MySQL driver.
   123  func (m *Model) OnConflict(onConflict ...interface{}) *Model {
   124  	if len(onConflict) == 0 {
   125  		return m
   126  	}
   127  	model := m.getModel()
   128  	if len(onConflict) > 1 {
   129  		model.onConflict = onConflict
   130  	} else if len(onConflict) == 1 {
   131  		model.onConflict = onConflict[0]
   132  	}
   133  	return model
   134  }
   135  
   136  // OnDuplicate sets the operations when columns conflicts occurs.
   137  // In MySQL, this is used for "ON DUPLICATE KEY UPDATE" statement.
   138  // In PgSQL, this is used for "ON CONFLICT (id) DO UPDATE SET" statement.
   139  // The parameter `onDuplicate` can be type of string/Raw/*Raw/map/slice.
   140  // Example:
   141  //
   142  // OnDuplicate("nickname, age")
   143  // OnDuplicate("nickname", "age")
   144  //
   145  //	OnDuplicate(g.Map{
   146  //		  "nickname": gdb.Raw("CONCAT('name_', VALUES(`nickname`))"),
   147  //	})
   148  //
   149  //	OnDuplicate(g.Map{
   150  //		  "nickname": "passport",
   151  //	}).
   152  func (m *Model) OnDuplicate(onDuplicate ...interface{}) *Model {
   153  	if len(onDuplicate) == 0 {
   154  		return m
   155  	}
   156  	model := m.getModel()
   157  	if len(onDuplicate) > 1 {
   158  		model.onDuplicate = onDuplicate
   159  	} else if len(onDuplicate) == 1 {
   160  		model.onDuplicate = onDuplicate[0]
   161  	}
   162  	return model
   163  }
   164  
   165  // OnDuplicateEx sets the excluding columns for operations when columns conflict occurs.
   166  // In MySQL, this is used for "ON DUPLICATE KEY UPDATE" statement.
   167  // In PgSQL, this is used for "ON CONFLICT (id) DO UPDATE SET" statement.
   168  // The parameter `onDuplicateEx` can be type of string/map/slice.
   169  // Example:
   170  //
   171  // OnDuplicateEx("passport, password")
   172  // OnDuplicateEx("passport", "password")
   173  //
   174  //	OnDuplicateEx(g.Map{
   175  //		  "passport": "",
   176  //		  "password": "",
   177  //	}).
   178  func (m *Model) OnDuplicateEx(onDuplicateEx ...interface{}) *Model {
   179  	if len(onDuplicateEx) == 0 {
   180  		return m
   181  	}
   182  	model := m.getModel()
   183  	if len(onDuplicateEx) > 1 {
   184  		model.onDuplicateEx = onDuplicateEx
   185  	} else if len(onDuplicateEx) == 1 {
   186  		model.onDuplicateEx = onDuplicateEx[0]
   187  	}
   188  	return model
   189  }
   190  
   191  // Insert does "INSERT INTO ..." statement for the model.
   192  // The optional parameter `data` is the same as the parameter of Model.Data function,
   193  // see Model.Data.
   194  func (m *Model) Insert(data ...interface{}) (result sql.Result, err error) {
   195  	var ctx = m.GetCtx()
   196  	if len(data) > 0 {
   197  		return m.Data(data...).Insert()
   198  	}
   199  	return m.doInsertWithOption(ctx, InsertOptionDefault)
   200  }
   201  
   202  // InsertAndGetId performs action Insert and returns the last insert id that automatically generated.
   203  func (m *Model) InsertAndGetId(data ...interface{}) (lastInsertId int64, err error) {
   204  	var ctx = m.GetCtx()
   205  	if len(data) > 0 {
   206  		return m.Data(data...).InsertAndGetId()
   207  	}
   208  	result, err := m.doInsertWithOption(ctx, InsertOptionDefault)
   209  	if err != nil {
   210  		return 0, err
   211  	}
   212  	return result.LastInsertId()
   213  }
   214  
   215  // InsertIgnore does "INSERT IGNORE INTO ..." statement for the model.
   216  // The optional parameter `data` is the same as the parameter of Model.Data function,
   217  // see Model.Data.
   218  func (m *Model) InsertIgnore(data ...interface{}) (result sql.Result, err error) {
   219  	var ctx = m.GetCtx()
   220  	if len(data) > 0 {
   221  		return m.Data(data...).InsertIgnore()
   222  	}
   223  	return m.doInsertWithOption(ctx, InsertOptionIgnore)
   224  }
   225  
   226  // Replace does "REPLACE INTO ..." statement for the model.
   227  // The optional parameter `data` is the same as the parameter of Model.Data function,
   228  // see Model.Data.
   229  func (m *Model) Replace(data ...interface{}) (result sql.Result, err error) {
   230  	var ctx = m.GetCtx()
   231  	if len(data) > 0 {
   232  		return m.Data(data...).Replace()
   233  	}
   234  	return m.doInsertWithOption(ctx, InsertOptionReplace)
   235  }
   236  
   237  // Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the model.
   238  // The optional parameter `data` is the same as the parameter of Model.Data function,
   239  // see Model.Data.
   240  //
   241  // It updates the record if there's primary or unique index in the saving data,
   242  // or else it inserts a new record into the table.
   243  func (m *Model) Save(data ...interface{}) (result sql.Result, err error) {
   244  	var ctx = m.GetCtx()
   245  	if len(data) > 0 {
   246  		return m.Data(data...).Save()
   247  	}
   248  	return m.doInsertWithOption(ctx, InsertOptionSave)
   249  }
   250  
   251  // doInsertWithOption inserts data with option parameter.
   252  func (m *Model) doInsertWithOption(ctx context.Context, insertOption InsertOption) (result sql.Result, err error) {
   253  	defer func() {
   254  		if err == nil {
   255  			m.checkAndRemoveSelectCache(ctx)
   256  		}
   257  	}()
   258  	if m.data == nil {
   259  		return nil, gerror.NewCode(gcode.CodeMissingParameter, "inserting into table with empty data")
   260  	}
   261  	var (
   262  		list                             List
   263  		stm                              = m.softTimeMaintainer()
   264  		fieldNameCreate, fieldTypeCreate = stm.GetFieldNameAndTypeForCreate(ctx, "", m.tablesInit)
   265  		fieldNameUpdate, fieldTypeUpdate = stm.GetFieldNameAndTypeForUpdate(ctx, "", m.tablesInit)
   266  		fieldNameDelete, fieldTypeDelete = stm.GetFieldNameAndTypeForDelete(ctx, "", m.tablesInit)
   267  	)
   268  	// m.data was already converted to type List/Map by function Data
   269  	newData, err := m.filterDataForInsertOrUpdate(m.data)
   270  	if err != nil {
   271  		return nil, err
   272  	}
   273  	// It converts any data to List type for inserting.
   274  	switch value := newData.(type) {
   275  	case List:
   276  		list = value
   277  
   278  	case Map:
   279  		list = List{value}
   280  	}
   281  
   282  	if len(list) < 1 {
   283  		return result, gerror.NewCode(gcode.CodeMissingParameter, "data list cannot be empty")
   284  	}
   285  
   286  	// Automatic handling for creating/updating time.
   287  	if !m.unscoped && (fieldNameCreate != "" || fieldNameUpdate != "") {
   288  		for k, v := range list {
   289  			if fieldNameCreate != "" {
   290  				fieldCreateValue := stm.GetValueByFieldTypeForCreateOrUpdate(ctx, fieldTypeCreate, false)
   291  				if fieldCreateValue != nil {
   292  					v[fieldNameCreate] = fieldCreateValue
   293  				}
   294  			}
   295  			if fieldNameUpdate != "" {
   296  				fieldUpdateValue := stm.GetValueByFieldTypeForCreateOrUpdate(ctx, fieldTypeUpdate, false)
   297  				if fieldUpdateValue != nil {
   298  					v[fieldNameUpdate] = fieldUpdateValue
   299  				}
   300  			}
   301  			if fieldNameDelete != "" {
   302  				fieldDeleteValue := stm.GetValueByFieldTypeForCreateOrUpdate(ctx, fieldTypeDelete, true)
   303  				if fieldDeleteValue != nil {
   304  					v[fieldNameDelete] = fieldDeleteValue
   305  				}
   306  			}
   307  			list[k] = v
   308  		}
   309  	}
   310  	// Format DoInsertOption, especially for "ON DUPLICATE KEY UPDATE" statement.
   311  	columnNames := make([]string, 0, len(list[0]))
   312  	for k := range list[0] {
   313  		columnNames = append(columnNames, k)
   314  	}
   315  	doInsertOption, err := m.formatDoInsertOption(insertOption, columnNames)
   316  	if err != nil {
   317  		return result, err
   318  	}
   319  
   320  	in := &HookInsertInput{
   321  		internalParamHookInsert: internalParamHookInsert{
   322  			internalParamHook: internalParamHook{
   323  				link: m.getLink(true),
   324  			},
   325  			handler: m.hookHandler.Insert,
   326  		},
   327  		Model:  m,
   328  		Table:  m.tables,
   329  		Data:   list,
   330  		Option: doInsertOption,
   331  	}
   332  	return in.Next(ctx)
   333  }
   334  
   335  func (m *Model) formatDoInsertOption(insertOption InsertOption, columnNames []string) (option DoInsertOption, err error) {
   336  	option = DoInsertOption{
   337  		InsertOption: insertOption,
   338  		BatchCount:   m.getBatch(),
   339  	}
   340  	if insertOption != InsertOptionSave {
   341  		return
   342  	}
   343  
   344  	onConflictKeys, err := m.formatOnConflictKeys(m.onConflict)
   345  	if err != nil {
   346  		return option, err
   347  	}
   348  	option.OnConflict = onConflictKeys
   349  
   350  	onDuplicateExKeys, err := m.formatOnDuplicateExKeys(m.onDuplicateEx)
   351  	if err != nil {
   352  		return option, err
   353  	}
   354  	onDuplicateExKeySet := gset.NewStrSetFrom(onDuplicateExKeys)
   355  	if m.onDuplicate != nil {
   356  		switch m.onDuplicate.(type) {
   357  		case Raw, *Raw:
   358  			option.OnDuplicateStr = gconv.String(m.onDuplicate)
   359  
   360  		default:
   361  			reflectInfo := reflection.OriginValueAndKind(m.onDuplicate)
   362  			switch reflectInfo.OriginKind {
   363  			case reflect.String:
   364  				option.OnDuplicateMap = make(map[string]interface{})
   365  				for _, v := range gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ",") {
   366  					if onDuplicateExKeySet.Contains(v) {
   367  						continue
   368  					}
   369  					option.OnDuplicateMap[v] = v
   370  				}
   371  
   372  			case reflect.Map:
   373  				option.OnDuplicateMap = make(map[string]interface{})
   374  				for k, v := range gconv.Map(m.onDuplicate) {
   375  					if onDuplicateExKeySet.Contains(k) {
   376  						continue
   377  					}
   378  					option.OnDuplicateMap[k] = v
   379  				}
   380  
   381  			case reflect.Slice, reflect.Array:
   382  				option.OnDuplicateMap = make(map[string]interface{})
   383  				for _, v := range gconv.Strings(m.onDuplicate) {
   384  					if onDuplicateExKeySet.Contains(v) {
   385  						continue
   386  					}
   387  					option.OnDuplicateMap[v] = v
   388  				}
   389  
   390  			default:
   391  				return option, gerror.NewCodef(
   392  					gcode.CodeInvalidParameter,
   393  					`unsupported OnDuplicate parameter type "%s"`,
   394  					reflect.TypeOf(m.onDuplicate),
   395  				)
   396  			}
   397  		}
   398  	} else if onDuplicateExKeySet.Size() > 0 {
   399  		option.OnDuplicateMap = make(map[string]interface{})
   400  		for _, v := range columnNames {
   401  			if onDuplicateExKeySet.Contains(v) {
   402  				continue
   403  			}
   404  			option.OnDuplicateMap[v] = v
   405  		}
   406  	}
   407  	return
   408  }
   409  
   410  func (m *Model) formatOnDuplicateExKeys(onDuplicateEx interface{}) ([]string, error) {
   411  	if onDuplicateEx == nil {
   412  		return nil, nil
   413  	}
   414  
   415  	reflectInfo := reflection.OriginValueAndKind(onDuplicateEx)
   416  	switch reflectInfo.OriginKind {
   417  	case reflect.String:
   418  		return gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ","), nil
   419  
   420  	case reflect.Map:
   421  		return gutil.Keys(onDuplicateEx), nil
   422  
   423  	case reflect.Slice, reflect.Array:
   424  		return gconv.Strings(onDuplicateEx), nil
   425  
   426  	default:
   427  		return nil, gerror.NewCodef(
   428  			gcode.CodeInvalidParameter,
   429  			`unsupported OnDuplicateEx parameter type "%s"`,
   430  			reflect.TypeOf(onDuplicateEx),
   431  		)
   432  	}
   433  }
   434  
   435  func (m *Model) formatOnConflictKeys(onConflict interface{}) ([]string, error) {
   436  	if onConflict == nil {
   437  		return nil, nil
   438  	}
   439  
   440  	reflectInfo := reflection.OriginValueAndKind(onConflict)
   441  	switch reflectInfo.OriginKind {
   442  	case reflect.String:
   443  		return gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ","), nil
   444  
   445  	case reflect.Slice, reflect.Array:
   446  		return gconv.Strings(onConflict), nil
   447  
   448  	default:
   449  		return nil, gerror.NewCodef(
   450  			gcode.CodeInvalidParameter,
   451  			`unsupported onConflict parameter type "%s"`,
   452  			reflect.TypeOf(onConflict),
   453  		)
   454  	}
   455  }
   456  
   457  func (m *Model) getBatch() int {
   458  	return m.batch
   459  }