github.com/gogf/gf/v2@v2.7.4/database/gdb/gdb_model_insert.go (about)

     1  // Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the MIT License.
     4  // If a copy of the MIT was not distributed with this file,
     5  // You can obtain one at https://github.com/gogf/gf.
     6  
     7  package gdb
     8  
     9  import (
    10  	"context"
    11  	"database/sql"
    12  	"reflect"
    13  
    14  	"github.com/gogf/gf/v2/container/gset"
    15  	"github.com/gogf/gf/v2/errors/gcode"
    16  	"github.com/gogf/gf/v2/errors/gerror"
    17  	"github.com/gogf/gf/v2/internal/empty"
    18  	"github.com/gogf/gf/v2/internal/reflection"
    19  	"github.com/gogf/gf/v2/text/gstr"
    20  	"github.com/gogf/gf/v2/util/gconv"
    21  	"github.com/gogf/gf/v2/util/gutil"
    22  )
    23  
    24  // Batch sets the batch operation number for the model.
    25  func (m *Model) Batch(batch int) *Model {
    26  	model := m.getModel()
    27  	model.batch = batch
    28  	return model
    29  }
    30  
    31  // Data sets the operation data for the model.
    32  // The parameter `data` can be type of string/map/gmap/slice/struct/*struct, etc.
    33  // Note that, it uses shallow value copying for `data` if `data` is type of map/slice
    34  // to avoid changing it inside function.
    35  // Eg:
    36  // Data("uid=10000")
    37  // Data("uid", 10000)
    38  // Data("uid=? AND name=?", 10000, "john")
    39  // Data(g.Map{"uid": 10000, "name":"john"})
    40  // Data(g.Slice{g.Map{"uid": 10000, "name":"john"}, g.Map{"uid": 20000, "name":"smith"}).
    41  func (m *Model) Data(data ...interface{}) *Model {
    42  	var model = m.getModel()
    43  	if len(data) > 1 {
    44  		if s := gconv.String(data[0]); gstr.Contains(s, "?") {
    45  			model.data = s
    46  			model.extraArgs = data[1:]
    47  		} else {
    48  			m := make(map[string]interface{})
    49  			for i := 0; i < len(data); i += 2 {
    50  				m[gconv.String(data[i])] = data[i+1]
    51  			}
    52  			model.data = m
    53  		}
    54  	} else if len(data) == 1 {
    55  		switch value := data[0].(type) {
    56  		case Result:
    57  			model.data = value.List()
    58  
    59  		case Record:
    60  			model.data = value.Map()
    61  
    62  		case List:
    63  			list := make(List, len(value))
    64  			for k, v := range value {
    65  				list[k] = gutil.MapCopy(v)
    66  			}
    67  			model.data = list
    68  
    69  		case Map:
    70  			model.data = gutil.MapCopy(value)
    71  
    72  		default:
    73  			reflectInfo := reflection.OriginValueAndKind(value)
    74  			switch reflectInfo.OriginKind {
    75  			case reflect.Slice, reflect.Array:
    76  				if reflectInfo.OriginValue.Len() > 0 {
    77  					// If the `data` parameter is a DO struct,
    78  					// it then adds `OmitNilData` option for this condition,
    79  					// which will filter all nil parameters in `data`.
    80  					if isDoStruct(reflectInfo.OriginValue.Index(0).Interface()) {
    81  						model = model.OmitNilData()
    82  						model.option |= optionOmitNilDataInternal
    83  					}
    84  				}
    85  				list := make(List, reflectInfo.OriginValue.Len())
    86  				for i := 0; i < reflectInfo.OriginValue.Len(); i++ {
    87  					list[i] = anyValueToMapBeforeToRecord(reflectInfo.OriginValue.Index(i).Interface())
    88  				}
    89  				model.data = list
    90  
    91  			case reflect.Struct:
    92  				// If the `data` parameter is a DO struct,
    93  				// it then adds `OmitNilData` option for this condition,
    94  				// which will filter all nil parameters in `data`.
    95  				if isDoStruct(value) {
    96  					model = model.OmitNilData()
    97  				}
    98  				if v, ok := data[0].(iInterfaces); ok {
    99  					var (
   100  						array = v.Interfaces()
   101  						list  = make(List, len(array))
   102  					)
   103  					for i := 0; i < len(array); i++ {
   104  						list[i] = anyValueToMapBeforeToRecord(array[i])
   105  					}
   106  					model.data = list
   107  				} else {
   108  					model.data = anyValueToMapBeforeToRecord(data[0])
   109  				}
   110  
   111  			case reflect.Map:
   112  				model.data = anyValueToMapBeforeToRecord(data[0])
   113  
   114  			default:
   115  				model.data = data[0]
   116  			}
   117  		}
   118  	}
   119  	return model
   120  }
   121  
   122  // OnConflict sets the primary key or index when columns conflicts occurs.
   123  // It's not necessary for MySQL driver.
   124  func (m *Model) OnConflict(onConflict ...interface{}) *Model {
   125  	if len(onConflict) == 0 {
   126  		return m
   127  	}
   128  	model := m.getModel()
   129  	if len(onConflict) > 1 {
   130  		model.onConflict = onConflict
   131  	} else if len(onConflict) == 1 {
   132  		model.onConflict = onConflict[0]
   133  	}
   134  	return model
   135  }
   136  
   137  // OnDuplicate sets the operations when columns conflicts occurs.
   138  // In MySQL, this is used for "ON DUPLICATE KEY UPDATE" statement.
   139  // In PgSQL, this is used for "ON CONFLICT (id) DO UPDATE SET" statement.
   140  // The parameter `onDuplicate` can be type of string/Raw/*Raw/map/slice.
   141  // Example:
   142  //
   143  // OnDuplicate("nickname, age")
   144  // OnDuplicate("nickname", "age")
   145  //
   146  //	OnDuplicate(g.Map{
   147  //		  "nickname": gdb.Raw("CONCAT('name_', VALUES(`nickname`))"),
   148  //	})
   149  //
   150  //	OnDuplicate(g.Map{
   151  //		  "nickname": "passport",
   152  //	}).
   153  func (m *Model) OnDuplicate(onDuplicate ...interface{}) *Model {
   154  	if len(onDuplicate) == 0 {
   155  		return m
   156  	}
   157  	model := m.getModel()
   158  	if len(onDuplicate) > 1 {
   159  		model.onDuplicate = onDuplicate
   160  	} else if len(onDuplicate) == 1 {
   161  		model.onDuplicate = onDuplicate[0]
   162  	}
   163  	return model
   164  }
   165  
   166  // OnDuplicateEx sets the excluding columns for operations when columns conflict occurs.
   167  // In MySQL, this is used for "ON DUPLICATE KEY UPDATE" statement.
   168  // In PgSQL, this is used for "ON CONFLICT (id) DO UPDATE SET" statement.
   169  // The parameter `onDuplicateEx` can be type of string/map/slice.
   170  // Example:
   171  //
   172  // OnDuplicateEx("passport, password")
   173  // OnDuplicateEx("passport", "password")
   174  //
   175  //	OnDuplicateEx(g.Map{
   176  //		  "passport": "",
   177  //		  "password": "",
   178  //	}).
   179  func (m *Model) OnDuplicateEx(onDuplicateEx ...interface{}) *Model {
   180  	if len(onDuplicateEx) == 0 {
   181  		return m
   182  	}
   183  	model := m.getModel()
   184  	if len(onDuplicateEx) > 1 {
   185  		model.onDuplicateEx = onDuplicateEx
   186  	} else if len(onDuplicateEx) == 1 {
   187  		model.onDuplicateEx = onDuplicateEx[0]
   188  	}
   189  	return model
   190  }
   191  
   192  // Insert does "INSERT INTO ..." statement for the model.
   193  // The optional parameter `data` is the same as the parameter of Model.Data function,
   194  // see Model.Data.
   195  func (m *Model) Insert(data ...interface{}) (result sql.Result, err error) {
   196  	var ctx = m.GetCtx()
   197  	if len(data) > 0 {
   198  		return m.Data(data...).Insert()
   199  	}
   200  	return m.doInsertWithOption(ctx, InsertOptionDefault)
   201  }
   202  
   203  // InsertAndGetId performs action Insert and returns the last insert id that automatically generated.
   204  func (m *Model) InsertAndGetId(data ...interface{}) (lastInsertId int64, err error) {
   205  	var ctx = m.GetCtx()
   206  	if len(data) > 0 {
   207  		return m.Data(data...).InsertAndGetId()
   208  	}
   209  	result, err := m.doInsertWithOption(ctx, InsertOptionDefault)
   210  	if err != nil {
   211  		return 0, err
   212  	}
   213  	return result.LastInsertId()
   214  }
   215  
   216  // InsertIgnore does "INSERT IGNORE INTO ..." statement for the model.
   217  // The optional parameter `data` is the same as the parameter of Model.Data function,
   218  // see Model.Data.
   219  func (m *Model) InsertIgnore(data ...interface{}) (result sql.Result, err error) {
   220  	var ctx = m.GetCtx()
   221  	if len(data) > 0 {
   222  		return m.Data(data...).InsertIgnore()
   223  	}
   224  	return m.doInsertWithOption(ctx, InsertOptionIgnore)
   225  }
   226  
   227  // Replace does "REPLACE INTO ..." statement for the model.
   228  // The optional parameter `data` is the same as the parameter of Model.Data function,
   229  // see Model.Data.
   230  func (m *Model) Replace(data ...interface{}) (result sql.Result, err error) {
   231  	var ctx = m.GetCtx()
   232  	if len(data) > 0 {
   233  		return m.Data(data...).Replace()
   234  	}
   235  	return m.doInsertWithOption(ctx, InsertOptionReplace)
   236  }
   237  
   238  // Save does "INSERT INTO ... ON DUPLICATE KEY UPDATE..." statement for the model.
   239  // The optional parameter `data` is the same as the parameter of Model.Data function,
   240  // see Model.Data.
   241  //
   242  // It updates the record if there's primary or unique index in the saving data,
   243  // or else it inserts a new record into the table.
   244  func (m *Model) Save(data ...interface{}) (result sql.Result, err error) {
   245  	var ctx = m.GetCtx()
   246  	if len(data) > 0 {
   247  		return m.Data(data...).Save()
   248  	}
   249  	return m.doInsertWithOption(ctx, InsertOptionSave)
   250  }
   251  
   252  // doInsertWithOption inserts data with option parameter.
   253  func (m *Model) doInsertWithOption(ctx context.Context, insertOption InsertOption) (result sql.Result, err error) {
   254  	defer func() {
   255  		if err == nil {
   256  			m.checkAndRemoveSelectCache(ctx)
   257  		}
   258  	}()
   259  	if m.data == nil {
   260  		return nil, gerror.NewCode(gcode.CodeMissingParameter, "inserting into table with empty data")
   261  	}
   262  	var (
   263  		list                             List
   264  		stm                              = m.softTimeMaintainer()
   265  		fieldNameCreate, fieldTypeCreate = stm.GetFieldNameAndTypeForCreate(ctx, "", m.tablesInit)
   266  		fieldNameUpdate, fieldTypeUpdate = stm.GetFieldNameAndTypeForUpdate(ctx, "", m.tablesInit)
   267  		fieldNameDelete, fieldTypeDelete = stm.GetFieldNameAndTypeForDelete(ctx, "", m.tablesInit)
   268  	)
   269  	// m.data was already converted to type List/Map by function Data
   270  	newData, err := m.filterDataForInsertOrUpdate(m.data)
   271  	if err != nil {
   272  		return nil, err
   273  	}
   274  	// It converts any data to List type for inserting.
   275  	switch value := newData.(type) {
   276  	case List:
   277  		list = value
   278  
   279  	case Map:
   280  		list = List{value}
   281  	}
   282  
   283  	if len(list) < 1 {
   284  		return result, gerror.NewCode(gcode.CodeMissingParameter, "data list cannot be empty")
   285  	}
   286  
   287  	// Automatic handling for creating/updating time.
   288  	if fieldNameCreate != "" && m.isFieldInFieldsEx(fieldNameCreate) {
   289  		fieldNameCreate = ""
   290  	}
   291  	if fieldNameUpdate != "" && m.isFieldInFieldsEx(fieldNameUpdate) {
   292  		fieldNameUpdate = ""
   293  	}
   294  	var isSoftTimeFeatureEnabled = fieldNameCreate != "" || fieldNameUpdate != ""
   295  	if !m.unscoped && isSoftTimeFeatureEnabled {
   296  		for k, v := range list {
   297  			if fieldNameCreate != "" && empty.IsNil(v[fieldNameCreate]) {
   298  				fieldCreateValue := stm.GetValueByFieldTypeForCreateOrUpdate(ctx, fieldTypeCreate, false)
   299  				if fieldCreateValue != nil {
   300  					v[fieldNameCreate] = fieldCreateValue
   301  				}
   302  			}
   303  			if fieldNameUpdate != "" && empty.IsNil(v[fieldNameUpdate]) {
   304  				fieldUpdateValue := stm.GetValueByFieldTypeForCreateOrUpdate(ctx, fieldTypeUpdate, false)
   305  				if fieldUpdateValue != nil {
   306  					v[fieldNameUpdate] = fieldUpdateValue
   307  				}
   308  			}
   309  			// for timestamp field that should initialize the delete_at field with value, for example 0.
   310  			if fieldNameDelete != "" && empty.IsNil(v[fieldNameDelete]) {
   311  				fieldDeleteValue := stm.GetValueByFieldTypeForCreateOrUpdate(ctx, fieldTypeDelete, true)
   312  				if fieldDeleteValue != nil {
   313  					v[fieldNameDelete] = fieldDeleteValue
   314  				}
   315  			}
   316  			list[k] = v
   317  		}
   318  	}
   319  	// Format DoInsertOption, especially for "ON DUPLICATE KEY UPDATE" statement.
   320  	columnNames := make([]string, 0, len(list[0]))
   321  	for k := range list[0] {
   322  		columnNames = append(columnNames, k)
   323  	}
   324  	doInsertOption, err := m.formatDoInsertOption(insertOption, columnNames)
   325  	if err != nil {
   326  		return result, err
   327  	}
   328  
   329  	in := &HookInsertInput{
   330  		internalParamHookInsert: internalParamHookInsert{
   331  			internalParamHook: internalParamHook{
   332  				link: m.getLink(true),
   333  			},
   334  			handler: m.hookHandler.Insert,
   335  		},
   336  		Model:  m,
   337  		Table:  m.tables,
   338  		Data:   list,
   339  		Option: doInsertOption,
   340  	}
   341  	return in.Next(ctx)
   342  }
   343  
   344  func (m *Model) formatDoInsertOption(insertOption InsertOption, columnNames []string) (option DoInsertOption, err error) {
   345  	option = DoInsertOption{
   346  		InsertOption: insertOption,
   347  		BatchCount:   m.getBatch(),
   348  	}
   349  	if insertOption != InsertOptionSave {
   350  		return
   351  	}
   352  
   353  	onConflictKeys, err := m.formatOnConflictKeys(m.onConflict)
   354  	if err != nil {
   355  		return option, err
   356  	}
   357  	option.OnConflict = onConflictKeys
   358  
   359  	onDuplicateExKeys, err := m.formatOnDuplicateExKeys(m.onDuplicateEx)
   360  	if err != nil {
   361  		return option, err
   362  	}
   363  	onDuplicateExKeySet := gset.NewStrSetFrom(onDuplicateExKeys)
   364  	if m.onDuplicate != nil {
   365  		switch m.onDuplicate.(type) {
   366  		case Raw, *Raw:
   367  			option.OnDuplicateStr = gconv.String(m.onDuplicate)
   368  
   369  		default:
   370  			reflectInfo := reflection.OriginValueAndKind(m.onDuplicate)
   371  			switch reflectInfo.OriginKind {
   372  			case reflect.String:
   373  				option.OnDuplicateMap = make(map[string]interface{})
   374  				for _, v := range gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ",") {
   375  					if onDuplicateExKeySet.Contains(v) {
   376  						continue
   377  					}
   378  					option.OnDuplicateMap[v] = v
   379  				}
   380  
   381  			case reflect.Map:
   382  				option.OnDuplicateMap = make(map[string]interface{})
   383  				for k, v := range gconv.Map(m.onDuplicate) {
   384  					if onDuplicateExKeySet.Contains(k) {
   385  						continue
   386  					}
   387  					option.OnDuplicateMap[k] = v
   388  				}
   389  
   390  			case reflect.Slice, reflect.Array:
   391  				option.OnDuplicateMap = make(map[string]interface{})
   392  				for _, v := range gconv.Strings(m.onDuplicate) {
   393  					if onDuplicateExKeySet.Contains(v) {
   394  						continue
   395  					}
   396  					option.OnDuplicateMap[v] = v
   397  				}
   398  
   399  			default:
   400  				return option, gerror.NewCodef(
   401  					gcode.CodeInvalidParameter,
   402  					`unsupported OnDuplicate parameter type "%s"`,
   403  					reflect.TypeOf(m.onDuplicate),
   404  				)
   405  			}
   406  		}
   407  	} else if onDuplicateExKeySet.Size() > 0 {
   408  		option.OnDuplicateMap = make(map[string]interface{})
   409  		for _, v := range columnNames {
   410  			if onDuplicateExKeySet.Contains(v) {
   411  				continue
   412  			}
   413  			option.OnDuplicateMap[v] = v
   414  		}
   415  	}
   416  	return
   417  }
   418  
   419  func (m *Model) formatOnDuplicateExKeys(onDuplicateEx interface{}) ([]string, error) {
   420  	if onDuplicateEx == nil {
   421  		return nil, nil
   422  	}
   423  
   424  	reflectInfo := reflection.OriginValueAndKind(onDuplicateEx)
   425  	switch reflectInfo.OriginKind {
   426  	case reflect.String:
   427  		return gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ","), nil
   428  
   429  	case reflect.Map:
   430  		return gutil.Keys(onDuplicateEx), nil
   431  
   432  	case reflect.Slice, reflect.Array:
   433  		return gconv.Strings(onDuplicateEx), nil
   434  
   435  	default:
   436  		return nil, gerror.NewCodef(
   437  			gcode.CodeInvalidParameter,
   438  			`unsupported OnDuplicateEx parameter type "%s"`,
   439  			reflect.TypeOf(onDuplicateEx),
   440  		)
   441  	}
   442  }
   443  
   444  func (m *Model) formatOnConflictKeys(onConflict interface{}) ([]string, error) {
   445  	if onConflict == nil {
   446  		return nil, nil
   447  	}
   448  
   449  	reflectInfo := reflection.OriginValueAndKind(onConflict)
   450  	switch reflectInfo.OriginKind {
   451  	case reflect.String:
   452  		return gstr.SplitAndTrim(reflectInfo.OriginValue.String(), ","), nil
   453  
   454  	case reflect.Slice, reflect.Array:
   455  		return gconv.Strings(onConflict), nil
   456  
   457  	default:
   458  		return nil, gerror.NewCodef(
   459  			gcode.CodeInvalidParameter,
   460  			`unsupported onConflict parameter type "%s"`,
   461  			reflect.TypeOf(onConflict),
   462  		)
   463  	}
   464  }
   465  
   466  func (m *Model) getBatch() int {
   467  	return m.batch
   468  }