github.com/kunlun-qilian/sqlx/v3@v3.0.0/builder/expr.go (about)

     1  package builder
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql/driver"
     7  	"fmt"
     8  	"reflect"
     9  	"strings"
    10  
    11  	reflectx "github.com/go-courier/x/reflect"
    12  )
    13  
    14  type SqlExpr interface {
    15  	IsNil() bool
    16  	Ex(ctx context.Context) *Ex
    17  }
    18  
    19  func IsNilExpr(e SqlExpr) bool {
    20  	return e == nil || e.IsNil()
    21  }
    22  
    23  func RangeNotNilExpr(exprs []SqlExpr, each func(e SqlExpr, i int)) {
    24  	count := 0
    25  
    26  	for i := range exprs {
    27  		e := exprs[i]
    28  		if IsNilExpr(e) {
    29  			continue
    30  		}
    31  		each(e, count)
    32  		count++
    33  	}
    34  }
    35  
    36  func ExactlyExpr(query string, args ...interface{}) *Ex {
    37  	if query != "" {
    38  		return &Ex{b: *bytes.NewBufferString(query), args: args, exactly: true}
    39  	}
    40  	return &Ex{args: args, exactly: true}
    41  }
    42  
    43  func Expr(query string, args ...interface{}) *Ex {
    44  	if query != "" {
    45  		return &Ex{b: *bytes.NewBufferString(query), args: args}
    46  	}
    47  	return &Ex{args: args}
    48  }
    49  
    50  func ResolveExpr(v interface{}) *Ex {
    51  	return ResolveExprContext(context.Background(), v)
    52  }
    53  
    54  func ResolveExprContext(ctx context.Context, v interface{}) *Ex {
    55  	switch e := v.(type) {
    56  	case nil:
    57  		return nil
    58  	case SqlExpr:
    59  		if IsNilExpr(e) {
    60  			return nil
    61  		}
    62  		return e.Ex(ctx)
    63  	}
    64  	return nil
    65  }
    66  
    67  func Multi(exprs ...SqlExpr) SqlExpr {
    68  	return MultiWith(" ", exprs...)
    69  }
    70  
    71  func MultiWith(connector string, exprs ...SqlExpr) SqlExpr {
    72  	return ExprBy(func(ctx context.Context) *Ex {
    73  		e := Expr("")
    74  		e.Grow(len(exprs))
    75  
    76  		for i := range exprs {
    77  			if i != 0 {
    78  				e.WriteQuery(connector)
    79  			}
    80  			e.WriteExpr(exprs[i])
    81  		}
    82  		return e.Ex(ctx)
    83  	})
    84  }
    85  
    86  func ExprBy(build func(ctx context.Context) *Ex) SqlExpr {
    87  	return &exBy{build: build}
    88  }
    89  
    90  type exBy struct {
    91  	build func(ctx context.Context) *Ex
    92  }
    93  
    94  func (c *exBy) IsNil() bool {
    95  	return c == nil || c.build == nil
    96  }
    97  
    98  func (c *exBy) Ex(ctx context.Context) *Ex {
    99  	return c.build(ctx)
   100  }
   101  
   102  type Ex struct {
   103  	b       bytes.Buffer
   104  	args    []interface{}
   105  	err     error
   106  	exactly bool
   107  }
   108  
   109  func (e *Ex) IsNil() bool {
   110  	return e == nil || e.b.Len() == 0
   111  }
   112  
   113  func (e *Ex) Query() string {
   114  	if e == nil {
   115  		return ""
   116  	}
   117  	return e.b.String()
   118  }
   119  
   120  func (e *Ex) Args() []interface{} {
   121  	if len(e.args) == 0 {
   122  		return nil
   123  	}
   124  	return e.args
   125  }
   126  
   127  func (e *Ex) Err() error {
   128  	return e.err
   129  }
   130  
   131  func (e *Ex) AppendArgs(args ...interface{}) {
   132  	e.args = append(e.args, args...)
   133  }
   134  
   135  func (e *Ex) ArgsLen() int {
   136  	return len(e.args)
   137  }
   138  
   139  func (e *Ex) WriteString(s string) (int, error) {
   140  	return e.b.WriteString(s)
   141  }
   142  
   143  func (e *Ex) WriteByte(b byte) error {
   144  	return e.b.WriteByte(b)
   145  }
   146  
   147  func (e *Ex) QueryGrow(n int) {
   148  	e.b.Grow(n)
   149  }
   150  
   151  func (e *Ex) Grow(n int) {
   152  	if n > 0 && cap(e.args)-len(e.args) < n {
   153  		args := make([]interface{}, len(e.args), 2*cap(e.args)+n)
   154  		copy(args, e.args)
   155  		e.args = args
   156  	}
   157  }
   158  
   159  func (e *Ex) WriteQuery(s string) {
   160  	_, _ = e.b.WriteString(s)
   161  }
   162  
   163  func (e *Ex) WriteQueryByte(b byte) {
   164  	_ = e.b.WriteByte(b)
   165  }
   166  
   167  func (e *Ex) WriteGroup(fn func(e *Ex)) {
   168  	e.WriteQueryByte('(')
   169  	fn(e)
   170  	e.WriteQueryByte(')')
   171  }
   172  
   173  func (e *Ex) WhiteComments(comments []byte) {
   174  	_, _ = e.b.WriteString("/* ")
   175  	_, _ = e.b.Write(comments)
   176  	_, _ = e.b.WriteString(" */")
   177  }
   178  
   179  func (e *Ex) WriteExpr(expr SqlExpr) {
   180  	if IsNilExpr(expr) {
   181  		return
   182  	}
   183  
   184  	e.WriteHolder(0)
   185  	e.AppendArgs(expr)
   186  }
   187  
   188  func (e *Ex) WriteEnd() {
   189  	e.WriteQueryByte(';')
   190  }
   191  
   192  func (e *Ex) WriteHolder(idx int) {
   193  	if idx > 0 {
   194  		e.b.WriteByte(',')
   195  	}
   196  	e.b.WriteByte('?')
   197  }
   198  
   199  func (e *Ex) SetExactly(exactly bool) {
   200  	e.exactly = exactly
   201  }
   202  
   203  func (e *Ex) Ex(ctx context.Context) *Ex {
   204  	if e.IsNil() {
   205  		return nil
   206  	}
   207  
   208  	args, n := e.args, len(e.args)
   209  
   210  	eb := Expr("")
   211  	eb.Grow(n)
   212  
   213  	query := e.Query()
   214  
   215  	if e.exactly {
   216  		eb.WriteQuery(query)
   217  		eb.AppendArgs(args...)
   218  		eb.exactly = true
   219  		return eb
   220  	}
   221  
   222  	shouldResolveArgs := preprocessArgs(args)
   223  
   224  	if !shouldResolveArgs {
   225  		eb.WriteQuery(query)
   226  		eb.AppendArgs(args...)
   227  		eb.SetExactly(true)
   228  		return eb
   229  	}
   230  
   231  	argIndex := 0
   232  
   233  	for i := range query {
   234  		switch c := query[i]; c {
   235  		case '?':
   236  			if argIndex >= n {
   237  				panic(fmt.Errorf("missing arg %d of %s", argIndex, query))
   238  			}
   239  
   240  			switch arg := args[argIndex].(type) {
   241  			case SqlExpr:
   242  				if !IsNilExpr(arg) {
   243  					subExpr := arg.Ex(ctx)
   244  
   245  					if subExpr != eb && !IsNilExpr(subExpr) {
   246  						eb.WriteQuery(subExpr.Query())
   247  						eb.AppendArgs(subExpr.Args()...)
   248  					}
   249  				}
   250  			default:
   251  				eb.WriteHolder(0)
   252  				eb.AppendArgs(arg)
   253  			}
   254  			argIndex++
   255  		default:
   256  			eb.WriteQueryByte(c)
   257  		}
   258  	}
   259  
   260  	eb.SetExactly(true)
   261  
   262  	return eb
   263  }
   264  
   265  func exactlyExprFromSlice(values []interface{}) *Ex {
   266  	if n := len(values); n > 0 {
   267  		return ExactlyExpr(strings.Repeat(",?", n)[1:], values...)
   268  	}
   269  	return ExactlyExpr("")
   270  }
   271  
   272  func preprocessArgs(args []interface{}) bool {
   273  	shouldResolve := false
   274  
   275  	for i := range args {
   276  		switch arg := args[i].(type) {
   277  		case ValuerExpr:
   278  			args[i] = ExactlyExpr(arg.ValueEx(), arg)
   279  			shouldResolve = true
   280  		case SqlExpr:
   281  			shouldResolve = true
   282  		case driver.Valuer:
   283  
   284  		case []interface{}:
   285  			args[i] = exactlyExprFromSlice(arg)
   286  			shouldResolve = true
   287  		default:
   288  			if typ := reflect.TypeOf(arg); typ.Kind() == reflect.Slice {
   289  				if !reflectx.IsBytes(typ) {
   290  					args[i] = exactlyExprFromSlice(toInterfaceSlice(arg))
   291  					shouldResolve = true
   292  				}
   293  			}
   294  		}
   295  	}
   296  
   297  	return shouldResolve
   298  }
   299  
   300  func toInterfaceSlice(arg interface{}) []interface{} {
   301  	switch x := (arg).(type) {
   302  	case []bool:
   303  		values := make([]interface{}, len(x))
   304  		for i := range values {
   305  			values[i] = x[i]
   306  		}
   307  		return values
   308  	case []string:
   309  		values := make([]interface{}, len(x))
   310  		for i := range values {
   311  			values[i] = x[i]
   312  		}
   313  		return values
   314  	case []float32:
   315  		values := make([]interface{}, len(x))
   316  		for i := range values {
   317  			values[i] = x[i]
   318  		}
   319  		return values
   320  	case []float64:
   321  		values := make([]interface{}, len(x))
   322  		for i := range values {
   323  			values[i] = x[i]
   324  		}
   325  		return values
   326  	case []int:
   327  		values := make([]interface{}, len(x))
   328  		for i := range values {
   329  			values[i] = x[i]
   330  		}
   331  		return values
   332  	case []int8:
   333  		values := make([]interface{}, len(x))
   334  		for i := range values {
   335  			values[i] = x[i]
   336  		}
   337  		return values
   338  	case []int16:
   339  		values := make([]interface{}, len(x))
   340  		for i := range values {
   341  			values[i] = x[i]
   342  		}
   343  		return values
   344  	case []int32:
   345  		values := make([]interface{}, len(x))
   346  		for i := range values {
   347  			values[i] = x[i]
   348  		}
   349  		return values
   350  	case []int64:
   351  		values := make([]interface{}, len(x))
   352  		for i := range values {
   353  			values[i] = x[i]
   354  		}
   355  		return values
   356  	case []uint:
   357  		values := make([]interface{}, len(x))
   358  		for i := range values {
   359  			values[i] = x[i]
   360  		}
   361  		return values
   362  	case []uint8:
   363  		values := make([]interface{}, len(x))
   364  		for i := range values {
   365  			values[i] = x[i]
   366  		}
   367  		return values
   368  	case []uint16:
   369  		values := make([]interface{}, len(x))
   370  		for i := range values {
   371  			values[i] = x[i]
   372  		}
   373  		return values
   374  	case []uint32:
   375  		values := make([]interface{}, len(x))
   376  		for i := range values {
   377  			values[i] = x[i]
   378  		}
   379  		return values
   380  	case []uint64:
   381  		values := make([]interface{}, len(x))
   382  		for i := range values {
   383  			values[i] = x[i]
   384  		}
   385  		return values
   386  	case []interface{}:
   387  		return x
   388  	}
   389  	sliceRv := reflect.ValueOf(arg)
   390  	values := make([]interface{}, sliceRv.Len())
   391  	for i := range values {
   392  		values[i] = sliceRv.Index(i).Interface()
   393  	}
   394  	return values
   395  }