github.com/systematiccaos/gorm@v1.22.6/statement.go (about)

     1  package gorm
     2  
     3  import (
     4  	"context"
     5  	"database/sql"
     6  	"database/sql/driver"
     7  	"fmt"
     8  	"reflect"
     9  	"regexp"
    10  	"sort"
    11  	"strconv"
    12  	"strings"
    13  	"sync"
    14  
    15  	"github.com/systematiccaos/gorm/clause"
    16  	"github.com/systematiccaos/gorm/logger"
    17  	"github.com/systematiccaos/gorm/schema"
    18  	"github.com/systematiccaos/gorm/utils"
    19  )
    20  
    21  // Statement statement
    22  type Statement struct {
    23  	*DB
    24  	TableExpr            *clause.Expr
    25  	Table                string
    26  	Model                interface{}
    27  	Unscoped             bool
    28  	Dest                 interface{}
    29  	ReflectValue         reflect.Value
    30  	Clauses              map[string]clause.Clause
    31  	BuildClauses         []string
    32  	Distinct             bool
    33  	Selects              []string // selected columns
    34  	Omits                []string // omit columns
    35  	Joins                []join
    36  	Preloads             map[string][]interface{}
    37  	Settings             sync.Map
    38  	ConnPool             ConnPool
    39  	Schema               *schema.Schema
    40  	Context              context.Context
    41  	RaiseErrorOnNotFound bool
    42  	SkipHooks            bool
    43  	SQL                  strings.Builder
    44  	Vars                 []interface{}
    45  	CurDestIndex         int
    46  	attrs                []interface{}
    47  	assigns              []interface{}
    48  	scopes               []func(*DB) *DB
    49  }
    50  
    51  type join struct {
    52  	Name  string
    53  	Conds []interface{}
    54  	On    *clause.Where
    55  }
    56  
    57  // StatementModifier statement modifier interface
    58  type StatementModifier interface {
    59  	ModifyStatement(*Statement)
    60  }
    61  
    62  // WriteString write string
    63  func (stmt *Statement) WriteString(str string) (int, error) {
    64  	return stmt.SQL.WriteString(str)
    65  }
    66  
    67  // WriteByte write byte
    68  func (stmt *Statement) WriteByte(c byte) error {
    69  	return stmt.SQL.WriteByte(c)
    70  }
    71  
    72  // WriteQuoted write quoted value
    73  func (stmt *Statement) WriteQuoted(value interface{}) {
    74  	stmt.QuoteTo(&stmt.SQL, value)
    75  }
    76  
    77  // QuoteTo write quoted value to writer
    78  func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
    79  	write := func(raw bool, str string) {
    80  		if raw {
    81  			writer.WriteString(str)
    82  		} else {
    83  			stmt.DB.Dialector.QuoteTo(writer, str)
    84  		}
    85  	}
    86  
    87  	switch v := field.(type) {
    88  	case clause.Table:
    89  		if v.Name == clause.CurrentTable {
    90  			if stmt.TableExpr != nil {
    91  				stmt.TableExpr.Build(stmt)
    92  			} else {
    93  				write(v.Raw, stmt.Table)
    94  			}
    95  		} else {
    96  			write(v.Raw, v.Name)
    97  		}
    98  
    99  		if v.Alias != "" {
   100  			writer.WriteByte(' ')
   101  			write(v.Raw, v.Alias)
   102  		}
   103  	case clause.Column:
   104  		if v.Table != "" {
   105  			if v.Table == clause.CurrentTable {
   106  				write(v.Raw, stmt.Table)
   107  			} else {
   108  				write(v.Raw, v.Table)
   109  			}
   110  			writer.WriteByte('.')
   111  		}
   112  
   113  		if v.Name == clause.PrimaryKey {
   114  			if stmt.Schema == nil {
   115  				stmt.DB.AddError(ErrModelValueRequired)
   116  			} else if stmt.Schema.PrioritizedPrimaryField != nil {
   117  				write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName)
   118  			} else if len(stmt.Schema.DBNames) > 0 {
   119  				write(v.Raw, stmt.Schema.DBNames[0])
   120  			}
   121  		} else {
   122  			write(v.Raw, v.Name)
   123  		}
   124  
   125  		if v.Alias != "" {
   126  			writer.WriteString(" AS ")
   127  			write(v.Raw, v.Alias)
   128  		}
   129  	case []clause.Column:
   130  		writer.WriteByte('(')
   131  		for idx, d := range v {
   132  			if idx > 0 {
   133  				writer.WriteString(",")
   134  			}
   135  			stmt.QuoteTo(writer, d)
   136  		}
   137  		writer.WriteByte(')')
   138  	case clause.Expr:
   139  		v.Build(stmt)
   140  	case string:
   141  		stmt.DB.Dialector.QuoteTo(writer, v)
   142  	case []string:
   143  		writer.WriteByte('(')
   144  		for idx, d := range v {
   145  			if idx > 0 {
   146  				writer.WriteString(",")
   147  			}
   148  			stmt.DB.Dialector.QuoteTo(writer, d)
   149  		}
   150  		writer.WriteByte(')')
   151  	default:
   152  		stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field))
   153  	}
   154  }
   155  
   156  // Quote returns quoted value
   157  func (stmt *Statement) Quote(field interface{}) string {
   158  	var builder strings.Builder
   159  	stmt.QuoteTo(&builder, field)
   160  	return builder.String()
   161  }
   162  
   163  // AddVar add var
   164  func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) {
   165  	for idx, v := range vars {
   166  		if idx > 0 {
   167  			writer.WriteByte(',')
   168  		}
   169  
   170  		switch v := v.(type) {
   171  		case sql.NamedArg:
   172  			stmt.Vars = append(stmt.Vars, v.Value)
   173  		case clause.Column, clause.Table:
   174  			stmt.QuoteTo(writer, v)
   175  		case Valuer:
   176  			reflectValue := reflect.ValueOf(v)
   177  			if reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() {
   178  				stmt.AddVar(writer, nil)
   179  			} else {
   180  				stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB))
   181  			}
   182  		case clause.Expr:
   183  			v.Build(stmt)
   184  		case *clause.Expr:
   185  			v.Build(stmt)
   186  		case driver.Valuer:
   187  			stmt.Vars = append(stmt.Vars, v)
   188  			stmt.DB.Dialector.BindVarTo(writer, stmt, v)
   189  		case []byte:
   190  			stmt.Vars = append(stmt.Vars, v)
   191  			stmt.DB.Dialector.BindVarTo(writer, stmt, v)
   192  		case []interface{}:
   193  			if len(v) > 0 {
   194  				writer.WriteByte('(')
   195  				stmt.AddVar(writer, v...)
   196  				writer.WriteByte(')')
   197  			} else {
   198  				writer.WriteString("(NULL)")
   199  			}
   200  		case *DB:
   201  			subdb := v.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance()
   202  			if v.Statement.SQL.Len() > 0 {
   203  				var (
   204  					vars = subdb.Statement.Vars
   205  					sql  = v.Statement.SQL.String()
   206  				)
   207  
   208  				subdb.Statement.Vars = make([]interface{}, 0, len(vars))
   209  				for _, vv := range vars {
   210  					subdb.Statement.Vars = append(subdb.Statement.Vars, vv)
   211  					bindvar := strings.Builder{}
   212  					v.Dialector.BindVarTo(&bindvar, subdb.Statement, vv)
   213  					sql = strings.Replace(sql, bindvar.String(), "?", 1)
   214  				}
   215  
   216  				subdb.Statement.SQL.Reset()
   217  				subdb.Statement.Vars = stmt.Vars
   218  				if strings.Contains(sql, "@") {
   219  					clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement)
   220  				} else {
   221  					clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement)
   222  				}
   223  			} else {
   224  				subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...)
   225  				subdb.callbacks.Query().Execute(subdb)
   226  			}
   227  
   228  			writer.WriteString(subdb.Statement.SQL.String())
   229  			stmt.Vars = subdb.Statement.Vars
   230  		default:
   231  			switch rv := reflect.ValueOf(v); rv.Kind() {
   232  			case reflect.Slice, reflect.Array:
   233  				if rv.Len() == 0 {
   234  					writer.WriteString("(NULL)")
   235  				} else {
   236  					writer.WriteByte('(')
   237  					for i := 0; i < rv.Len(); i++ {
   238  						if i > 0 {
   239  							writer.WriteByte(',')
   240  						}
   241  						stmt.AddVar(writer, rv.Index(i).Interface())
   242  					}
   243  					writer.WriteByte(')')
   244  				}
   245  			default:
   246  				stmt.Vars = append(stmt.Vars, v)
   247  				stmt.DB.Dialector.BindVarTo(writer, stmt, v)
   248  			}
   249  		}
   250  	}
   251  }
   252  
   253  // AddClause add clause
   254  func (stmt *Statement) AddClause(v clause.Interface) {
   255  	if optimizer, ok := v.(StatementModifier); ok {
   256  		optimizer.ModifyStatement(stmt)
   257  	} else {
   258  		name := v.Name()
   259  		c := stmt.Clauses[name]
   260  		c.Name = name
   261  		v.MergeClause(&c)
   262  		stmt.Clauses[name] = c
   263  	}
   264  }
   265  
   266  // AddClauseIfNotExists add clause if not exists
   267  func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) {
   268  	if c, ok := stmt.Clauses[v.Name()]; !ok || c.Expression == nil {
   269  		stmt.AddClause(v)
   270  	}
   271  }
   272  
   273  // BuildCondition build condition
   274  func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression {
   275  	if s, ok := query.(string); ok {
   276  		// if it is a number, then treats it as primary key
   277  		if _, err := strconv.Atoi(s); err != nil {
   278  			if s == "" && len(args) == 0 {
   279  				return nil
   280  			}
   281  
   282  			if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) {
   283  				// looks like a where condition
   284  				return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
   285  			}
   286  
   287  			if len(args) > 0 && strings.Contains(s, "@") {
   288  				// looks like a named query
   289  				return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}}
   290  			}
   291  
   292  			if strings.Contains(strings.TrimSpace(s), " ") {
   293  				// looks like a where condition
   294  				return []clause.Expression{clause.Expr{SQL: s, Vars: args}}
   295  			}
   296  
   297  			if len(args) == 1 {
   298  				return []clause.Expression{clause.Eq{Column: s, Value: args[0]}}
   299  			}
   300  		}
   301  	}
   302  
   303  	conds := make([]clause.Expression, 0, 4)
   304  	args = append([]interface{}{query}, args...)
   305  	for idx, arg := range args {
   306  		if valuer, ok := arg.(driver.Valuer); ok {
   307  			arg, _ = valuer.Value()
   308  		}
   309  
   310  		switch v := arg.(type) {
   311  		case clause.Expression:
   312  			conds = append(conds, v)
   313  		case *DB:
   314  			if cs, ok := v.Statement.Clauses["WHERE"]; ok {
   315  				if where, ok := cs.Expression.(clause.Where); ok {
   316  					if len(where.Exprs) == 1 {
   317  						if orConds, ok := where.Exprs[0].(clause.OrConditions); ok {
   318  							where.Exprs[0] = clause.AndConditions(orConds)
   319  						}
   320  					}
   321  					conds = append(conds, clause.And(where.Exprs...))
   322  				} else if cs.Expression != nil {
   323  					conds = append(conds, cs.Expression)
   324  				}
   325  			}
   326  		case map[interface{}]interface{}:
   327  			for i, j := range v {
   328  				conds = append(conds, clause.Eq{Column: i, Value: j})
   329  			}
   330  		case map[string]string:
   331  			var keys = make([]string, 0, len(v))
   332  			for i := range v {
   333  				keys = append(keys, i)
   334  			}
   335  			sort.Strings(keys)
   336  
   337  			for _, key := range keys {
   338  				conds = append(conds, clause.Eq{Column: key, Value: v[key]})
   339  			}
   340  		case map[string]interface{}:
   341  			var keys = make([]string, 0, len(v))
   342  			for i := range v {
   343  				keys = append(keys, i)
   344  			}
   345  			sort.Strings(keys)
   346  
   347  			for _, key := range keys {
   348  				reflectValue := reflect.Indirect(reflect.ValueOf(v[key]))
   349  				switch reflectValue.Kind() {
   350  				case reflect.Slice, reflect.Array:
   351  					if _, ok := v[key].(driver.Valuer); ok {
   352  						conds = append(conds, clause.Eq{Column: key, Value: v[key]})
   353  					} else if _, ok := v[key].(Valuer); ok {
   354  						conds = append(conds, clause.Eq{Column: key, Value: v[key]})
   355  					} else {
   356  						// optimize reflect value length
   357  						valueLen := reflectValue.Len()
   358  						values := make([]interface{}, valueLen)
   359  						for i := 0; i < valueLen; i++ {
   360  							values[i] = reflectValue.Index(i).Interface()
   361  						}
   362  
   363  						conds = append(conds, clause.IN{Column: key, Values: values})
   364  					}
   365  				default:
   366  					conds = append(conds, clause.Eq{Column: key, Value: v[key]})
   367  				}
   368  			}
   369  		default:
   370  			reflectValue := reflect.Indirect(reflect.ValueOf(arg))
   371  			for reflectValue.Kind() == reflect.Ptr {
   372  				reflectValue = reflectValue.Elem()
   373  			}
   374  
   375  			if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil {
   376  				selectedColumns := map[string]bool{}
   377  				if idx == 0 {
   378  					for _, v := range args[1:] {
   379  						if vs, ok := v.(string); ok {
   380  							selectedColumns[vs] = true
   381  						}
   382  					}
   383  				}
   384  				restricted := len(selectedColumns) != 0
   385  
   386  				switch reflectValue.Kind() {
   387  				case reflect.Struct:
   388  					for _, field := range s.Fields {
   389  						selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
   390  						if selected || (!restricted && field.Readable) {
   391  							if v, isZero := field.ValueOf(reflectValue); !isZero || selected {
   392  								if field.DBName != "" {
   393  									conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
   394  								} else if field.DataType != "" {
   395  									conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
   396  								}
   397  							}
   398  						}
   399  					}
   400  				case reflect.Slice, reflect.Array:
   401  					for i := 0; i < reflectValue.Len(); i++ {
   402  						for _, field := range s.Fields {
   403  							selected := selectedColumns[field.DBName] || selectedColumns[field.Name]
   404  							if selected || (!restricted && field.Readable) {
   405  								if v, isZero := field.ValueOf(reflectValue.Index(i)); !isZero || selected {
   406  									if field.DBName != "" {
   407  										conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.DBName}, Value: v})
   408  									} else if field.DataType != "" {
   409  										conds = append(conds, clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: field.Name}, Value: v})
   410  									}
   411  								}
   412  							}
   413  						}
   414  					}
   415  				}
   416  
   417  				if restricted {
   418  					break
   419  				}
   420  			} else if !reflectValue.IsValid() {
   421  				stmt.AddError(ErrInvalidData)
   422  			} else if len(conds) == 0 {
   423  				if len(args) == 1 {
   424  					switch reflectValue.Kind() {
   425  					case reflect.Slice, reflect.Array:
   426  						// optimize reflect value length
   427  						valueLen := reflectValue.Len()
   428  						values := make([]interface{}, valueLen)
   429  						for i := 0; i < valueLen; i++ {
   430  							values[i] = reflectValue.Index(i).Interface()
   431  						}
   432  
   433  						if len(values) > 0 {
   434  							conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: values})
   435  						}
   436  						return conds
   437  					}
   438  				}
   439  
   440  				conds = append(conds, clause.IN{Column: clause.PrimaryColumn, Values: args})
   441  			}
   442  		}
   443  	}
   444  
   445  	return conds
   446  }
   447  
   448  // Build build sql with clauses names
   449  func (stmt *Statement) Build(clauses ...string) {
   450  	var firstClauseWritten bool
   451  
   452  	for _, name := range clauses {
   453  		if c, ok := stmt.Clauses[name]; ok {
   454  			if firstClauseWritten {
   455  				stmt.WriteByte(' ')
   456  			}
   457  
   458  			firstClauseWritten = true
   459  			if b, ok := stmt.DB.ClauseBuilders[name]; ok {
   460  				b(c, stmt)
   461  			} else {
   462  				c.Build(stmt)
   463  			}
   464  		}
   465  	}
   466  }
   467  
   468  func (stmt *Statement) Parse(value interface{}) (err error) {
   469  	return stmt.ParseWithSpecialTableName(value, "")
   470  }
   471  
   472  func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) {
   473  	if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" {
   474  		if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 {
   475  			stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)}
   476  			stmt.Table = tables[1]
   477  			return
   478  		}
   479  
   480  		stmt.Table = stmt.Schema.Table
   481  	}
   482  	return err
   483  }
   484  
   485  func (stmt *Statement) clone() *Statement {
   486  	newStmt := &Statement{
   487  		TableExpr:            stmt.TableExpr,
   488  		Table:                stmt.Table,
   489  		Model:                stmt.Model,
   490  		Unscoped:             stmt.Unscoped,
   491  		Dest:                 stmt.Dest,
   492  		ReflectValue:         stmt.ReflectValue,
   493  		Clauses:              map[string]clause.Clause{},
   494  		Distinct:             stmt.Distinct,
   495  		Selects:              stmt.Selects,
   496  		Omits:                stmt.Omits,
   497  		Preloads:             map[string][]interface{}{},
   498  		ConnPool:             stmt.ConnPool,
   499  		Schema:               stmt.Schema,
   500  		Context:              stmt.Context,
   501  		RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound,
   502  		SkipHooks:            stmt.SkipHooks,
   503  	}
   504  
   505  	if stmt.SQL.Len() > 0 {
   506  		newStmt.SQL.WriteString(stmt.SQL.String())
   507  		newStmt.Vars = make([]interface{}, 0, len(stmt.Vars))
   508  		newStmt.Vars = append(newStmt.Vars, stmt.Vars...)
   509  	}
   510  
   511  	for k, c := range stmt.Clauses {
   512  		newStmt.Clauses[k] = c
   513  	}
   514  
   515  	for k, p := range stmt.Preloads {
   516  		newStmt.Preloads[k] = p
   517  	}
   518  
   519  	if len(stmt.Joins) > 0 {
   520  		newStmt.Joins = make([]join, len(stmt.Joins))
   521  		copy(newStmt.Joins, stmt.Joins)
   522  	}
   523  
   524  	if len(stmt.scopes) > 0 {
   525  		newStmt.scopes = make([]func(*DB) *DB, len(stmt.scopes))
   526  		copy(newStmt.scopes, stmt.scopes)
   527  	}
   528  
   529  	stmt.Settings.Range(func(k, v interface{}) bool {
   530  		newStmt.Settings.Store(k, v)
   531  		return true
   532  	})
   533  
   534  	return newStmt
   535  }
   536  
   537  // SetColumn set column's value
   538  //   stmt.SetColumn("Name", "jinzhu") // Hooks Method
   539  //   stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method
   540  func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) {
   541  	if v, ok := stmt.Dest.(map[string]interface{}); ok {
   542  		v[name] = value
   543  	} else if v, ok := stmt.Dest.([]map[string]interface{}); ok {
   544  		for _, m := range v {
   545  			m[name] = value
   546  		}
   547  	} else if stmt.Schema != nil {
   548  		if field := stmt.Schema.LookUpField(name); field != nil {
   549  			destValue := reflect.ValueOf(stmt.Dest)
   550  			for destValue.Kind() == reflect.Ptr {
   551  				destValue = destValue.Elem()
   552  			}
   553  
   554  			if stmt.ReflectValue != destValue {
   555  				if !destValue.CanAddr() {
   556  					destValueCanAddr := reflect.New(destValue.Type())
   557  					destValueCanAddr.Elem().Set(destValue)
   558  					stmt.Dest = destValueCanAddr.Interface()
   559  					destValue = destValueCanAddr.Elem()
   560  				}
   561  
   562  				switch destValue.Kind() {
   563  				case reflect.Struct:
   564  					field.Set(destValue, value)
   565  				default:
   566  					stmt.AddError(ErrInvalidData)
   567  				}
   568  			}
   569  
   570  			switch stmt.ReflectValue.Kind() {
   571  			case reflect.Slice, reflect.Array:
   572  				if len(fromCallbacks) > 0 {
   573  					for i := 0; i < stmt.ReflectValue.Len(); i++ {
   574  						field.Set(stmt.ReflectValue.Index(i), value)
   575  					}
   576  				} else {
   577  					field.Set(stmt.ReflectValue.Index(stmt.CurDestIndex), value)
   578  				}
   579  			case reflect.Struct:
   580  				if !stmt.ReflectValue.CanAddr() {
   581  					stmt.AddError(ErrInvalidValue)
   582  					return
   583  				}
   584  
   585  				field.Set(stmt.ReflectValue, value)
   586  			}
   587  		} else {
   588  			stmt.AddError(ErrInvalidField)
   589  		}
   590  	} else {
   591  		stmt.AddError(ErrInvalidField)
   592  	}
   593  }
   594  
   595  // Changed check model changed or not when updating
   596  func (stmt *Statement) Changed(fields ...string) bool {
   597  	modelValue := stmt.ReflectValue
   598  	switch modelValue.Kind() {
   599  	case reflect.Slice, reflect.Array:
   600  		modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex)
   601  	}
   602  
   603  	selectColumns, restricted := stmt.SelectAndOmitColumns(false, true)
   604  	changed := func(field *schema.Field) bool {
   605  		fieldValue, _ := field.ValueOf(modelValue)
   606  		if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) {
   607  			if v, ok := stmt.Dest.(map[string]interface{}); ok {
   608  				if fv, ok := v[field.Name]; ok {
   609  					return !utils.AssertEqual(fv, fieldValue)
   610  				} else if fv, ok := v[field.DBName]; ok {
   611  					return !utils.AssertEqual(fv, fieldValue)
   612  				}
   613  			} else {
   614  				destValue := reflect.ValueOf(stmt.Dest)
   615  				for destValue.Kind() == reflect.Ptr {
   616  					destValue = destValue.Elem()
   617  				}
   618  
   619  				changedValue, zero := field.ValueOf(destValue)
   620  				return !zero && !utils.AssertEqual(changedValue, fieldValue)
   621  			}
   622  		}
   623  		return false
   624  	}
   625  
   626  	if len(fields) == 0 {
   627  		for _, field := range stmt.Schema.FieldsByDBName {
   628  			if changed(field) {
   629  				return true
   630  			}
   631  		}
   632  	} else {
   633  		for _, name := range fields {
   634  			if field := stmt.Schema.LookUpField(name); field != nil {
   635  				if changed(field) {
   636  					return true
   637  				}
   638  			}
   639  		}
   640  	}
   641  
   642  	return false
   643  }
   644  
   645  var nameMatcher = regexp.MustCompile(`^[\W]?(?:[a-z_]+?)[\W]?\.[\W]?([a-z_]+?)[\W]?$`)
   646  
   647  // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false
   648  func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) {
   649  	results := map[string]bool{}
   650  	notRestricted := false
   651  
   652  	// select columns
   653  	for _, column := range stmt.Selects {
   654  		if stmt.Schema == nil {
   655  			results[column] = true
   656  		} else if column == "*" {
   657  			notRestricted = true
   658  			for _, dbName := range stmt.Schema.DBNames {
   659  				results[dbName] = true
   660  			}
   661  		} else if column == clause.Associations {
   662  			for _, rel := range stmt.Schema.Relationships.Relations {
   663  				results[rel.Name] = true
   664  			}
   665  		} else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" {
   666  			results[field.DBName] = true
   667  		} else if matches := nameMatcher.FindStringSubmatch(column); len(matches) == 2 {
   668  			results[matches[1]] = true
   669  		} else {
   670  			results[column] = true
   671  		}
   672  	}
   673  
   674  	// omit columns
   675  	for _, omit := range stmt.Omits {
   676  		if stmt.Schema == nil {
   677  			results[omit] = false
   678  		} else if omit == "*" {
   679  			for _, dbName := range stmt.Schema.DBNames {
   680  				results[dbName] = false
   681  			}
   682  		} else if omit == clause.Associations {
   683  			for _, rel := range stmt.Schema.Relationships.Relations {
   684  				results[rel.Name] = false
   685  			}
   686  		} else if field := stmt.Schema.LookUpField(omit); field != nil && field.DBName != "" {
   687  			results[field.DBName] = false
   688  		} else if matches := nameMatcher.FindStringSubmatch(omit); len(matches) == 2 {
   689  			results[matches[1]] = false
   690  		} else {
   691  			results[omit] = false
   692  		}
   693  	}
   694  
   695  	if stmt.Schema != nil {
   696  		for _, field := range stmt.Schema.FieldsByName {
   697  			name := field.DBName
   698  			if name == "" {
   699  				name = field.Name
   700  			}
   701  
   702  			if requireCreate && !field.Creatable {
   703  				results[name] = false
   704  			} else if requireUpdate && !field.Updatable {
   705  				results[name] = false
   706  			}
   707  		}
   708  	}
   709  
   710  	return results, !notRestricted && len(stmt.Selects) > 0
   711  }