github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqx/dot.go (about) 1 package sqx 2 3 import ( 4 "bufio" 5 "bytes" 6 "fmt" 7 "io" 8 "os" 9 "regexp" 10 "strings" 11 "unicode" 12 13 "github.com/bingoohuang/gg/pkg/mapp" 14 "github.com/expr-lang/expr" 15 "github.com/expr-lang/expr/vm" 16 funk "github.com/thoas/go-funk" 17 ) 18 19 type Dot struct { 20 Query string 21 Vars []interface{} 22 23 CountQuery string 24 CuntVars []interface{} 25 } 26 27 // DotItem tells the item details. 28 type DotItem struct { 29 Content []string 30 Name string 31 Attrs map[string]string 32 } 33 34 var re = regexp.MustCompile(`\s*(\w+)\s*(:\s*(\S+))?`) 35 36 // ParseDotTag parses the tag like name:value age:34 adult to map 37 // returns the map and main tag's value. 38 func ParseDotTag(line, prefix, mainTag string) (map[string]string, string) { 39 l := strings.TrimSpace(line) 40 if !strings.HasPrefix(l, prefix) { 41 return nil, "" 42 } 43 44 l = strings.TrimSpace(l[2:]) 45 m := make(map[string]string) 46 47 for _, subs := range re.FindAllStringSubmatch(l, -1) { 48 m[subs[1]] = subs[3] 49 } 50 51 return m, m[mainTag] 52 } 53 54 // DotScanner scans the SQL statements from .sql files. 55 type DotScanner struct { 56 line string 57 queries map[string]DotItem 58 current DotItem 59 } 60 61 func (s *DotScanner) createNewItem(name string, tag map[string]string) { 62 s.current = DotItem{Name: name, Attrs: tag, Content: make([]string, 0)} 63 } 64 65 type stateFn func() stateFn 66 67 func (s *DotScanner) initialState() stateFn { 68 if tag, name := ParseDotTag(s.line, "--", "name"); name != "" { 69 s.createNewItem(name, tag) 70 71 return s.queryState 72 } 73 74 return s.initialState 75 } 76 77 func (s *DotScanner) queryState() stateFn { 78 if tag, name := ParseDotTag(s.line, "--", "name"); name != "" { 79 s.createNewItem(name, tag) 80 } else { 81 s.appendQueryLine() 82 } 83 84 return s.queryState 85 } 86 87 func (s *DotScanner) appendQueryLine() { 88 line := strings.Trim(s.line, " \t") 89 if len(line) == 0 { 90 return 91 } 92 93 s.current.Content = append(s.current.Content, strings.TrimSpace(line)) 94 s.queries[s.current.Name] = s.current 95 } 96 97 // Run runs the scanner. 98 func (s *DotScanner) Run(io *bufio.Scanner) map[string]DotItem { 99 s.queries = make(map[string]DotItem) 100 101 for state := s.initialState; io.Scan(); { 102 s.line = io.Text() 103 state = state() 104 } 105 106 return s.queries 107 } 108 109 // DotSQL is the set of SQL statements. 110 type DotSQL struct { 111 Dots map[string]DotItem 112 } 113 114 // Raw returns the query, everything after the --name tag. 115 func (d DotSQL) Raw(name string) (SQLPart, error) { 116 v, err := d.lookupQuery(name) 117 118 return v, err 119 } 120 121 func (d DotSQL) lookupQuery(name string) (query SQLPart, err error) { 122 s, ok := d.Dots[name] 123 if !ok { 124 return nil, fmt.Errorf("dotsql: '%s' could not be found", name) // nolint:goerr113 125 } 126 127 query, err = s.DynamicSQL() 128 129 return query, err 130 } 131 132 // RawSQL returns the raw SQL. 133 func (d DotItem) RawSQL() string { 134 delimiter := d.Attrs["delimiter"] 135 if delimiter == "" { 136 delimiter = ";" 137 } 138 139 return TrimSQL(strings.Join(d.Content, "\n"), delimiter) 140 } 141 142 // TrimSQL trims the delimiter from the string s. 143 func TrimSQL(s, delimiter string) string { 144 s = strings.TrimSpace(s) 145 146 for strings.HasPrefix(s, delimiter) || strings.HasSuffix(s, delimiter) { 147 s = strings.TrimPrefix(s, delimiter) 148 s = strings.TrimSuffix(s, delimiter) 149 s = strings.TrimSpace(s) 150 } 151 152 return s 153 } 154 155 // DynamicSQL returns the dynamic SQL. 156 func (d DotItem) DynamicSQL() (SQLPart, error) { 157 lines := ConvertSQLLines(d.Content) 158 159 _, part, err := ParseDynamicSQL(lines) 160 if err != nil { 161 return nil, err 162 } 163 164 if err := part.Compile(); err != nil { 165 return nil, err 166 } 167 168 return &PostProcessingSQLPart{ 169 Part: part, 170 Attrs: d.Attrs, 171 }, nil 172 } 173 174 // DotSQLLoad imports sql queries from any io.Reader. 175 func DotSQLLoad(r io.Reader) (*DotSQL, error) { 176 return &DotSQL{(&DotScanner{}).Run(bufio.NewScanner(r))}, nil 177 } 178 179 // DotSQLLoadFile imports SQL queries from the file. 180 func DotSQLLoadFile(sqlFile string) (*DotSQL, error) { 181 f, err := os.Open(sqlFile) 182 if err != nil { 183 return nil, err 184 } 185 186 defer f.Close() 187 188 return DotSQLLoad(f) 189 } 190 191 // DotSQLLoadString imports SQL queries from the string. 192 func DotSQLLoadString(s string) (*DotSQL, error) { return DotSQLLoad(bytes.NewBufferString(s)) } 193 194 // SQLPart defines the dynamic SQL part. 195 type SQLPart interface { 196 // Compile compiles the condition int advance. 197 Compile() error 198 // Eval evaluates the SQL part to a real SQL. 199 Eval(m map[string]interface{}) (string, error) 200 // Raw returns the raw content. 201 Raw() string 202 } 203 204 // PostProcessingSQLPart defines the SQLPart for post-processing like delimiter trimming. 205 type PostProcessingSQLPart struct { 206 Part SQLPart 207 Attrs map[string]string 208 } 209 210 // Compile compiles the condition int advance. 211 func (p *PostProcessingSQLPart) Compile() error { 212 return p.Part.Compile() 213 } 214 215 // Eval evaluated the dynamic sql with env. 216 func (p *PostProcessingSQLPart) Eval(env map[string]interface{}) (string, error) { 217 eval, err := p.Part.Eval(env) 218 if err != nil { 219 return "", err 220 } 221 222 delimiter := mapp.GetStringOr(p.Attrs, "delimiter", ";") 223 224 return TrimSQL(eval, delimiter), nil 225 } 226 227 // Raw returns the raw content. 228 func (p *PostProcessingSQLPart) Raw() string { 229 raw := p.Part.Raw() 230 231 delimiter := mapp.GetStringOr(p.Attrs, "delimiter", ";") 232 233 return TrimSQL(raw, delimiter) 234 } 235 236 // LiteralPart define literal SQL part that no eval required. 237 type LiteralPart struct { 238 Literal string 239 } 240 241 // MakeLiteralPart makes a MakeLiteralPart. 242 func MakeLiteralPart(s string) SQLPart { 243 return &LiteralPart{Literal: s} 244 } 245 246 // Compile compiles the condition int advance. 247 func (p *LiteralPart) Compile() error { return nil } 248 249 // Raw returns the raw content. 250 func (p *LiteralPart) Raw() string { return p.Literal } 251 252 // Eval evaluates the SQL part to a real SQL. 253 func (p *LiteralPart) Eval(map[string]interface{}) (string, error) { return p.Literal, nil } 254 255 // IfCondition defines a single condition that makes up a conditions-set for IfPart/SwitchPart. 256 type IfCondition struct { 257 Expr string 258 CompiledExpr *vm.Program 259 Part SQLPart 260 } 261 262 // IfPart is the part that has the format of if ... else if ... else ... end. 263 type IfPart struct { 264 Conditions []IfCondition 265 Else SQLPart 266 } 267 268 // Compile compiles the condition int advance. 269 func (p *IfPart) Compile() (err error) { 270 for i, c := range p.Conditions { 271 if c.CompiledExpr, err = expr.Compile(c.Expr); err != nil { 272 return err 273 } 274 275 p.Conditions[i] = c 276 } 277 278 return nil 279 } 280 281 // MakeIfPart makes a new IfPart. 282 func MakeIfPart() *IfPart { 283 return &IfPart{Conditions: make([]IfCondition, 0)} 284 } 285 286 // AddElse adds an else part to the IfPart. 287 func (p *IfPart) AddElse(part SQLPart) { 288 p.Else = part 289 } 290 291 // AddCondition adds a condition to the IfPart. 292 func (p *IfPart) AddCondition(conditionExpr string, part SQLPart) { 293 p.Conditions = append(p.Conditions, IfCondition{ 294 Expr: conditionExpr, 295 Part: part, 296 }) 297 } 298 299 // Eval evaluates the SQL part to a real SQL. 300 func (p *IfPart) Eval(env map[string]interface{}) (string, error) { 301 for _, c := range p.Conditions { 302 output, err := expr.Run(c.CompiledExpr, env) 303 if err != nil { 304 return "", err 305 } 306 307 if yes, ok := output.(bool); !ok { 308 return "", fmt.Errorf("%s is not a bool expression", c.Expr) // nolint:goerr113 309 } else if yes { 310 return c.Part.Eval(env) 311 } 312 } 313 314 if p.Else != nil { 315 return p.Else.Eval(env) 316 } 317 318 return "", nil 319 } 320 321 // Raw returns the raw content. 322 func (p *IfPart) Raw() string { 323 raw := "" 324 325 for _, c := range p.Conditions { 326 raw += c.Expr + "\n" + c.Part.Raw() 327 } 328 329 if p.Else != nil { 330 raw += "\n" + p.Else.Raw() 331 } 332 333 return raw 334 } 335 336 // MultiPart is the multi SQLParts. 337 type MultiPart struct { 338 Parts []SQLPart 339 } 340 341 // MakeMultiPart makes MultiPart. 342 func MakeMultiPart() *MultiPart { 343 return &MultiPart{Parts: make([]SQLPart, 0)} 344 } 345 346 // Eval evaluates the SQL part to a real SQL. 347 func (p *MultiPart) Eval(env map[string]interface{}) (string, error) { 348 value := "" 349 350 for _, p := range p.Parts { 351 v, err := p.Eval(env) 352 if err != nil { 353 return "", err 354 } 355 356 if value != "" { 357 value += " " 358 } 359 360 value += v 361 } 362 363 return value, nil 364 } 365 366 // Raw returns the raw content. 367 func (p *MultiPart) Raw() string { 368 raw := "" 369 370 for _, c := range p.Parts { 371 if raw != "" { 372 raw += "\n" 373 } 374 375 raw += c.Raw() 376 } 377 378 return raw 379 } 380 381 // AddPart adds a part to the current MultiPart. 382 func (p *MultiPart) AddPart(part SQLPart) { 383 p.Parts = append(p.Parts, part) 384 } 385 386 // Compile compiles the condition int advance. 387 func (p *MultiPart) Compile() error { 388 for _, part := range p.Parts { 389 if err := part.Compile(); err != nil { 390 return err 391 } 392 } 393 394 return nil 395 } 396 397 var ( 398 _ SQLPart = (*LiteralPart)(nil) 399 _ SQLPart = (*IfPart)(nil) 400 _ SQLPart = (*MultiPart)(nil) 401 _ SQLPart = (*PostProcessingSQLPart)(nil) 402 ) 403 404 // ParseDynamicSQL parses the dynamic sqls to structured SQLPart. 405 func ParseDynamicSQL(lines []string, terminators ...string) (int, SQLPart, error) { 406 multiPart := MakeMultiPart() 407 408 for i := 0; i < len(lines); i++ { 409 l := lines[i] 410 411 if !strings.HasPrefix(l, "--") { 412 multiPart.AddPart(MakeLiteralPart(l)) 413 continue 414 } 415 416 commentLine := strings.TrimSpace(l[2:]) 417 word := firstWord(commentLine, 1) 418 parser := CreateParser(word, strings.TrimSpace(commentLine[len(word):])) 419 420 if parser == nil { // no parser found, ignore comment line 421 if funk.ContainsString(terminators, word) { 422 return i, multiPart, nil 423 } 424 425 continue 426 } 427 428 partLines, part, err := parser.Parse(lines[i+1:]) 429 if err != nil { 430 return 0, nil, err 431 } 432 433 multiPart.AddPart(part) 434 435 i += partLines - 1 436 } 437 438 return len(lines), multiPart, nil 439 } 440 441 // ConvertSQLLines converts the inline comments to line comments 442 // and merge to uncomment lines together. 443 func ConvertSQLLines(lines []string) []string { 444 inlineCommentMode := false 445 noneComment := "" 446 inlineCommentContent := "" 447 converted := make([]string, 0) 448 449 for _, l := range lines { 450 if strings.HasPrefix(l, "--") { 451 if noneComment != "" { 452 converted = append(converted, noneComment) 453 noneComment = "" 454 } 455 456 converted = append(converted, l) 457 458 continue 459 } 460 461 inlineCommentGo: 462 l = strings.TrimSpace(l) 463 464 if l == "" { 465 continue 466 } 467 468 if !inlineCommentMode { 469 inlineCommentStart := strings.Index(l, "/*") 470 if inlineCommentStart < 0 { 471 noneComment = appendNoneComment(noneComment, l) 472 473 continue 474 } 475 476 inlineCommentMode = true 477 478 if before := strings.TrimSpace(l[0:inlineCommentStart]); before != "" { 479 noneComment = appendNoneComment(noneComment, before) 480 } 481 482 l = l[inlineCommentStart+2:] 483 } 484 485 inlineCommentStop := strings.Index(l, "*/") 486 if inlineCommentStop >= 0 { 487 inlineCommentMode = false 488 inlineCommentContent += l[:inlineCommentStop] 489 490 if inlineComment := strings.TrimSpace(inlineCommentContent); inlineComment != "" { 491 if noneComment != "" { 492 converted = append(converted, noneComment) 493 noneComment = "" 494 } 495 496 converted = append(converted, "-- "+inlineComment) 497 } 498 499 l = l[inlineCommentStop+2:] 500 inlineCommentContent = "" 501 502 goto inlineCommentGo 503 } 504 505 inlineCommentContent += l 506 } 507 508 if noneComment != "" { 509 converted = append(converted, noneComment) 510 } 511 512 return converted 513 } 514 515 func appendNoneComment(noneComment string, l string) string { 516 if noneComment != "" { 517 noneComment += "\n" 518 } 519 520 return noneComment + l 521 } 522 523 // SQLPartParser defines the parser of SQLPart. 524 type SQLPartParser interface { 525 // Parse parses the lines to SQLPart. 526 Parse(lines []string) (partLines int, part SQLPart, err error) 527 } 528 529 // IfSQLPartParser defines the Parser of IfPart. 530 type IfSQLPartParser struct { 531 Condition string 532 Else string 533 } 534 535 // MakeIfSQLPartParser makes a IfSQLPartParser. 536 func MakeIfSQLPartParser(condition string) *IfSQLPartParser { 537 return &IfSQLPartParser{ 538 Condition: condition, 539 } 540 } 541 542 // Parse parses the lines to SQLPart. 543 func (p *IfSQLPartParser) Parse(lines []string) (partLines int, part SQLPart, err error) { 544 ifPart := MakeIfPart() 545 condition := p.Condition 546 547 for i := 0; i < len(lines); i++ { 548 l := lines[i] 549 550 if !strings.HasPrefix(l, "--") { 551 ifPart.AddCondition(condition, MakeLiteralMultiPart(l)) 552 continue 553 } 554 555 commentLine := strings.TrimSpace(l[2:]) 556 word := firstWord(commentLine, 1) 557 558 if word == "end" { 559 return i + 2 /*包括if 行*/, ifPart, nil 560 } 561 562 if word == "elseif" { 563 condition = strings.TrimSpace(commentLine[len(word):]) 564 565 processLines, sqlPart, err := ParseDynamicSQL(lines[i+1:], "end", "elseif", "else") 566 if err != nil { 567 return 0, nil, err 568 } 569 570 ifPart.AddCondition(condition, sqlPart) 571 572 i += processLines 573 574 continue 575 } 576 577 if word == "else" { 578 processLines, sqlPart, err := ParseDynamicSQL(lines[i+1:], "end") 579 if err != nil { 580 return 0, nil, err 581 } 582 583 ifPart.AddElse(sqlPart) 584 585 return i + 2 + processLines, ifPart, nil 586 } 587 588 processLines, sqlPart, err := ParseDynamicSQL(lines[i:], "end", "elseif", "else") 589 if err != nil { 590 return 0, nil, err 591 } 592 593 ifPart.AddCondition(condition, sqlPart) 594 595 i += processLines - 1 596 } 597 598 return 0, nil, fmt.Errorf("no end found for if expr") // nolint:goerr113 599 } 600 601 // MakeLiteralMultiPart makes a MultiPart. 602 func MakeLiteralMultiPart(l string) *MultiPart { 603 return &MultiPart{Parts: []SQLPart{&LiteralPart{l}}} 604 } 605 606 var _ SQLPartParser = (*IfSQLPartParser)(nil) 607 608 // CreateParser creates a SQLPartParser. 609 // If no parser found, nil returned. 610 func CreateParser(word string, l string) SQLPartParser { 611 if word == "if" { 612 return MakeIfSQLPartParser(l) 613 } 614 615 return nil 616 } 617 618 func firstWord(value string, count int) string { 619 // Loop over all indexes in the string. 620 for i := range value { 621 // If we encounter a space, reduce the count. 622 if unicode.IsSpace(rune(value[i])) { 623 count-- 624 // When no more words required, return a substring. 625 if count == 0 { 626 return value[0:i] 627 } 628 } 629 } 630 631 // Return the entire string. 632 return value 633 }