github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqx/dot.go (about)

     1  package sqx
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"fmt"
     7  	"io"
     8  	"os"
     9  	"regexp"
    10  	"strings"
    11  	"unicode"
    12  
    13  	"github.com/bingoohuang/gg/pkg/mapp"
    14  	"github.com/expr-lang/expr"
    15  	"github.com/expr-lang/expr/vm"
    16  	funk "github.com/thoas/go-funk"
    17  )
    18  
    19  type Dot struct {
    20  	Query string
    21  	Vars  []interface{}
    22  
    23  	CountQuery string
    24  	CuntVars   []interface{}
    25  }
    26  
    27  // DotItem tells the item details.
    28  type DotItem struct {
    29  	Content []string
    30  	Name    string
    31  	Attrs   map[string]string
    32  }
    33  
    34  var re = regexp.MustCompile(`\s*(\w+)\s*(:\s*(\S+))?`)
    35  
    36  // ParseDotTag parses the tag like name:value age:34 adult to map
    37  // returns the map and main tag's value.
    38  func ParseDotTag(line, prefix, mainTag string) (map[string]string, string) {
    39  	l := strings.TrimSpace(line)
    40  	if !strings.HasPrefix(l, prefix) {
    41  		return nil, ""
    42  	}
    43  
    44  	l = strings.TrimSpace(l[2:])
    45  	m := make(map[string]string)
    46  
    47  	for _, subs := range re.FindAllStringSubmatch(l, -1) {
    48  		m[subs[1]] = subs[3]
    49  	}
    50  
    51  	return m, m[mainTag]
    52  }
    53  
    54  // DotScanner scans the SQL statements from .sql files.
    55  type DotScanner struct {
    56  	line    string
    57  	queries map[string]DotItem
    58  	current DotItem
    59  }
    60  
    61  func (s *DotScanner) createNewItem(name string, tag map[string]string) {
    62  	s.current = DotItem{Name: name, Attrs: tag, Content: make([]string, 0)}
    63  }
    64  
    65  type stateFn func() stateFn
    66  
    67  func (s *DotScanner) initialState() stateFn {
    68  	if tag, name := ParseDotTag(s.line, "--", "name"); name != "" {
    69  		s.createNewItem(name, tag)
    70  
    71  		return s.queryState
    72  	}
    73  
    74  	return s.initialState
    75  }
    76  
    77  func (s *DotScanner) queryState() stateFn {
    78  	if tag, name := ParseDotTag(s.line, "--", "name"); name != "" {
    79  		s.createNewItem(name, tag)
    80  	} else {
    81  		s.appendQueryLine()
    82  	}
    83  
    84  	return s.queryState
    85  }
    86  
    87  func (s *DotScanner) appendQueryLine() {
    88  	line := strings.Trim(s.line, " \t")
    89  	if len(line) == 0 {
    90  		return
    91  	}
    92  
    93  	s.current.Content = append(s.current.Content, strings.TrimSpace(line))
    94  	s.queries[s.current.Name] = s.current
    95  }
    96  
    97  // Run runs the scanner.
    98  func (s *DotScanner) Run(io *bufio.Scanner) map[string]DotItem {
    99  	s.queries = make(map[string]DotItem)
   100  
   101  	for state := s.initialState; io.Scan(); {
   102  		s.line = io.Text()
   103  		state = state()
   104  	}
   105  
   106  	return s.queries
   107  }
   108  
   109  // DotSQL is the set of SQL statements.
   110  type DotSQL struct {
   111  	Dots map[string]DotItem
   112  }
   113  
   114  // Raw returns the query, everything after the --name tag.
   115  func (d DotSQL) Raw(name string) (SQLPart, error) {
   116  	v, err := d.lookupQuery(name)
   117  
   118  	return v, err
   119  }
   120  
   121  func (d DotSQL) lookupQuery(name string) (query SQLPart, err error) {
   122  	s, ok := d.Dots[name]
   123  	if !ok {
   124  		return nil, fmt.Errorf("dotsql: '%s' could not be found", name) // nolint:goerr113
   125  	}
   126  
   127  	query, err = s.DynamicSQL()
   128  
   129  	return query, err
   130  }
   131  
   132  // RawSQL returns the raw SQL.
   133  func (d DotItem) RawSQL() string {
   134  	delimiter := d.Attrs["delimiter"]
   135  	if delimiter == "" {
   136  		delimiter = ";"
   137  	}
   138  
   139  	return TrimSQL(strings.Join(d.Content, "\n"), delimiter)
   140  }
   141  
   142  // TrimSQL trims the delimiter from the string s.
   143  func TrimSQL(s, delimiter string) string {
   144  	s = strings.TrimSpace(s)
   145  
   146  	for strings.HasPrefix(s, delimiter) || strings.HasSuffix(s, delimiter) {
   147  		s = strings.TrimPrefix(s, delimiter)
   148  		s = strings.TrimSuffix(s, delimiter)
   149  		s = strings.TrimSpace(s)
   150  	}
   151  
   152  	return s
   153  }
   154  
   155  // DynamicSQL returns the dynamic SQL.
   156  func (d DotItem) DynamicSQL() (SQLPart, error) {
   157  	lines := ConvertSQLLines(d.Content)
   158  
   159  	_, part, err := ParseDynamicSQL(lines)
   160  	if err != nil {
   161  		return nil, err
   162  	}
   163  
   164  	if err := part.Compile(); err != nil {
   165  		return nil, err
   166  	}
   167  
   168  	return &PostProcessingSQLPart{
   169  		Part:  part,
   170  		Attrs: d.Attrs,
   171  	}, nil
   172  }
   173  
   174  // DotSQLLoad imports sql queries from any io.Reader.
   175  func DotSQLLoad(r io.Reader) (*DotSQL, error) {
   176  	return &DotSQL{(&DotScanner{}).Run(bufio.NewScanner(r))}, nil
   177  }
   178  
   179  // DotSQLLoadFile imports SQL queries from the file.
   180  func DotSQLLoadFile(sqlFile string) (*DotSQL, error) {
   181  	f, err := os.Open(sqlFile)
   182  	if err != nil {
   183  		return nil, err
   184  	}
   185  
   186  	defer f.Close()
   187  
   188  	return DotSQLLoad(f)
   189  }
   190  
   191  // DotSQLLoadString imports SQL queries from the string.
   192  func DotSQLLoadString(s string) (*DotSQL, error) { return DotSQLLoad(bytes.NewBufferString(s)) }
   193  
   194  // SQLPart defines the dynamic SQL part.
   195  type SQLPart interface {
   196  	// Compile compiles the condition int advance.
   197  	Compile() error
   198  	// Eval evaluates the SQL part to a real SQL.
   199  	Eval(m map[string]interface{}) (string, error)
   200  	// Raw returns the raw content.
   201  	Raw() string
   202  }
   203  
   204  // PostProcessingSQLPart defines the SQLPart for post-processing like delimiter trimming.
   205  type PostProcessingSQLPart struct {
   206  	Part  SQLPart
   207  	Attrs map[string]string
   208  }
   209  
   210  // Compile compiles the condition int advance.
   211  func (p *PostProcessingSQLPart) Compile() error {
   212  	return p.Part.Compile()
   213  }
   214  
   215  // Eval evaluated the dynamic sql with env.
   216  func (p *PostProcessingSQLPart) Eval(env map[string]interface{}) (string, error) {
   217  	eval, err := p.Part.Eval(env)
   218  	if err != nil {
   219  		return "", err
   220  	}
   221  
   222  	delimiter := mapp.GetStringOr(p.Attrs, "delimiter", ";")
   223  
   224  	return TrimSQL(eval, delimiter), nil
   225  }
   226  
   227  // Raw returns the raw content.
   228  func (p *PostProcessingSQLPart) Raw() string {
   229  	raw := p.Part.Raw()
   230  
   231  	delimiter := mapp.GetStringOr(p.Attrs, "delimiter", ";")
   232  
   233  	return TrimSQL(raw, delimiter)
   234  }
   235  
   236  // LiteralPart define literal SQL part that no eval required.
   237  type LiteralPart struct {
   238  	Literal string
   239  }
   240  
   241  // MakeLiteralPart makes a MakeLiteralPart.
   242  func MakeLiteralPart(s string) SQLPart {
   243  	return &LiteralPart{Literal: s}
   244  }
   245  
   246  // Compile compiles the condition int advance.
   247  func (p *LiteralPart) Compile() error { return nil }
   248  
   249  // Raw returns the raw content.
   250  func (p *LiteralPart) Raw() string { return p.Literal }
   251  
   252  // Eval evaluates the SQL part to a real SQL.
   253  func (p *LiteralPart) Eval(map[string]interface{}) (string, error) { return p.Literal, nil }
   254  
   255  // IfCondition defines a single condition that makes up a conditions-set for IfPart/SwitchPart.
   256  type IfCondition struct {
   257  	Expr         string
   258  	CompiledExpr *vm.Program
   259  	Part         SQLPart
   260  }
   261  
   262  // IfPart is the part that has the format of if ... else if ... else ... end.
   263  type IfPart struct {
   264  	Conditions []IfCondition
   265  	Else       SQLPart
   266  }
   267  
   268  // Compile compiles the condition int advance.
   269  func (p *IfPart) Compile() (err error) {
   270  	for i, c := range p.Conditions {
   271  		if c.CompiledExpr, err = expr.Compile(c.Expr); err != nil {
   272  			return err
   273  		}
   274  
   275  		p.Conditions[i] = c
   276  	}
   277  
   278  	return nil
   279  }
   280  
   281  // MakeIfPart makes a new IfPart.
   282  func MakeIfPart() *IfPart {
   283  	return &IfPart{Conditions: make([]IfCondition, 0)}
   284  }
   285  
   286  // AddElse adds an else part to the IfPart.
   287  func (p *IfPart) AddElse(part SQLPart) {
   288  	p.Else = part
   289  }
   290  
   291  // AddCondition adds a condition to the IfPart.
   292  func (p *IfPart) AddCondition(conditionExpr string, part SQLPart) {
   293  	p.Conditions = append(p.Conditions, IfCondition{
   294  		Expr: conditionExpr,
   295  		Part: part,
   296  	})
   297  }
   298  
   299  // Eval evaluates the SQL part to a real SQL.
   300  func (p *IfPart) Eval(env map[string]interface{}) (string, error) {
   301  	for _, c := range p.Conditions {
   302  		output, err := expr.Run(c.CompiledExpr, env)
   303  		if err != nil {
   304  			return "", err
   305  		}
   306  
   307  		if yes, ok := output.(bool); !ok {
   308  			return "", fmt.Errorf("%s is not a bool expression", c.Expr) // nolint:goerr113
   309  		} else if yes {
   310  			return c.Part.Eval(env)
   311  		}
   312  	}
   313  
   314  	if p.Else != nil {
   315  		return p.Else.Eval(env)
   316  	}
   317  
   318  	return "", nil
   319  }
   320  
   321  // Raw returns the raw content.
   322  func (p *IfPart) Raw() string {
   323  	raw := ""
   324  
   325  	for _, c := range p.Conditions {
   326  		raw += c.Expr + "\n" + c.Part.Raw()
   327  	}
   328  
   329  	if p.Else != nil {
   330  		raw += "\n" + p.Else.Raw()
   331  	}
   332  
   333  	return raw
   334  }
   335  
   336  // MultiPart is the multi SQLParts.
   337  type MultiPart struct {
   338  	Parts []SQLPart
   339  }
   340  
   341  // MakeMultiPart makes MultiPart.
   342  func MakeMultiPart() *MultiPart {
   343  	return &MultiPart{Parts: make([]SQLPart, 0)}
   344  }
   345  
   346  // Eval evaluates the SQL part to a real SQL.
   347  func (p *MultiPart) Eval(env map[string]interface{}) (string, error) {
   348  	value := ""
   349  
   350  	for _, p := range p.Parts {
   351  		v, err := p.Eval(env)
   352  		if err != nil {
   353  			return "", err
   354  		}
   355  
   356  		if value != "" {
   357  			value += " "
   358  		}
   359  
   360  		value += v
   361  	}
   362  
   363  	return value, nil
   364  }
   365  
   366  // Raw returns the raw content.
   367  func (p *MultiPart) Raw() string {
   368  	raw := ""
   369  
   370  	for _, c := range p.Parts {
   371  		if raw != "" {
   372  			raw += "\n"
   373  		}
   374  
   375  		raw += c.Raw()
   376  	}
   377  
   378  	return raw
   379  }
   380  
   381  // AddPart adds a part to the current MultiPart.
   382  func (p *MultiPart) AddPart(part SQLPart) {
   383  	p.Parts = append(p.Parts, part)
   384  }
   385  
   386  // Compile compiles the condition int advance.
   387  func (p *MultiPart) Compile() error {
   388  	for _, part := range p.Parts {
   389  		if err := part.Compile(); err != nil {
   390  			return err
   391  		}
   392  	}
   393  
   394  	return nil
   395  }
   396  
   397  var (
   398  	_ SQLPart = (*LiteralPart)(nil)
   399  	_ SQLPart = (*IfPart)(nil)
   400  	_ SQLPart = (*MultiPart)(nil)
   401  	_ SQLPart = (*PostProcessingSQLPart)(nil)
   402  )
   403  
   404  // ParseDynamicSQL parses the dynamic sqls to structured SQLPart.
   405  func ParseDynamicSQL(lines []string, terminators ...string) (int, SQLPart, error) {
   406  	multiPart := MakeMultiPart()
   407  
   408  	for i := 0; i < len(lines); i++ {
   409  		l := lines[i]
   410  
   411  		if !strings.HasPrefix(l, "--") {
   412  			multiPart.AddPart(MakeLiteralPart(l))
   413  			continue
   414  		}
   415  
   416  		commentLine := strings.TrimSpace(l[2:])
   417  		word := firstWord(commentLine, 1)
   418  		parser := CreateParser(word, strings.TrimSpace(commentLine[len(word):]))
   419  
   420  		if parser == nil { // no parser found, ignore comment line
   421  			if funk.ContainsString(terminators, word) {
   422  				return i, multiPart, nil
   423  			}
   424  
   425  			continue
   426  		}
   427  
   428  		partLines, part, err := parser.Parse(lines[i+1:])
   429  		if err != nil {
   430  			return 0, nil, err
   431  		}
   432  
   433  		multiPart.AddPart(part)
   434  
   435  		i += partLines - 1
   436  	}
   437  
   438  	return len(lines), multiPart, nil
   439  }
   440  
   441  // ConvertSQLLines converts the inline comments to line comments
   442  // and merge to uncomment lines together.
   443  func ConvertSQLLines(lines []string) []string {
   444  	inlineCommentMode := false
   445  	noneComment := ""
   446  	inlineCommentContent := ""
   447  	converted := make([]string, 0)
   448  
   449  	for _, l := range lines {
   450  		if strings.HasPrefix(l, "--") {
   451  			if noneComment != "" {
   452  				converted = append(converted, noneComment)
   453  				noneComment = ""
   454  			}
   455  
   456  			converted = append(converted, l)
   457  
   458  			continue
   459  		}
   460  
   461  	inlineCommentGo:
   462  		l = strings.TrimSpace(l)
   463  
   464  		if l == "" {
   465  			continue
   466  		}
   467  
   468  		if !inlineCommentMode {
   469  			inlineCommentStart := strings.Index(l, "/*")
   470  			if inlineCommentStart < 0 {
   471  				noneComment = appendNoneComment(noneComment, l)
   472  
   473  				continue
   474  			}
   475  
   476  			inlineCommentMode = true
   477  
   478  			if before := strings.TrimSpace(l[0:inlineCommentStart]); before != "" {
   479  				noneComment = appendNoneComment(noneComment, before)
   480  			}
   481  
   482  			l = l[inlineCommentStart+2:]
   483  		}
   484  
   485  		inlineCommentStop := strings.Index(l, "*/")
   486  		if inlineCommentStop >= 0 {
   487  			inlineCommentMode = false
   488  			inlineCommentContent += l[:inlineCommentStop]
   489  
   490  			if inlineComment := strings.TrimSpace(inlineCommentContent); inlineComment != "" {
   491  				if noneComment != "" {
   492  					converted = append(converted, noneComment)
   493  					noneComment = ""
   494  				}
   495  
   496  				converted = append(converted, "-- "+inlineComment)
   497  			}
   498  
   499  			l = l[inlineCommentStop+2:]
   500  			inlineCommentContent = ""
   501  
   502  			goto inlineCommentGo
   503  		}
   504  
   505  		inlineCommentContent += l
   506  	}
   507  
   508  	if noneComment != "" {
   509  		converted = append(converted, noneComment)
   510  	}
   511  
   512  	return converted
   513  }
   514  
   515  func appendNoneComment(noneComment string, l string) string {
   516  	if noneComment != "" {
   517  		noneComment += "\n"
   518  	}
   519  
   520  	return noneComment + l
   521  }
   522  
   523  // SQLPartParser defines the parser of SQLPart.
   524  type SQLPartParser interface {
   525  	// Parse parses the lines to SQLPart.
   526  	Parse(lines []string) (partLines int, part SQLPart, err error)
   527  }
   528  
   529  // IfSQLPartParser defines the Parser of IfPart.
   530  type IfSQLPartParser struct {
   531  	Condition string
   532  	Else      string
   533  }
   534  
   535  // MakeIfSQLPartParser makes a IfSQLPartParser.
   536  func MakeIfSQLPartParser(condition string) *IfSQLPartParser {
   537  	return &IfSQLPartParser{
   538  		Condition: condition,
   539  	}
   540  }
   541  
   542  // Parse parses the lines to SQLPart.
   543  func (p *IfSQLPartParser) Parse(lines []string) (partLines int, part SQLPart, err error) {
   544  	ifPart := MakeIfPart()
   545  	condition := p.Condition
   546  
   547  	for i := 0; i < len(lines); i++ {
   548  		l := lines[i]
   549  
   550  		if !strings.HasPrefix(l, "--") {
   551  			ifPart.AddCondition(condition, MakeLiteralMultiPart(l))
   552  			continue
   553  		}
   554  
   555  		commentLine := strings.TrimSpace(l[2:])
   556  		word := firstWord(commentLine, 1)
   557  
   558  		if word == "end" {
   559  			return i + 2 /*包括if 行*/, ifPart, nil
   560  		}
   561  
   562  		if word == "elseif" {
   563  			condition = strings.TrimSpace(commentLine[len(word):])
   564  
   565  			processLines, sqlPart, err := ParseDynamicSQL(lines[i+1:], "end", "elseif", "else")
   566  			if err != nil {
   567  				return 0, nil, err
   568  			}
   569  
   570  			ifPart.AddCondition(condition, sqlPart)
   571  
   572  			i += processLines
   573  
   574  			continue
   575  		}
   576  
   577  		if word == "else" {
   578  			processLines, sqlPart, err := ParseDynamicSQL(lines[i+1:], "end")
   579  			if err != nil {
   580  				return 0, nil, err
   581  			}
   582  
   583  			ifPart.AddElse(sqlPart)
   584  
   585  			return i + 2 + processLines, ifPart, nil
   586  		}
   587  
   588  		processLines, sqlPart, err := ParseDynamicSQL(lines[i:], "end", "elseif", "else")
   589  		if err != nil {
   590  			return 0, nil, err
   591  		}
   592  
   593  		ifPart.AddCondition(condition, sqlPart)
   594  
   595  		i += processLines - 1
   596  	}
   597  
   598  	return 0, nil, fmt.Errorf("no end found for if expr") // nolint:goerr113
   599  }
   600  
   601  // MakeLiteralMultiPart makes a MultiPart.
   602  func MakeLiteralMultiPart(l string) *MultiPart {
   603  	return &MultiPart{Parts: []SQLPart{&LiteralPart{l}}}
   604  }
   605  
   606  var _ SQLPartParser = (*IfSQLPartParser)(nil)
   607  
   608  // CreateParser creates a SQLPartParser.
   609  // If no parser found, nil returned.
   610  func CreateParser(word string, l string) SQLPartParser {
   611  	if word == "if" {
   612  		return MakeIfSQLPartParser(l)
   613  	}
   614  
   615  	return nil
   616  }
   617  
   618  func firstWord(value string, count int) string {
   619  	// Loop over all indexes in the string.
   620  	for i := range value {
   621  		// If we encounter a space, reduce the count.
   622  		if unicode.IsSpace(rune(value[i])) {
   623  			count--
   624  			// When no more words required, return a substring.
   625  			if count == 0 {
   626  				return value[0:i]
   627  			}
   628  		}
   629  	}
   630  
   631  	// Return the entire string.
   632  	return value
   633  }