github.com/systematiccaos/gorm@v1.22.6/clause/expression.go (about)

     1  package clause
     2  
     3  import (
     4  	"database/sql"
     5  	"database/sql/driver"
     6  	"go/ast"
     7  	"reflect"
     8  )
     9  
    10  // Expression expression interface
    11  type Expression interface {
    12  	Build(builder Builder)
    13  }
    14  
    15  // NegationExpressionBuilder negation expression builder
    16  type NegationExpressionBuilder interface {
    17  	NegationBuild(builder Builder)
    18  }
    19  
    20  // Expr raw expression
    21  type Expr struct {
    22  	SQL                string
    23  	Vars               []interface{}
    24  	WithoutParentheses bool
    25  }
    26  
    27  // Build build raw expression
    28  func (expr Expr) Build(builder Builder) {
    29  	var (
    30  		afterParenthesis bool
    31  		idx              int
    32  	)
    33  
    34  	for _, v := range []byte(expr.SQL) {
    35  		if v == '?' && len(expr.Vars) > idx {
    36  			if afterParenthesis || expr.WithoutParentheses {
    37  				if _, ok := expr.Vars[idx].(driver.Valuer); ok {
    38  					builder.AddVar(builder, expr.Vars[idx])
    39  				} else {
    40  					switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
    41  					case reflect.Slice, reflect.Array:
    42  						if rv.Len() == 0 {
    43  							builder.AddVar(builder, nil)
    44  						} else {
    45  							for i := 0; i < rv.Len(); i++ {
    46  								if i > 0 {
    47  									builder.WriteByte(',')
    48  								}
    49  								builder.AddVar(builder, rv.Index(i).Interface())
    50  							}
    51  						}
    52  					default:
    53  						builder.AddVar(builder, expr.Vars[idx])
    54  					}
    55  				}
    56  			} else {
    57  				builder.AddVar(builder, expr.Vars[idx])
    58  			}
    59  
    60  			idx++
    61  		} else {
    62  			if v == '(' {
    63  				afterParenthesis = true
    64  			} else {
    65  				afterParenthesis = false
    66  			}
    67  			builder.WriteByte(v)
    68  		}
    69  	}
    70  
    71  	if idx < len(expr.Vars) {
    72  		for _, v := range expr.Vars[idx:] {
    73  			builder.AddVar(builder, sql.NamedArg{Value: v})
    74  		}
    75  	}
    76  }
    77  
    78  // NamedExpr raw expression for named expr
    79  type NamedExpr struct {
    80  	SQL  string
    81  	Vars []interface{}
    82  }
    83  
    84  // Build build raw expression
    85  func (expr NamedExpr) Build(builder Builder) {
    86  	var (
    87  		idx              int
    88  		inName           bool
    89  		afterParenthesis bool
    90  		namedMap         = make(map[string]interface{}, len(expr.Vars))
    91  	)
    92  
    93  	for _, v := range expr.Vars {
    94  		switch value := v.(type) {
    95  		case sql.NamedArg:
    96  			namedMap[value.Name] = value.Value
    97  		case map[string]interface{}:
    98  			for k, v := range value {
    99  				namedMap[k] = v
   100  			}
   101  		default:
   102  			var appendFieldsToMap func(reflect.Value)
   103  			appendFieldsToMap = func(reflectValue reflect.Value) {
   104  				reflectValue = reflect.Indirect(reflectValue)
   105  				switch reflectValue.Kind() {
   106  				case reflect.Struct:
   107  					modelType := reflectValue.Type()
   108  					for i := 0; i < modelType.NumField(); i++ {
   109  						if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) {
   110  							namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface()
   111  
   112  							if fieldStruct.Anonymous {
   113  								appendFieldsToMap(reflectValue.Field(i))
   114  							}
   115  						}
   116  					}
   117  				}
   118  			}
   119  
   120  			appendFieldsToMap(reflect.ValueOf(value))
   121  		}
   122  	}
   123  
   124  	name := make([]byte, 0, 10)
   125  
   126  	for _, v := range []byte(expr.SQL) {
   127  		if v == '@' && !inName {
   128  			inName = true
   129  			name = []byte{}
   130  		} else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\n' || v == ';' {
   131  			if inName {
   132  				if nv, ok := namedMap[string(name)]; ok {
   133  					builder.AddVar(builder, nv)
   134  				} else {
   135  					builder.WriteByte('@')
   136  					builder.WriteString(string(name))
   137  				}
   138  				inName = false
   139  			}
   140  
   141  			afterParenthesis = false
   142  			builder.WriteByte(v)
   143  		} else if v == '?' && len(expr.Vars) > idx {
   144  			if afterParenthesis {
   145  				if _, ok := expr.Vars[idx].(driver.Valuer); ok {
   146  					builder.AddVar(builder, expr.Vars[idx])
   147  				} else {
   148  					switch rv := reflect.ValueOf(expr.Vars[idx]); rv.Kind() {
   149  					case reflect.Slice, reflect.Array:
   150  						if rv.Len() == 0 {
   151  							builder.AddVar(builder, nil)
   152  						} else {
   153  							for i := 0; i < rv.Len(); i++ {
   154  								if i > 0 {
   155  									builder.WriteByte(',')
   156  								}
   157  								builder.AddVar(builder, rv.Index(i).Interface())
   158  							}
   159  						}
   160  					default:
   161  						builder.AddVar(builder, expr.Vars[idx])
   162  					}
   163  				}
   164  			} else {
   165  				builder.AddVar(builder, expr.Vars[idx])
   166  			}
   167  
   168  			idx++
   169  		} else if inName {
   170  			name = append(name, v)
   171  		} else {
   172  			if v == '(' {
   173  				afterParenthesis = true
   174  			} else {
   175  				afterParenthesis = false
   176  			}
   177  			builder.WriteByte(v)
   178  		}
   179  	}
   180  
   181  	if inName {
   182  		if nv, ok := namedMap[string(name)]; ok {
   183  			builder.AddVar(builder, nv)
   184  		} else {
   185  			builder.WriteByte('@')
   186  			builder.WriteString(string(name))
   187  		}
   188  	}
   189  }
   190  
   191  // IN Whether a value is within a set of values
   192  type IN struct {
   193  	Column interface{}
   194  	Values []interface{}
   195  }
   196  
   197  func (in IN) Build(builder Builder) {
   198  	builder.WriteQuoted(in.Column)
   199  
   200  	switch len(in.Values) {
   201  	case 0:
   202  		builder.WriteString(" IN (NULL)")
   203  	case 1:
   204  		if _, ok := in.Values[0].([]interface{}); !ok {
   205  			builder.WriteString(" = ")
   206  			builder.AddVar(builder, in.Values[0])
   207  			break
   208  		}
   209  
   210  		fallthrough
   211  	default:
   212  		builder.WriteString(" IN (")
   213  		builder.AddVar(builder, in.Values...)
   214  		builder.WriteByte(')')
   215  	}
   216  }
   217  
   218  func (in IN) NegationBuild(builder Builder) {
   219  	builder.WriteQuoted(in.Column)
   220  	switch len(in.Values) {
   221  	case 0:
   222  		builder.WriteString(" IS NOT NULL")
   223  	case 1:
   224  		if _, ok := in.Values[0].([]interface{}); !ok {
   225  			builder.WriteString(" <> ")
   226  			builder.AddVar(builder, in.Values[0])
   227  			break
   228  		}
   229  
   230  		fallthrough
   231  	default:
   232  		builder.WriteString(" NOT IN (")
   233  		builder.AddVar(builder, in.Values...)
   234  		builder.WriteByte(')')
   235  	}
   236  }
   237  
   238  // Eq equal to for where
   239  type Eq struct {
   240  	Column interface{}
   241  	Value  interface{}
   242  }
   243  
   244  func (eq Eq) Build(builder Builder) {
   245  	builder.WriteQuoted(eq.Column)
   246  
   247  	switch eq.Value.(type) {
   248  	case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
   249  		builder.WriteString(" IN (")
   250  		rv := reflect.ValueOf(eq.Value)
   251  		for i := 0; i < rv.Len(); i++ {
   252  			if i > 0 {
   253  				builder.WriteByte(',')
   254  			}
   255  			builder.AddVar(builder, rv.Index(i).Interface())
   256  		}
   257  		builder.WriteByte(')')
   258  	default:
   259  		if eqNil(eq.Value) {
   260  			builder.WriteString(" IS NULL")
   261  		} else {
   262  			builder.WriteString(" = ")
   263  			builder.AddVar(builder, eq.Value)
   264  		}
   265  	}
   266  }
   267  
   268  func (eq Eq) NegationBuild(builder Builder) {
   269  	Neq(eq).Build(builder)
   270  }
   271  
   272  // Neq not equal to for where
   273  type Neq Eq
   274  
   275  func (neq Neq) Build(builder Builder) {
   276  	builder.WriteQuoted(neq.Column)
   277  
   278  	switch neq.Value.(type) {
   279  	case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}:
   280  		builder.WriteString(" NOT IN (")
   281  		rv := reflect.ValueOf(neq.Value)
   282  		for i := 0; i < rv.Len(); i++ {
   283  			if i > 0 {
   284  				builder.WriteByte(',')
   285  			}
   286  			builder.AddVar(builder, rv.Index(i).Interface())
   287  		}
   288  		builder.WriteByte(')')
   289  	default:
   290  		if eqNil(neq.Value) {
   291  			builder.WriteString(" IS NOT NULL")
   292  		} else {
   293  			builder.WriteString(" <> ")
   294  			builder.AddVar(builder, neq.Value)
   295  		}
   296  	}
   297  }
   298  
   299  func (neq Neq) NegationBuild(builder Builder) {
   300  	Eq(neq).Build(builder)
   301  }
   302  
   303  // Gt greater than for where
   304  type Gt Eq
   305  
   306  func (gt Gt) Build(builder Builder) {
   307  	builder.WriteQuoted(gt.Column)
   308  	builder.WriteString(" > ")
   309  	builder.AddVar(builder, gt.Value)
   310  }
   311  
   312  func (gt Gt) NegationBuild(builder Builder) {
   313  	Lte(gt).Build(builder)
   314  }
   315  
   316  // Gte greater than or equal to for where
   317  type Gte Eq
   318  
   319  func (gte Gte) Build(builder Builder) {
   320  	builder.WriteQuoted(gte.Column)
   321  	builder.WriteString(" >= ")
   322  	builder.AddVar(builder, gte.Value)
   323  }
   324  
   325  func (gte Gte) NegationBuild(builder Builder) {
   326  	Lt(gte).Build(builder)
   327  }
   328  
   329  // Lt less than for where
   330  type Lt Eq
   331  
   332  func (lt Lt) Build(builder Builder) {
   333  	builder.WriteQuoted(lt.Column)
   334  	builder.WriteString(" < ")
   335  	builder.AddVar(builder, lt.Value)
   336  }
   337  
   338  func (lt Lt) NegationBuild(builder Builder) {
   339  	Gte(lt).Build(builder)
   340  }
   341  
   342  // Lte less than or equal to for where
   343  type Lte Eq
   344  
   345  func (lte Lte) Build(builder Builder) {
   346  	builder.WriteQuoted(lte.Column)
   347  	builder.WriteString(" <= ")
   348  	builder.AddVar(builder, lte.Value)
   349  }
   350  
   351  func (lte Lte) NegationBuild(builder Builder) {
   352  	Gt(lte).Build(builder)
   353  }
   354  
   355  // Like whether string matches regular expression
   356  type Like Eq
   357  
   358  func (like Like) Build(builder Builder) {
   359  	builder.WriteQuoted(like.Column)
   360  	builder.WriteString(" LIKE ")
   361  	builder.AddVar(builder, like.Value)
   362  }
   363  
   364  func (like Like) NegationBuild(builder Builder) {
   365  	builder.WriteQuoted(like.Column)
   366  	builder.WriteString(" NOT LIKE ")
   367  	builder.AddVar(builder, like.Value)
   368  }
   369  
   370  func eqNil(value interface{}) bool {
   371  	if valuer, ok := value.(driver.Valuer); ok && !eqNilReflect(valuer) {
   372  		value, _ = valuer.Value()
   373  	}
   374  
   375  	return value == nil || eqNilReflect(value)
   376  }
   377  
   378  func eqNilReflect(value interface{}) bool {
   379  	reflectValue := reflect.ValueOf(value)
   380  	return reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil()
   381  }