github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqx/daoparse.go (about)

     1  package sqx
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"regexp"
     7  	"strconv"
     8  	"strings"
     9  
    10  	"github.com/bingoohuang/gg/pkg/sqlparse/sqlparser"
    11  	"github.com/bingoohuang/gg/pkg/ss"
    12  )
    13  
    14  func (p *SQLParsed) checkFuncInOut(numIn int, f StructField) error {
    15  	if numIn == 0 && !p.isBindBy(ByNone) {
    16  		return fmt.Errorf("sql %s required bind varialbes, but the func %v has none", p.RawStmt, f.Type)
    17  	}
    18  
    19  	if numIn != 1 && p.isBindBy(ByName) {
    20  		return fmt.Errorf("sql %s required named varialbes, but the func %v has non-one arguments",
    21  			p.RawStmt, f.Type)
    22  	}
    23  
    24  	if p.isBindBy(BySeq, ByAuto) {
    25  		if numIn < p.MaxSeq {
    26  			// nolint:goerr113
    27  			return fmt.Errorf("sql %s required max %d vars, but the func %v has only %d arguments",
    28  				p.RawStmt, p.MaxSeq, f.Type, numIn)
    29  		}
    30  	}
    31  
    32  	return nil
    33  }
    34  
    35  type bindBy int
    36  
    37  const (
    38  	// ByNone means no bind params.
    39  	ByNone bindBy = iota
    40  	// ByAuto means auto seq for bind params.
    41  	ByAuto
    42  	// BySeq means specific seq for bind params.
    43  	BySeq
    44  	// ByName means named bind params.
    45  	ByName
    46  )
    47  
    48  func (b bindBy) String() string {
    49  	switch b {
    50  	case ByNone:
    51  		return "byNone"
    52  	case ByAuto:
    53  		return "byAuto"
    54  	case BySeq:
    55  		return "bySeq"
    56  	case ByName:
    57  		return "byName"
    58  	default:
    59  		return "Unknown"
    60  	}
    61  }
    62  
    63  // SQLParsed is the structure of the parsed SQL.
    64  type SQLParsed struct {
    65  	ID      string
    66  	SQL     SQLPart
    67  	BindBy  bindBy
    68  	Vars    []string
    69  	MaxSeq  int
    70  	IsQuery bool
    71  
    72  	RawStmt string
    73  
    74  	fp     FieldParts
    75  	runSQL string
    76  	opt    *CreateDaoOpt
    77  }
    78  
    79  func (p SQLParsed) replaceQuery(db *sql.DB, query string) (string, error) {
    80  	if ss.AnyOfFold(ss.FirstWord(query), "CREATE") {
    81  		return query, nil
    82  	}
    83  
    84  	dbType := sqlparser.ToDBType(DriverName(db.Driver()))
    85  	cr, err := dbType.Convert(query)
    86  	return cr.ConvertQuery(), err
    87  }
    88  
    89  func (p SQLParsed) isBindBy(by ...bindBy) bool {
    90  	for _, b := range by {
    91  		if p.BindBy == b {
    92  			return true
    93  		}
    94  	}
    95  
    96  	return false
    97  }
    98  
    99  var sqlre = regexp.MustCompile(`'?:\w*'?`)
   100  
   101  type FieldParts struct {
   102  	fieldParts []FieldPart
   103  	fieldVars  []interface{}
   104  }
   105  
   106  func (p *FieldParts) AddFieldSqlPart(part string, varVal []interface{}, joinedSep bool) {
   107  	p.fieldParts = append(p.fieldParts, FieldPart{
   108  		PartSQL:        part,
   109  		BindVal:        varVal,
   110  		PartSQLPlTimes: strings.Count(part, "?"),
   111  		JoinedSep:      joinedSep,
   112  	})
   113  }
   114  
   115  // ParseSQL parses the sql.
   116  func ParseSQL(name, stmt string) (*SQLParsed, error) {
   117  	p := &SQLParsed{ID: name}
   118  
   119  	if err := p.fastParseSQL(stmt); err != nil {
   120  		return nil, err
   121  	}
   122  
   123  	return p, nil
   124  }
   125  
   126  func (p *SQLParsed) fastParseSQL(stmt string) error {
   127  	p.Vars = make([]string, 0)
   128  	p.RawStmt = sqlre.ReplaceAllStringFunc(stmt, func(v string) string {
   129  		if v[:1] == "'" {
   130  			v = v[2:]
   131  		} else {
   132  			v = v[1:]
   133  		}
   134  		v = strings.TrimSuffix(v, "'")
   135  
   136  		p.Vars = append(p.Vars, v)
   137  		return "?"
   138  	})
   139  
   140  	var err error
   141  
   142  	p.BindBy, p.MaxSeq, err = parseBindBy(p.ID, p.Vars)
   143  	if err != nil {
   144  		return err
   145  	}
   146  
   147  	_, p.IsQuery = IsQuerySQL(p.RawStmt)
   148  	return nil
   149  }
   150  
   151  // IsQuerySQL tests a sql is a query or not.
   152  func IsQuerySQL(query string) (string, bool) {
   153  	switch f := ss.FirstWord(query); strings.ToUpper(f) {
   154  	case "SELECT", "SHOW", "DESC", "DESCRIBE", "EXPLAIN":
   155  		return f, true
   156  	default: // "INSERT", "DELETE", "UPDATE", "SET", "REPLACE":
   157  		return f, false
   158  	}
   159  }
   160  
   161  func (p *SQLParsed) parseSQL(runSQl string) error {
   162  	p.Vars = make([]string, 0)
   163  	p.runSQL = sqlre.ReplaceAllStringFunc(runSQl, func(v string) string {
   164  		if v[:1] == "'" {
   165  			v = v[2:]
   166  		} else {
   167  			v = v[1:]
   168  		}
   169  		v = strings.TrimSuffix(v, "'")
   170  		p.Vars = append(p.Vars, v)
   171  		return "?"
   172  	})
   173  
   174  	if len(p.fp.fieldParts) > 0 {
   175  		parsed, err := sqlparser.Parse(p.runSQL)
   176  		if err != nil {
   177  			return err
   178  		}
   179  
   180  		w, hasWhere := parsed.(sqlparser.IWhere)
   181  		if hasWhere {
   182  			hasWhere = w.GetWhere() != nil
   183  		}
   184  
   185  		for i, f := range p.fp.fieldParts {
   186  			if f.JoinedSep {
   187  				if i == 0 && !hasWhere {
   188  					p.runSQL += " where " + f.PartSQL
   189  				} else {
   190  					p.runSQL += " and " + f.PartSQL
   191  				}
   192  			} else {
   193  				p.runSQL += " " + f.PartSQL
   194  			}
   195  
   196  			p.Vars = append(p.Vars, f.VarMarks()...)
   197  			p.fp.fieldVars = append(p.fp.fieldVars, f.Vars()...)
   198  		}
   199  	}
   200  
   201  	return nil
   202  }
   203  
   204  type FieldPart struct {
   205  	PartSQL        string
   206  	BindVal        []interface{}
   207  	PartSQLPlTimes int
   208  	JoinedSep      bool
   209  }
   210  
   211  func (p FieldPart) VarMarks() []string {
   212  	vars := make([]string, p.PartSQLPlTimes)
   213  
   214  	for i := 0; i < p.PartSQLPlTimes; i++ {
   215  		vars[i] = "?"
   216  	}
   217  
   218  	return vars
   219  }
   220  
   221  func (p FieldPart) Vars() []interface{} {
   222  	vars := make([]interface{}, p.PartSQLPlTimes)
   223  
   224  	for i := 0; i < p.PartSQLPlTimes; i++ {
   225  		vars[i] = p.BindVal[i]
   226  	}
   227  
   228  	return vars
   229  }
   230  
   231  func parseBindBy(sqlName string, vars []string) (bindBy bindBy, maxSeq int, err error) {
   232  	bindBy = ByNone
   233  
   234  	for _, v := range vars {
   235  		if v == "" {
   236  			if bindBy == ByAuto {
   237  				maxSeq++
   238  				continue
   239  			}
   240  
   241  			if bindBy != ByNone {
   242  				// nolint:goerr113
   243  				return 0, 0, fmt.Errorf("[%s] illegal mixed bind mod (%v-%v)", sqlName, bindBy, ByAuto)
   244  			}
   245  
   246  			bindBy = ByAuto
   247  			maxSeq++
   248  
   249  			continue
   250  		}
   251  
   252  		n, err := strconv.Atoi(v)
   253  		if err == nil {
   254  			if bindBy == BySeq {
   255  				if maxSeq < n {
   256  					maxSeq = n
   257  				}
   258  
   259  				continue
   260  			}
   261  
   262  			if bindBy != ByNone {
   263  				// nolint:goerr113
   264  				return 0, 0, fmt.Errorf("[%s] illegal mixed bind mod (%v-%v)", sqlName, bindBy, BySeq)
   265  			}
   266  
   267  			bindBy = BySeq
   268  			maxSeq = n
   269  
   270  			continue
   271  		}
   272  
   273  		if bindBy == ByName {
   274  			maxSeq++
   275  			continue
   276  		}
   277  
   278  		if bindBy != ByNone {
   279  			// nolint:goerr113
   280  			return 0, 0, fmt.Errorf("[%s] illegal mixed bind mod (%v-%v)", sqlName, bindBy, ByName)
   281  		}
   282  
   283  		bindBy = ByName
   284  		maxSeq++
   285  	}
   286  
   287  	return bindBy, maxSeq, nil
   288  }