github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/zutil/args.go (about)

     1  package zutil
     2  
     3  import (
     4  	"bytes"
     5  	"database/sql"
     6  	"fmt"
     7  	"sort"
     8  	"strconv"
     9  	"strings"
    10  
    11  	"github.com/sohaha/zlsgo/zstring"
    12  	"github.com/sohaha/zlsgo/ztype"
    13  )
    14  
    15  // Args stores arguments associated
    16  type Args struct {
    17  	namedArgs      map[string]int
    18  	sqlNamedArgs   map[string]int
    19  	compileHandler ArgsCompileHandler
    20  	args           []argsArr
    21  	onlyNamed      bool
    22  }
    23  
    24  type argsArr struct {
    25  	Fn  func(k string) interface{}
    26  	Arg interface{}
    27  }
    28  type ArgsOpt func(*Args)
    29  type ArgsCompileHandler func(buf *bytes.Buffer, values []interface{}, arg interface{}) ([]interface{}, bool)
    30  
    31  const maxPredefinedArgs = 64
    32  
    33  var predefinedArgs []string
    34  
    35  func init() {
    36  	predefinedArgs = make([]string, 0, maxPredefinedArgs)
    37  	for i := 0; i < maxPredefinedArgs; i++ {
    38  		predefinedArgs = append(predefinedArgs, fmt.Sprintf("$%v", i))
    39  	}
    40  }
    41  
    42  func WithOnlyNamed() func(args *Args) {
    43  	return func(args *Args) {
    44  		args.onlyNamed = true
    45  	}
    46  }
    47  
    48  func WithCompileHandler(fn ArgsCompileHandler) func(args *Args) {
    49  	return func(args *Args) {
    50  		args.compileHandler = fn
    51  	}
    52  }
    53  
    54  // NewArgs returns a new Args
    55  func NewArgs(opt ...ArgsOpt) *Args {
    56  	args := &Args{}
    57  	for _, o := range opt {
    58  		o(args)
    59  	}
    60  	return args
    61  }
    62  
    63  // Var adds an arg to Args and returns a placeholder
    64  func (args *Args) Var(arg interface{}) string {
    65  	idx := args.add(arg, nil)
    66  	if idx < maxPredefinedArgs {
    67  		return predefinedArgs[idx]
    68  	}
    69  	return fmt.Sprintf("$%v", idx)
    70  }
    71  
    72  func (args *Args) add(arg interface{}, fn func(k string) interface{}) int {
    73  	idx := len(args.args)
    74  
    75  	switch a := arg.(type) {
    76  	case namedArgs:
    77  		if args.namedArgs == nil {
    78  			args.namedArgs = map[string]int{}
    79  		}
    80  		if p, ok := args.namedArgs[a.name]; ok {
    81  			arg = args.args[p]
    82  			break
    83  		}
    84  		arg := a.arg
    85  		switch v := a.arg.(type) {
    86  		default:
    87  			idx = args.add(arg, nil)
    88  		case func() interface{}:
    89  			idx = args.add(arg, func(_ string) interface{} { return v() })
    90  		case func(k string) interface{}:
    91  			idx = args.add(arg, v)
    92  		}
    93  
    94  		args.namedArgs[a.name] = idx
    95  		return idx
    96  	case sql.NamedArg:
    97  		if args.sqlNamedArgs == nil {
    98  			args.sqlNamedArgs = map[string]int{}
    99  		}
   100  		if p, ok := args.sqlNamedArgs[a.Name]; ok {
   101  			arg = args.args[p]
   102  			break
   103  		}
   104  
   105  		args.sqlNamedArgs[a.Name] = idx
   106  	}
   107  
   108  	args.args = append(args.args, argsArr{Arg: arg, Fn: fn})
   109  	return idx
   110  }
   111  
   112  // CompileString returns a string representation of Args
   113  func (args *Args) CompileString(format string, initialValue ...interface{}) string {
   114  	old := args.compileHandler
   115  	args.compileHandler = func(buf *bytes.Buffer, values []interface{}, arg interface{}) ([]interface{}, bool) {
   116  		switch v := arg.(type) {
   117  		case string:
   118  			buf.WriteString(v)
   119  		case sql.NamedArg:
   120  			buf.WriteString(ztype.ToString(v.Value))
   121  		default:
   122  			val := ztype.ToString(v)
   123  			buf.WriteString(val)
   124  		}
   125  		return values, true
   126  	}
   127  	defer func() {
   128  		if old != nil {
   129  			args.compileHandler = old
   130  		}
   131  	}()
   132  	query, _ := args.Compile(format, initialValue...)
   133  
   134  	return query
   135  }
   136  
   137  // Compile compiles builder's format to standard sql and returns associated args
   138  func (args *Args) Compile(format string, initialValue ...interface{}) (query string, values []interface{}) {
   139  	buf := GetBuff(256)
   140  	idx := strings.IndexRune(format, '$')
   141  	offset := 0
   142  	values = initialValue
   143  
   144  	for idx >= 0 && len(format) > 0 {
   145  		if idx > 0 {
   146  			buf.WriteString(format[:idx])
   147  		}
   148  
   149  		format = format[idx+1:]
   150  		if len(format) == 0 {
   151  			buf.WriteRune('$')
   152  			break
   153  		}
   154  
   155  		if r := format[0]; r == '$' {
   156  			buf.WriteRune('$')
   157  			format = format[1:]
   158  		} else if r == '{' {
   159  			format, values = args.compileNamed(buf, format, values)
   160  		} else if !args.onlyNamed && '0' <= r && r <= '9' {
   161  			format, values, offset = args.compileDigits(buf, format, values, offset)
   162  		} else if !args.onlyNamed && r == '?' {
   163  			format, values, offset = args.compileSuccessive(buf, format[1:], values, offset, "")
   164  		} else {
   165  			buf.WriteRune('$')
   166  		}
   167  
   168  		idx = strings.IndexRune(format, '$')
   169  	}
   170  
   171  	if len(format) > 0 {
   172  		buf.WriteString(format)
   173  	}
   174  
   175  	query = buf.String()
   176  
   177  	PutBuff(buf)
   178  
   179  	if len(args.sqlNamedArgs) > 0 {
   180  		ints := make([]int, 0, len(args.sqlNamedArgs))
   181  		for _, p := range args.sqlNamedArgs {
   182  			ints = append(ints, p)
   183  		}
   184  		sort.Ints(ints)
   185  
   186  		for _, i := range ints {
   187  			values = append(values, args.args[i].Arg)
   188  		}
   189  	}
   190  
   191  	return
   192  }
   193  
   194  func (args *Args) compileNamed(buf *bytes.Buffer, format string, values []interface{}) (string, []interface{}) {
   195  	i := 1
   196  	for ; i < len(format) && format[i] != '}'; i++ {
   197  	}
   198  	if i == len(format) {
   199  		return format, values
   200  	}
   201  
   202  	name := format[1:i]
   203  	format = format[i+1:]
   204  
   205  	if p, ok := args.namedArgs[name]; ok {
   206  		format, values, _ = args.compileSuccessive(buf, format, values, p, "")
   207  	} else if strings.IndexRune(name, '.') > 0 {
   208  		for n := range args.namedArgs {
   209  			if zstring.Match(name, n) {
   210  				p := args.namedArgs[n]
   211  				format, values, _ = args.compileSuccessive(buf, format, values, p, name)
   212  			}
   213  		}
   214  	}
   215  
   216  	return format, values
   217  }
   218  
   219  func (args *Args) compileDigits(buf *bytes.Buffer, format string, values []interface{}, offset int) (string, []interface{}, int) {
   220  	i := 1
   221  	for ; i < len(format) && '0' <= format[i] && format[i] <= '9'; i++ {
   222  	}
   223  
   224  	digits := format[:i]
   225  	format = format[i:]
   226  
   227  	if pointer, err := strconv.Atoi(digits); err == nil {
   228  		return args.compileSuccessive(buf, format, values, pointer, "")
   229  	}
   230  
   231  	return format, values, offset
   232  }
   233  
   234  func (args *Args) compileSuccessive(buf *bytes.Buffer, format string, values []interface{}, offset int, name string) (string, []interface{}, int) {
   235  	if offset >= len(args.args) {
   236  		return format, values, offset
   237  	}
   238  
   239  	arg := args.args[offset]
   240  	if arg.Fn != nil {
   241  		values = args.CompileArg(buf, values, arg.Fn(name))
   242  	} else {
   243  		values = args.CompileArg(buf, values, arg.Arg)
   244  	}
   245  
   246  	return format, values, offset + 1
   247  }
   248  
   249  func (args *Args) CompileArg(buf *bytes.Buffer, values []interface{}, arg interface{}) []interface{} {
   250  	if args.compileHandler != nil {
   251  		if values, ok := args.compileHandler(buf, values, arg); ok {
   252  			return values
   253  		}
   254  	}
   255  	switch a := arg.(type) {
   256  	case sql.NamedArg:
   257  		buf.WriteRune('@')
   258  		buf.WriteString(a.Name)
   259  	default:
   260  		buf.WriteRune('?')
   261  		values = append(values, arg)
   262  	}
   263  
   264  	return values
   265  }