github.com/bingoohuang/gg@v0.0.0-20240325092523-45da7dee9335/pkg/sqx/count.go (about)

     1  package sqx
     2  
     3  import (
     4  	"errors"
     5  
     6  	"github.com/bingoohuang/gg/pkg/sqlparse/sqlparser"
     7  )
     8  
     9  // ErrNotSelect shows an error that the query is not a select statement.
    10  var ErrNotSelect = errors.New("not a select query statement")
    11  
    12  // CreateCount creates a count query sql.
    13  func (s SQL) CreateCount() (*SQL, error) {
    14  	parsed, err := sqlparser.Parse(s.Q)
    15  	if err != nil {
    16  		return nil, err
    17  	}
    18  
    19  	sel, ok := parsed.(*sqlparser.Select)
    20  	if !ok {
    21  		return nil, ErrNotSelect
    22  	}
    23  
    24  	limitVarsCount := 0
    25  	if sel.Limit != nil {
    26  		limitVarsCount++
    27  		if sel.Limit.Offset != nil {
    28  			limitVarsCount++
    29  		}
    30  	}
    31  
    32  	sel.SelectExprs = countStar
    33  	sel.OrderBy = nil
    34  	sel.Having = nil
    35  	sel.Limit = nil
    36  
    37  	c := &SQL{
    38  		Q:    sqlparser.String(sel),
    39  		Vars: s.Vars,
    40  		Ctx:  s.Ctx,
    41  	}
    42  
    43  	if limitVarsCount > 0 && len(s.Vars) >= limitVarsCount {
    44  		c.Vars = s.Vars[:len(s.Vars)-limitVarsCount]
    45  	}
    46  
    47  	return c, nil
    48  }
    49  
    50  var countStar = func() sqlparser.SelectExprs {
    51  	p, _ := sqlparser.Parse(`select count(*)`)
    52  	return p.(*sqlparser.Select).SelectExprs
    53  }()