github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqx/daoparse.go (about) 1 package sqx 2 3 import ( 4 "database/sql" 5 "fmt" 6 "regexp" 7 "strconv" 8 "strings" 9 10 "github.com/bingoohuang/gg/pkg/sqlparse/sqlparser" 11 "github.com/bingoohuang/gg/pkg/ss" 12 ) 13 14 func (p *SQLParsed) checkFuncInOut(numIn int, f StructField) error { 15 if numIn == 0 && !p.isBindBy(ByNone) { 16 return fmt.Errorf("sql %s required bind varialbes, but the func %v has none", p.RawStmt, f.Type) 17 } 18 19 if numIn != 1 && p.isBindBy(ByName) { 20 return fmt.Errorf("sql %s required named varialbes, but the func %v has non-one arguments", 21 p.RawStmt, f.Type) 22 } 23 24 if p.isBindBy(BySeq, ByAuto) { 25 if numIn < p.MaxSeq { 26 // nolint:goerr113 27 return fmt.Errorf("sql %s required max %d vars, but the func %v has only %d arguments", 28 p.RawStmt, p.MaxSeq, f.Type, numIn) 29 } 30 } 31 32 return nil 33 } 34 35 type bindBy int 36 37 const ( 38 // ByNone means no bind params. 39 ByNone bindBy = iota 40 // ByAuto means auto seq for bind params. 41 ByAuto 42 // BySeq means specific seq for bind params. 43 BySeq 44 // ByName means named bind params. 45 ByName 46 ) 47 48 func (b bindBy) String() string { 49 switch b { 50 case ByNone: 51 return "byNone" 52 case ByAuto: 53 return "byAuto" 54 case BySeq: 55 return "bySeq" 56 case ByName: 57 return "byName" 58 default: 59 return "Unknown" 60 } 61 } 62 63 // SQLParsed is the structure of the parsed SQL. 64 type SQLParsed struct { 65 ID string 66 SQL SQLPart 67 BindBy bindBy 68 Vars []string 69 MaxSeq int 70 IsQuery bool 71 72 RawStmt string 73 74 fp FieldParts 75 runSQL string 76 opt *CreateDaoOpt 77 } 78 79 func (p SQLParsed) replaceQuery(db *sql.DB, query string) (string, error) { 80 if ss.AnyOfFold(ss.FirstWord(query), "CREATE") { 81 return query, nil 82 } 83 84 dbType := sqlparser.ToDBType(DriverName(db.Driver())) 85 cr, err := dbType.Convert(query) 86 return cr.ConvertQuery(), err 87 } 88 89 func (p SQLParsed) isBindBy(by ...bindBy) bool { 90 for _, b := range by { 91 if p.BindBy == b { 92 return true 93 } 94 } 95 96 return false 97 } 98 99 var sqlre = regexp.MustCompile(`'?:\w*'?`) 100 101 type FieldParts struct { 102 fieldParts []FieldPart 103 fieldVars []interface{} 104 } 105 106 func (p *FieldParts) AddFieldSqlPart(part string, varVal []interface{}, joinedSep bool) { 107 p.fieldParts = append(p.fieldParts, FieldPart{ 108 PartSQL: part, 109 BindVal: varVal, 110 PartSQLPlTimes: strings.Count(part, "?"), 111 JoinedSep: joinedSep, 112 }) 113 } 114 115 // ParseSQL parses the sql. 116 func ParseSQL(name, stmt string) (*SQLParsed, error) { 117 p := &SQLParsed{ID: name} 118 119 if err := p.fastParseSQL(stmt); err != nil { 120 return nil, err 121 } 122 123 return p, nil 124 } 125 126 func (p *SQLParsed) fastParseSQL(stmt string) error { 127 p.Vars = make([]string, 0) 128 p.RawStmt = sqlre.ReplaceAllStringFunc(stmt, func(v string) string { 129 if v[:1] == "'" { 130 v = v[2:] 131 } else { 132 v = v[1:] 133 } 134 v = strings.TrimSuffix(v, "'") 135 136 p.Vars = append(p.Vars, v) 137 return "?" 138 }) 139 140 var err error 141 142 p.BindBy, p.MaxSeq, err = parseBindBy(p.ID, p.Vars) 143 if err != nil { 144 return err 145 } 146 147 _, p.IsQuery = IsQuerySQL(p.RawStmt) 148 return nil 149 } 150 151 // IsQuerySQL tests a sql is a query or not. 152 func IsQuerySQL(query string) (string, bool) { 153 switch f := ss.FirstWord(query); strings.ToUpper(f) { 154 case "SELECT", "SHOW", "DESC", "DESCRIBE", "EXPLAIN": 155 return f, true 156 default: // "INSERT", "DELETE", "UPDATE", "SET", "REPLACE": 157 return f, false 158 } 159 } 160 161 func (p *SQLParsed) parseSQL(runSQl string) error { 162 p.Vars = make([]string, 0) 163 p.runSQL = sqlre.ReplaceAllStringFunc(runSQl, func(v string) string { 164 if v[:1] == "'" { 165 v = v[2:] 166 } else { 167 v = v[1:] 168 } 169 v = strings.TrimSuffix(v, "'") 170 p.Vars = append(p.Vars, v) 171 return "?" 172 }) 173 174 if len(p.fp.fieldParts) > 0 { 175 parsed, err := sqlparser.Parse(p.runSQL) 176 if err != nil { 177 return err 178 } 179 180 w, hasWhere := parsed.(sqlparser.IWhere) 181 if hasWhere { 182 hasWhere = w.GetWhere() != nil 183 } 184 185 for i, f := range p.fp.fieldParts { 186 if f.JoinedSep { 187 if i == 0 && !hasWhere { 188 p.runSQL += " where " + f.PartSQL 189 } else { 190 p.runSQL += " and " + f.PartSQL 191 } 192 } else { 193 p.runSQL += " " + f.PartSQL 194 } 195 196 p.Vars = append(p.Vars, f.VarMarks()...) 197 p.fp.fieldVars = append(p.fp.fieldVars, f.Vars()...) 198 } 199 } 200 201 return nil 202 } 203 204 type FieldPart struct { 205 PartSQL string 206 BindVal []interface{} 207 PartSQLPlTimes int 208 JoinedSep bool 209 } 210 211 func (p FieldPart) VarMarks() []string { 212 vars := make([]string, p.PartSQLPlTimes) 213 214 for i := 0; i < p.PartSQLPlTimes; i++ { 215 vars[i] = "?" 216 } 217 218 return vars 219 } 220 221 func (p FieldPart) Vars() []interface{} { 222 vars := make([]interface{}, p.PartSQLPlTimes) 223 224 for i := 0; i < p.PartSQLPlTimes; i++ { 225 vars[i] = p.BindVal[i] 226 } 227 228 return vars 229 } 230 231 func parseBindBy(sqlName string, vars []string) (bindBy bindBy, maxSeq int, err error) { 232 bindBy = ByNone 233 234 for _, v := range vars { 235 if v == "" { 236 if bindBy == ByAuto { 237 maxSeq++ 238 continue 239 } 240 241 if bindBy != ByNone { 242 // nolint:goerr113 243 return 0, 0, fmt.Errorf("[%s] illegal mixed bind mod (%v-%v)", sqlName, bindBy, ByAuto) 244 } 245 246 bindBy = ByAuto 247 maxSeq++ 248 249 continue 250 } 251 252 n, err := strconv.Atoi(v) 253 if err == nil { 254 if bindBy == BySeq { 255 if maxSeq < n { 256 maxSeq = n 257 } 258 259 continue 260 } 261 262 if bindBy != ByNone { 263 // nolint:goerr113 264 return 0, 0, fmt.Errorf("[%s] illegal mixed bind mod (%v-%v)", sqlName, bindBy, BySeq) 265 } 266 267 bindBy = BySeq 268 maxSeq = n 269 270 continue 271 } 272 273 if bindBy == ByName { 274 maxSeq++ 275 continue 276 } 277 278 if bindBy != ByNone { 279 // nolint:goerr113 280 return 0, 0, fmt.Errorf("[%s] illegal mixed bind mod (%v-%v)", sqlName, bindBy, ByName) 281 } 282 283 bindBy = ByName 284 maxSeq++ 285 } 286 287 return bindBy, maxSeq, nil 288 }