github.com/sohaha/zlsgo@v1.7.13-0.20240501141223-10dd1a906f76/zutil/args.go (about) 1 package zutil 2 3 import ( 4 "bytes" 5 "database/sql" 6 "fmt" 7 "sort" 8 "strconv" 9 "strings" 10 11 "github.com/sohaha/zlsgo/zstring" 12 "github.com/sohaha/zlsgo/ztype" 13 ) 14 15 // Args stores arguments associated 16 type Args struct { 17 namedArgs map[string]int 18 sqlNamedArgs map[string]int 19 compileHandler ArgsCompileHandler 20 args []argsArr 21 onlyNamed bool 22 } 23 24 type argsArr struct { 25 Fn func(k string) interface{} 26 Arg interface{} 27 } 28 type ArgsOpt func(*Args) 29 type ArgsCompileHandler func(buf *bytes.Buffer, values []interface{}, arg interface{}) ([]interface{}, bool) 30 31 const maxPredefinedArgs = 64 32 33 var predefinedArgs []string 34 35 func init() { 36 predefinedArgs = make([]string, 0, maxPredefinedArgs) 37 for i := 0; i < maxPredefinedArgs; i++ { 38 predefinedArgs = append(predefinedArgs, fmt.Sprintf("$%v", i)) 39 } 40 } 41 42 func WithOnlyNamed() func(args *Args) { 43 return func(args *Args) { 44 args.onlyNamed = true 45 } 46 } 47 48 func WithCompileHandler(fn ArgsCompileHandler) func(args *Args) { 49 return func(args *Args) { 50 args.compileHandler = fn 51 } 52 } 53 54 // NewArgs returns a new Args 55 func NewArgs(opt ...ArgsOpt) *Args { 56 args := &Args{} 57 for _, o := range opt { 58 o(args) 59 } 60 return args 61 } 62 63 // Var adds an arg to Args and returns a placeholder 64 func (args *Args) Var(arg interface{}) string { 65 idx := args.add(arg, nil) 66 if idx < maxPredefinedArgs { 67 return predefinedArgs[idx] 68 } 69 return fmt.Sprintf("$%v", idx) 70 } 71 72 func (args *Args) add(arg interface{}, fn func(k string) interface{}) int { 73 idx := len(args.args) 74 75 switch a := arg.(type) { 76 case namedArgs: 77 if args.namedArgs == nil { 78 args.namedArgs = map[string]int{} 79 } 80 if p, ok := args.namedArgs[a.name]; ok { 81 arg = args.args[p] 82 break 83 } 84 arg := a.arg 85 switch v := a.arg.(type) { 86 default: 87 idx = args.add(arg, nil) 88 case func() interface{}: 89 idx = args.add(arg, func(_ string) interface{} { return v() }) 90 case func(k string) interface{}: 91 idx = args.add(arg, v) 92 } 93 94 args.namedArgs[a.name] = idx 95 return idx 96 case sql.NamedArg: 97 if args.sqlNamedArgs == nil { 98 args.sqlNamedArgs = map[string]int{} 99 } 100 if p, ok := args.sqlNamedArgs[a.Name]; ok { 101 arg = args.args[p] 102 break 103 } 104 105 args.sqlNamedArgs[a.Name] = idx 106 } 107 108 args.args = append(args.args, argsArr{Arg: arg, Fn: fn}) 109 return idx 110 } 111 112 // CompileString returns a string representation of Args 113 func (args *Args) CompileString(format string, initialValue ...interface{}) string { 114 old := args.compileHandler 115 args.compileHandler = func(buf *bytes.Buffer, values []interface{}, arg interface{}) ([]interface{}, bool) { 116 switch v := arg.(type) { 117 case string: 118 buf.WriteString(v) 119 case sql.NamedArg: 120 buf.WriteString(ztype.ToString(v.Value)) 121 default: 122 val := ztype.ToString(v) 123 buf.WriteString(val) 124 } 125 return values, true 126 } 127 defer func() { 128 if old != nil { 129 args.compileHandler = old 130 } 131 }() 132 query, _ := args.Compile(format, initialValue...) 133 134 return query 135 } 136 137 // Compile compiles builder's format to standard sql and returns associated args 138 func (args *Args) Compile(format string, initialValue ...interface{}) (query string, values []interface{}) { 139 buf := GetBuff(256) 140 idx := strings.IndexRune(format, '$') 141 offset := 0 142 values = initialValue 143 144 for idx >= 0 && len(format) > 0 { 145 if idx > 0 { 146 buf.WriteString(format[:idx]) 147 } 148 149 format = format[idx+1:] 150 if len(format) == 0 { 151 buf.WriteRune('$') 152 break 153 } 154 155 if r := format[0]; r == '$' { 156 buf.WriteRune('$') 157 format = format[1:] 158 } else if r == '{' { 159 format, values = args.compileNamed(buf, format, values) 160 } else if !args.onlyNamed && '0' <= r && r <= '9' { 161 format, values, offset = args.compileDigits(buf, format, values, offset) 162 } else if !args.onlyNamed && r == '?' { 163 format, values, offset = args.compileSuccessive(buf, format[1:], values, offset, "") 164 } else { 165 buf.WriteRune('$') 166 } 167 168 idx = strings.IndexRune(format, '$') 169 } 170 171 if len(format) > 0 { 172 buf.WriteString(format) 173 } 174 175 query = buf.String() 176 177 PutBuff(buf) 178 179 if len(args.sqlNamedArgs) > 0 { 180 ints := make([]int, 0, len(args.sqlNamedArgs)) 181 for _, p := range args.sqlNamedArgs { 182 ints = append(ints, p) 183 } 184 sort.Ints(ints) 185 186 for _, i := range ints { 187 values = append(values, args.args[i].Arg) 188 } 189 } 190 191 return 192 } 193 194 func (args *Args) compileNamed(buf *bytes.Buffer, format string, values []interface{}) (string, []interface{}) { 195 i := 1 196 for ; i < len(format) && format[i] != '}'; i++ { 197 } 198 if i == len(format) { 199 return format, values 200 } 201 202 name := format[1:i] 203 format = format[i+1:] 204 205 if p, ok := args.namedArgs[name]; ok { 206 format, values, _ = args.compileSuccessive(buf, format, values, p, "") 207 } else if strings.IndexRune(name, '.') > 0 { 208 for n := range args.namedArgs { 209 if zstring.Match(name, n) { 210 p := args.namedArgs[n] 211 format, values, _ = args.compileSuccessive(buf, format, values, p, name) 212 } 213 } 214 } 215 216 return format, values 217 } 218 219 func (args *Args) compileDigits(buf *bytes.Buffer, format string, values []interface{}, offset int) (string, []interface{}, int) { 220 i := 1 221 for ; i < len(format) && '0' <= format[i] && format[i] <= '9'; i++ { 222 } 223 224 digits := format[:i] 225 format = format[i:] 226 227 if pointer, err := strconv.Atoi(digits); err == nil { 228 return args.compileSuccessive(buf, format, values, pointer, "") 229 } 230 231 return format, values, offset 232 } 233 234 func (args *Args) compileSuccessive(buf *bytes.Buffer, format string, values []interface{}, offset int, name string) (string, []interface{}, int) { 235 if offset >= len(args.args) { 236 return format, values, offset 237 } 238 239 arg := args.args[offset] 240 if arg.Fn != nil { 241 values = args.CompileArg(buf, values, arg.Fn(name)) 242 } else { 243 values = args.CompileArg(buf, values, arg.Arg) 244 } 245 246 return format, values, offset + 1 247 } 248 249 func (args *Args) CompileArg(buf *bytes.Buffer, values []interface{}, arg interface{}) []interface{} { 250 if args.compileHandler != nil { 251 if values, ok := args.compileHandler(buf, values, arg); ok { 252 return values 253 } 254 } 255 switch a := arg.(type) { 256 case sql.NamedArg: 257 buf.WriteRune('@') 258 buf.WriteString(a.Name) 259 default: 260 buf.WriteRune('?') 261 values = append(values, arg) 262 } 263 264 return values 265 }