github.com/cockroachdb/cockroachdb-parser@v0.23.3-0.20240213214944-911057d40c9a/pkg/sql/colexec/execgen/cmd/execgen/count_agg_gen.go (about)

     1  // Copyright 2019 The Cockroach Authors.
     2  //
     3  // Use of this software is governed by the Business Source License
     4  // included in the file licenses/BSL.txt.
     5  //
     6  // As of the Change Date specified in that file, in accordance with
     7  // the Business Source License, use of this software will be governed
     8  // by the Apache License, Version 2.0, included in the file
     9  // licenses/APL.txt.
    10  
    11  package main
    12  
    13  import (
    14  	"io"
    15  	"strings"
    16  	"text/template"
    17  
    18  	"github.com/cockroachdb/cockroachdb-parser/pkg/sql/types"
    19  )
    20  
    21  const countAggTmpl = "pkg/sql/colexec/colexecagg/count_agg_tmpl.go"
    22  
    23  func genCountAgg(inputFileContents string, wr io.Writer) error {
    24  	s := strings.ReplaceAll(inputFileContents, "_COUNTKIND", "{{.CountKind}}")
    25  
    26  	accumulateSum := makeFunctionRegex("_ACCUMULATE_COUNT", 5)
    27  	s = accumulateSum.ReplaceAllString(s, `{{template "accumulateCount" buildDict "Global" . "ColWithNulls" $4 "HasSel" $5}}`)
    28  
    29  	removeRow := makeFunctionRegex("_REMOVE_ROW", 4)
    30  	s = removeRow.ReplaceAllString(s, `{{template "removeRow" buildDict "Global" . "ColWithNulls" $4}}`)
    31  
    32  	s = replaceManipulationFuncs(s)
    33  
    34  	tmpl, err := template.New("count_agg").Funcs(template.FuncMap{"buildDict": buildDict}).Parse(s)
    35  	if err != nil {
    36  		return err
    37  	}
    38  
    39  	return tmpl.Execute(wr, []struct {
    40  		aggTmplInfoBase
    41  		CountKind string
    42  	}{
    43  		// "Rows" count aggregate performs COUNT(*) aggregation, which counts
    44  		// every row in the result unconditionally.
    45  		{
    46  			aggTmplInfoBase: aggTmplInfoBase{canonicalTypeFamily: types.IntFamily},
    47  			CountKind:       "Rows",
    48  		},
    49  		// "" ("pure") count aggregate performs COUNT(col) aggregation, which
    50  		// counts every row in the result where the value of col is not null.
    51  		{
    52  			aggTmplInfoBase: aggTmplInfoBase{canonicalTypeFamily: types.IntFamily},
    53  			CountKind:       "",
    54  		},
    55  	})
    56  }
    57  
    58  func init() {
    59  	registerAggGenerator(genCountAgg, "count_agg.eg.go", countAggTmpl, true /* genWindowVariant */)
    60  }