github.com/octohelm/storage@v0.0.0-20240516030302-1ac2cc1ea347/pkg/sqlbuilder/expr.go (about)

     1  package sqlbuilder
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"database/sql"
     7  	"database/sql/driver"
     8  	"fmt"
     9  	"reflect"
    10  	"strings"
    11  	"text/scanner"
    12  
    13  	reflectx "github.com/octohelm/x/reflect"
    14  )
    15  
    16  func IsNilExpr(e SqlExpr) bool {
    17  	return e == nil || e.IsNil()
    18  }
    19  
    20  func RangeNotNilExpr(exprs []SqlExpr, each func(e SqlExpr, i int)) {
    21  	count := 0
    22  
    23  	for i := range exprs {
    24  		e := exprs[i]
    25  		if IsNilExpr(e) {
    26  			continue
    27  		}
    28  		each(e, count)
    29  		count++
    30  	}
    31  }
    32  
    33  func ExactlyExpr(query string, args ...any) *Ex {
    34  	if query != "" {
    35  		return &Ex{b: *bytes.NewBufferString(query), args: args, exactly: true}
    36  	}
    37  	return &Ex{args: args, exactly: true}
    38  }
    39  
    40  func Expr(query string, args ...any) *Ex {
    41  	if query != "" {
    42  		return &Ex{b: *bytes.NewBufferString(query), args: args}
    43  	}
    44  	return &Ex{args: args}
    45  }
    46  
    47  func ResolveExpr(v any) *Ex {
    48  	return ResolveExprContext(context.Background(), v)
    49  }
    50  
    51  func ResolveExprContext(ctx context.Context, v any) *Ex {
    52  	switch e := v.(type) {
    53  	case nil:
    54  		return nil
    55  	case SqlExpr:
    56  		if IsNilExpr(e) {
    57  			return nil
    58  		}
    59  		return e.Ex(ctx)
    60  	}
    61  	return nil
    62  }
    63  
    64  func Multi(exprs ...SqlExpr) SqlExpr {
    65  	return MultiWith(" ", exprs...)
    66  }
    67  
    68  func MultiWith(connector string, exprs ...SqlExpr) SqlExpr {
    69  	return ExprBy(func(ctx context.Context) *Ex {
    70  		e := Expr("")
    71  		e.Grow(len(exprs))
    72  
    73  		for i := range exprs {
    74  			if i != 0 {
    75  				e.WriteQuery(connector)
    76  			}
    77  			e.WriteExpr(exprs[i])
    78  		}
    79  		return e.Ex(ctx)
    80  	})
    81  }
    82  
    83  func ExprBy(build func(ctx context.Context) *Ex) SqlExpr {
    84  	return &exBy{build: build}
    85  }
    86  
    87  type exBy struct {
    88  	build func(ctx context.Context) *Ex
    89  }
    90  
    91  func (c *exBy) IsNil() bool {
    92  	return c == nil || c.build == nil
    93  }
    94  
    95  func (c *exBy) Ex(ctx context.Context) *Ex {
    96  	return c.build(ctx)
    97  }
    98  
    99  type SqlExpr interface {
   100  	IsNil() bool
   101  	Ex(ctx context.Context) *Ex
   102  }
   103  
   104  // ValuerExpr
   105  // replace ? as some query snippet
   106  //
   107  // examples:
   108  // ? => ST_GeomFromText(?)
   109  type ValuerExpr interface {
   110  	ValueEx() string
   111  }
   112  
   113  type Ex struct {
   114  	b       bytes.Buffer
   115  	args    []any
   116  	err     error
   117  	exactly bool
   118  }
   119  
   120  func (e *Ex) IsNil() bool {
   121  	return e == nil || e.b.Len() == 0
   122  }
   123  
   124  func (e *Ex) Query() string {
   125  	if e == nil {
   126  		return ""
   127  	}
   128  	return e.b.String()
   129  }
   130  
   131  func (e *Ex) Args() []any {
   132  	if len(e.args) == 0 {
   133  		return nil
   134  	}
   135  	return e.args
   136  }
   137  
   138  func (e *Ex) Err() error {
   139  	return e.err
   140  }
   141  
   142  func (e *Ex) AppendArgs(args ...any) {
   143  	e.args = append(e.args, args...)
   144  }
   145  
   146  func (e *Ex) ArgsLen() int {
   147  	return len(e.args)
   148  }
   149  
   150  func (e *Ex) WriteString(s string) (int, error) {
   151  	return e.b.WriteString(s)
   152  }
   153  
   154  func (e *Ex) WriteByte(b byte) error {
   155  	return e.b.WriteByte(b)
   156  }
   157  
   158  func (e *Ex) QueryGrow(n int) {
   159  	e.b.Grow(n)
   160  }
   161  
   162  func (e *Ex) Grow(n int) {
   163  	if n > 0 && cap(e.args)-len(e.args) < n {
   164  		args := make([]any, len(e.args), 2*cap(e.args)+n)
   165  		copy(args, e.args)
   166  		e.args = args
   167  	}
   168  }
   169  
   170  func (e *Ex) WriteQuery(s string) {
   171  	_, _ = e.b.WriteString(s)
   172  }
   173  
   174  func (e *Ex) WriteQueryByte(b byte) {
   175  	_ = e.b.WriteByte(b)
   176  }
   177  
   178  func (e *Ex) WriteGroup(fn func(e *Ex)) {
   179  	e.WriteQueryByte('(')
   180  	fn(e)
   181  	e.WriteQueryByte(')')
   182  }
   183  
   184  func (e *Ex) WhiteComments(comments []byte) {
   185  	_, _ = e.b.WriteString("/* ")
   186  	_, _ = e.b.Write(comments)
   187  	_, _ = e.b.WriteString(" */")
   188  }
   189  
   190  func (e *Ex) WriteExpr(expr SqlExpr) {
   191  	if IsNilExpr(expr) {
   192  		return
   193  	}
   194  
   195  	e.WriteHolder(0)
   196  	e.AppendArgs(expr)
   197  }
   198  
   199  func (e *Ex) WriteEnd() {
   200  	e.WriteQueryByte(';')
   201  }
   202  
   203  func (e *Ex) WriteHolder(idx int) {
   204  	if idx > 0 {
   205  		e.b.WriteByte(',')
   206  	}
   207  	e.b.WriteByte('?')
   208  }
   209  
   210  func (e *Ex) SetExactly(exactly bool) {
   211  	e.exactly = exactly
   212  }
   213  
   214  type NamedArg = sql.NamedArg
   215  
   216  type NamedArgSet map[string]any
   217  
   218  func (e *Ex) Ex(ctx context.Context) *Ex {
   219  	if e.IsNil() {
   220  		return nil
   221  	}
   222  
   223  	allArgs, n := e.args, len(e.args)
   224  
   225  	eb := Expr("")
   226  	eb.Grow(n)
   227  
   228  	query := e.Query()
   229  
   230  	if e.exactly {
   231  		eb.WriteQuery(query)
   232  		eb.AppendArgs(allArgs...)
   233  		eb.exactly = true
   234  		return eb
   235  	}
   236  
   237  	namedArgSet, args, shouldResolveArgs := preprocessArgs(allArgs)
   238  
   239  	if !shouldResolveArgs {
   240  		eb.WriteQuery(query)
   241  		eb.AppendArgs(args...)
   242  		eb.SetExactly(true)
   243  		return eb
   244  	}
   245  
   246  	argIndex := 0
   247  
   248  	s := &scanner.Scanner{}
   249  	s.Init(bytes.NewBuffer([]byte(query)))
   250  	s.Error = func(s *scanner.Scanner, msg string) {}
   251  
   252  	for c := s.Next(); c != scanner.EOF; c = s.Next() {
   253  		switch c {
   254  		case '@':
   255  			named := bytes.NewBuffer(nil)
   256  
   257  			for {
   258  				c = s.Next()
   259  
   260  				if c == scanner.EOF {
   261  					break
   262  				}
   263  
   264  				if (c >= 'A' && c <= 'Z') ||
   265  					(c >= 'a' && c <= 'z') ||
   266  					(c >= '0' && c <= '9') ||
   267  					c == '_' {
   268  
   269  					named.WriteRune(c)
   270  					continue
   271  				}
   272  				break
   273  			}
   274  
   275  			if named.Len() > 0 {
   276  				name := named.String()
   277  
   278  				if v, ok := namedArgSet[name]; ok {
   279  					switch arg := v.(type) {
   280  					case SqlExpr:
   281  						if !IsNilExpr(arg) {
   282  							subExpr := arg.Ex(ctx)
   283  
   284  							if subExpr != eb && !IsNilExpr(subExpr) {
   285  								eb.WriteQuery(subExpr.Query())
   286  								eb.AppendArgs(subExpr.Args()...)
   287  							}
   288  						}
   289  					default:
   290  						eb.WriteHolder(0)
   291  						eb.AppendArgs(arg)
   292  					}
   293  				} else {
   294  					panic(fmt.Sprintf("missing named arg `%s`", name))
   295  				}
   296  			}
   297  
   298  			if c != scanner.EOF {
   299  				eb.WriteQueryByte(byte(c))
   300  			}
   301  		case '?':
   302  			if argIndex >= n {
   303  				panic(fmt.Errorf("missing arg %d of %s", argIndex, query))
   304  			}
   305  
   306  			switch arg := args[argIndex].(type) {
   307  			case SqlExpr:
   308  				if !IsNilExpr(arg) {
   309  					subExpr := arg.Ex(ctx)
   310  
   311  					if subExpr != eb && !IsNilExpr(subExpr) {
   312  						eb.WriteQuery(subExpr.Query())
   313  						eb.AppendArgs(subExpr.Args()...)
   314  					}
   315  				}
   316  			default:
   317  				eb.WriteHolder(0)
   318  				eb.AppendArgs(arg)
   319  			}
   320  			argIndex++
   321  		default:
   322  			eb.WriteQueryByte(byte(c))
   323  		}
   324  	}
   325  
   326  	eb.SetExactly(true)
   327  
   328  	return eb
   329  }
   330  
   331  func exactlyExprFromSlice(values []any) *Ex {
   332  	if n := len(values); n > 0 {
   333  		return ExactlyExpr(strings.Repeat(",?", n)[1:], values...)
   334  	}
   335  	return ExactlyExpr("")
   336  }
   337  
   338  func preprocessArgs(args []any) (NamedArgSet, []any, bool) {
   339  	namedArgSet := NamedArgSet{}
   340  	finalArgs := make([]any, 0, len(args))
   341  
   342  	shouldResolve := false
   343  
   344  	for i := range args {
   345  		switch arg := args[i].(type) {
   346  		case NamedArgSet:
   347  			for k := range arg {
   348  				namedArgSet[k] = arg[k]
   349  			}
   350  			shouldResolve = true
   351  		case NamedArg:
   352  			namedArgSet[arg.Name] = arg.Value
   353  			shouldResolve = true
   354  		case ValuerExpr:
   355  			finalArgs = append(finalArgs, ExactlyExpr(arg.ValueEx(), arg))
   356  			shouldResolve = true
   357  		case SqlExpr:
   358  			finalArgs = append(finalArgs, arg)
   359  			shouldResolve = true
   360  		case driver.Valuer:
   361  			finalArgs = append(finalArgs, arg)
   362  		case []any:
   363  			finalArgs = append(finalArgs, exactlyExprFromSlice(arg))
   364  			shouldResolve = true
   365  		default:
   366  			if typ := reflect.TypeOf(arg); typ.Kind() == reflect.Slice {
   367  				if !reflectx.IsBytes(typ) {
   368  					finalArgs = append(finalArgs, exactlyExprFromSlice(toInterfaceSlice(arg)))
   369  					shouldResolve = true
   370  					continue
   371  				}
   372  			}
   373  			finalArgs = append(finalArgs, arg)
   374  		}
   375  	}
   376  
   377  	return namedArgSet, finalArgs, shouldResolve
   378  }
   379  
   380  func toInterfaceSlice(arg any) []any {
   381  	switch x := (arg).(type) {
   382  	case []bool:
   383  		values := make([]any, len(x))
   384  		for i := range values {
   385  			values[i] = x[i]
   386  		}
   387  		return values
   388  	case []string:
   389  		values := make([]any, len(x))
   390  		for i := range values {
   391  			values[i] = x[i]
   392  		}
   393  		return values
   394  	case []float32:
   395  		values := make([]any, len(x))
   396  		for i := range values {
   397  			values[i] = x[i]
   398  		}
   399  		return values
   400  	case []float64:
   401  		values := make([]any, len(x))
   402  		for i := range values {
   403  			values[i] = x[i]
   404  		}
   405  		return values
   406  	case []int:
   407  		values := make([]any, len(x))
   408  		for i := range values {
   409  			values[i] = x[i]
   410  		}
   411  		return values
   412  	case []int8:
   413  		values := make([]any, len(x))
   414  		for i := range values {
   415  			values[i] = x[i]
   416  		}
   417  		return values
   418  	case []int16:
   419  		values := make([]any, len(x))
   420  		for i := range values {
   421  			values[i] = x[i]
   422  		}
   423  		return values
   424  	case []int32:
   425  		values := make([]any, len(x))
   426  		for i := range values {
   427  			values[i] = x[i]
   428  		}
   429  		return values
   430  	case []int64:
   431  		values := make([]any, len(x))
   432  		for i := range values {
   433  			values[i] = x[i]
   434  		}
   435  		return values
   436  	case []uint:
   437  		values := make([]any, len(x))
   438  		for i := range values {
   439  			values[i] = x[i]
   440  		}
   441  		return values
   442  	case []uint8:
   443  		values := make([]any, len(x))
   444  		for i := range values {
   445  			values[i] = x[i]
   446  		}
   447  		return values
   448  	case []uint16:
   449  		values := make([]any, len(x))
   450  		for i := range values {
   451  			values[i] = x[i]
   452  		}
   453  		return values
   454  	case []uint32:
   455  		values := make([]any, len(x))
   456  		for i := range values {
   457  			values[i] = x[i]
   458  		}
   459  		return values
   460  	case []uint64:
   461  		values := make([]any, len(x))
   462  		for i := range values {
   463  			values[i] = x[i]
   464  		}
   465  		return values
   466  	case []any:
   467  		return x
   468  	}
   469  	sliceRv := reflect.ValueOf(arg)
   470  	values := make([]any, sliceRv.Len())
   471  	for i := range values {
   472  		values[i] = sliceRv.Index(i).Interface()
   473  	}
   474  	return values
   475  }