github.com/whtcorpsinc/MilevaDB-Prod@v0.0.0-20211104133533-f57f4be3b597/dbs/memristed/memex/generator/compare_vec.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  // +build ignore
    15  
    16  package main
    17  
    18  import (
    19  	"bytes"
    20  	"flag"
    21  	"go/format"
    22  	"io/ioutil"
    23  	"log"
    24  	"path/filepath"
    25  	"text/template"
    26  
    27  	. "github.com/whtcorpsinc/milevadb/memex/generator/helper"
    28  )
    29  
    30  const header = `// Copyright 2021 WHTCORPS INC, Inc.
    31  //
    32  // Licensed under the Apache License, Version 2.0 (the "License");
    33  // you may not use this file except in compliance with the License.
    34  // You may obtain a copy of the License at
    35  //
    36  //     http://www.apache.org/licenses/LICENSE-2.0
    37  //
    38  // Unless required by applicable law or agreed to in writing, software
    39  // distributed under the License is distributed on an "AS IS" BASIS,
    40  // See the License for the specific language governing permissions and
    41  // limitations under the License.
    42  
    43  // Code generated by go generate in memex/generator; DO NOT EDIT.
    44  
    45  package memex
    46  `
    47  
    48  const newLine = "\n"
    49  
    50  const builtinCompareImports = `import (
    51  	"github.com/whtcorpsinc/milevadb/types"
    52  	"github.com/whtcorpsinc/milevadb/types/json"
    53  	"github.com/whtcorpsinc/milevadb/soliton/chunk"
    54  )
    55  `
    56  
    57  var builtinCompareVecTpl = template.Must(template.New("").Parse(`
    58  func (b *builtin{{ .compare.CompareName }}{{ .type.TypeName }}Sig) vecEvalInt(input *chunk.Chunk, result *chunk.DeferredCauset) error {
    59  	n := input.NumEvents()
    60  	buf0, err := b.bufSlabPredictor.get(types.ET{{ .type.ETName }}, n)
    61  	if err != nil {
    62  		return err
    63  	}
    64  	defer b.bufSlabPredictor.put(buf0)
    65  	if err := b.args[0].VecEval{{ .type.TypeName }}(b.ctx, input, buf0); err != nil {
    66  		return err
    67  	}
    68  	buf1, err := b.bufSlabPredictor.get(types.ET{{ .type.ETName }}, n)
    69  	if err != nil {
    70  		return err
    71  	}
    72  	defer b.bufSlabPredictor.put(buf1)
    73  	if err := b.args[1].VecEval{{ .type.TypeName }}(b.ctx, input, buf1); err != nil {
    74  		return err
    75  	}
    76  
    77  {{ if .type.Fixed }}
    78  	arg0 := buf0.{{ .type.TypeNameInDeferredCauset }}s()
    79  	arg1 := buf1.{{ .type.TypeNameInDeferredCauset }}s()
    80  {{- end }}
    81  	result.ResizeInt64(n, false)
    82  	result.MergeNulls(buf0, buf1)
    83  	i64s := result.Int64s()
    84  	for i := 0; i < n; i++ {
    85  		if result.IsNull(i) {
    86  			continue
    87  		}
    88  {{- if eq .type.ETName "Json" }}
    89  		val := json.CompareBinary(buf0.GetJSON(i), buf1.GetJSON(i))
    90  {{- else if eq .type.ETName "Real" }}
    91  		val := types.CompareFloat64(arg0[i], arg1[i])
    92  {{- else if eq .type.ETName "String" }}
    93  		val := types.CompareString(buf0.GetString(i), buf1.GetString(i), b.defCauslation)
    94  {{- else if eq .type.ETName "Duration" }}
    95  		val := types.CompareDuration(arg0[i], arg1[i])
    96  {{- else if eq .type.ETName "Datetime" }}
    97  		val := arg0[i].Compare(arg1[i])
    98  {{- else if eq .type.ETName "Decimal" }}
    99  		val := arg0[i].Compare(&arg1[i])
   100  {{- end }}
   101  		if val {{ .compare.Operator }} 0 {
   102  			i64s[i] = 1
   103  		} else {
   104  			i64s[i] = 0
   105  		}
   106  	}
   107  	return nil
   108  }
   109  
   110  func (b *builtin{{ .compare.CompareName }}{{ .type.TypeName }}Sig) vectorized() bool {
   111  	return true
   112  }
   113  `))
   114  
   115  var builtinNullEQCompareVecTpl = template.Must(template.New("").Parse(`
   116  func (b *builtin{{ .compare.CompareName }}{{ .type.TypeName }}Sig) vecEvalInt(input *chunk.Chunk, result *chunk.DeferredCauset) error {
   117  	n := input.NumEvents()
   118  	buf0, err := b.bufSlabPredictor.get(types.ET{{ .type.ETName }}, n)
   119  	if err != nil {
   120  		return err
   121  	}
   122  	defer b.bufSlabPredictor.put(buf0)
   123  	if err := b.args[0].VecEval{{ .type.TypeName }}(b.ctx, input, buf0); err != nil {
   124  		return err
   125  	}
   126  	buf1, err := b.bufSlabPredictor.get(types.ET{{ .type.ETName }}, n)
   127  	if err != nil {
   128  		return err
   129  	}
   130  	defer b.bufSlabPredictor.put(buf1)
   131  	if err := b.args[1].VecEval{{ .type.TypeName }}(b.ctx, input, buf1); err != nil {
   132  		return err
   133  	}
   134  
   135  {{ if .type.Fixed }}
   136  	arg0 := buf0.{{ .type.TypeNameInDeferredCauset }}s()
   137  	arg1 := buf1.{{ .type.TypeNameInDeferredCauset }}s()
   138  {{- end }}
   139  	result.ResizeInt64(n, false)
   140  	i64s := result.Int64s()
   141  	for i := 0; i < n; i++ {
   142  		isNull0 := buf0.IsNull(i)
   143  		isNull1 := buf1.IsNull(i)
   144  		switch {
   145  		case isNull0 && isNull1:
   146  			i64s[i] = 1
   147  		case isNull0 != isNull1:
   148  			i64s[i] = 0
   149  {{- if eq .type.ETName "Json" }}
   150  		case json.CompareBinary(buf0.GetJSON(i), buf1.GetJSON(i)) == 0:
   151  {{- else if eq .type.ETName "Real" }}
   152  		case types.CompareFloat64(arg0[i], arg1[i]) == 0:
   153  {{- else if eq .type.ETName "String" }}
   154  		case types.CompareString(buf0.GetString(i), buf1.GetString(i), b.defCauslation) == 0:
   155  {{- else if eq .type.ETName "Duration" }}
   156  		case types.CompareDuration(arg0[i], arg1[i]) == 0:
   157  {{- else if eq .type.ETName "Datetime" }}
   158  		case arg0[i].Compare(arg1[i]) == 0:
   159  {{- else if eq .type.ETName "Decimal" }}
   160  		case arg0[i].Compare(&arg1[i]) == 0:
   161  {{- end }}
   162  			i64s[i] = 1
   163  		}
   164  	}
   165  	return nil
   166  }
   167  
   168  func (b *builtin{{ .compare.CompareName }}{{ .type.TypeName }}Sig) vectorized() bool {
   169  	return true
   170  }
   171  `))
   172  
   173  var builtinCoalesceCompareVecTpl = template.Must(template.New("").Parse(`
   174  // NOTE: Coalesce just return the first non-null item, but vectorization do each item, which would incur additional errors. If this case happen, 
   175  // the vectorization falls back to the scalar execution.
   176  func (b *builtin{{ .compare.CompareName }}{{ .type.TypeName }}Sig) fallbackEval{{ .type.TypeName }}(input *chunk.Chunk, result *chunk.DeferredCauset) error {
   177  	n := input.NumEvents()
   178  	{{ if .type.Fixed }}
   179  	x := result.{{ .type.TypeNameInDeferredCauset }}s()
   180  	for i := 0; i < n; i++ {
   181  		res, isNull, err := b.eval{{ .type.TypeName }}(input.GetEvent(i))
   182  		if err != nil {
   183  			return err
   184  		}
   185  		result.SetNull(i, isNull)
   186  		if isNull {
   187  			continue
   188  		}
   189  		{{ if eq .type.TypeName "Decimal" }}
   190  			x[i] = *res
   191  		{{ else if eq .type.TypeName "Duration" }}
   192  			x[i] = res.Duration
   193  		{{ else }}
   194  			x[i] = res
   195  		{{ end }}
   196  	}
   197  	{{ else }}
   198  	result.Reserve{{ .type.TypeNameInDeferredCauset }}(n)
   199  	for i := 0; i < n; i++ {
   200  		res, isNull, err := b.eval{{ .type.TypeName }}(input.GetEvent(i))
   201  		if err != nil {
   202  			return err
   203  		}
   204  		if isNull {
   205  			result.AppendNull()
   206  			continue
   207  		}
   208  		result.Append{{ .type.TypeNameInDeferredCauset }}(res)
   209  	}
   210  	{{ end -}}
   211  	return nil
   212  }
   213  
   214  {{ if .type.Fixed }}
   215  func (b *builtin{{ .compare.CompareName }}{{ .type.TypeName }}Sig) vecEval{{ .type.TypeName }}(input *chunk.Chunk, result *chunk.DeferredCauset) error {
   216  	n := input.NumEvents()
   217  	result.Resize{{ .type.TypeNameInDeferredCauset }}(n, true)
   218  	i64s := result.{{ .type.TypeNameInDeferredCauset }}s()
   219  	buf1, err := b.bufSlabPredictor.get(types.ET{{ .type.ETName }}, n)
   220  	if err != nil {
   221  		return err
   222  	}
   223  	defer b.bufSlabPredictor.put(buf1)
   224  	sc := b.ctx.GetStochastikVars().StmtCtx
   225  	beforeWarns := sc.WarningCount()
   226  	for j := 0; j < len(b.args); j++{
   227  		err := b.args[j].VecEval{{ .type.TypeName }}(b.ctx, input, buf1)
   228  		afterWarns := sc.WarningCount()
   229  		if err != nil || afterWarns > beforeWarns {
   230  			if afterWarns > beforeWarns {
   231  				sc.TruncateWarnings(int(beforeWarns))
   232  			}
   233  			return b.fallbackEval{{ .type.TypeName }}(input, result)
   234  		}
   235  		args := buf1.{{ .type.TypeNameInDeferredCauset }}s()
   236  		for i := 0; i < n; i++ {
   237  			if !buf1.IsNull(i) && result.IsNull(i) {
   238  				i64s[i] = args[i]
   239  				result.SetNull(i, false)
   240  			}
   241  		}
   242  	}
   243  	return nil
   244  }
   245  {{ else }}
   246  func (b *builtin{{ .compare.CompareName }}{{ .type.TypeName }}Sig) vecEval{{ .type.TypeName }}(input *chunk.Chunk, result *chunk.DeferredCauset) error {
   247  	n := input.NumEvents()
   248  	argLen := len(b.args)
   249  
   250  	bufs := make([]*chunk.DeferredCauset, argLen)
   251  	sc := b.ctx.GetStochastikVars().StmtCtx
   252  	beforeWarns := sc.WarningCount()
   253  	for i := 0; i < argLen; i++ {
   254  		buf, err := b.bufSlabPredictor.get(types.ETInt, n)
   255  		if err != nil {
   256  			return err
   257  		}
   258  		defer b.bufSlabPredictor.put(buf)
   259  		err = b.args[i].VecEval{{ .type.TypeName }}(b.ctx, input, buf)
   260  		afterWarns := sc.WarningCount()
   261  		if err != nil || afterWarns > beforeWarns {
   262  			if afterWarns > beforeWarns {
   263  				sc.TruncateWarnings(int(beforeWarns))
   264  			}
   265  			return b.fallbackEval{{ .type.TypeName }}(input, result)
   266  		}
   267  		bufs[i]=buf
   268  	}
   269  	result.Reserve{{ .type.TypeName }}(n)
   270  
   271  	for i := 0; i < n; i++ {
   272  		for j := 0; j < argLen; j++ {
   273  			if !bufs[j].IsNull(i) {
   274  				result.Append{{ .type.TypeName }}(bufs[j].Get{{ .type.TypeName }}(i))
   275  				break
   276  			}
   277  			if j == argLen-1 && bufs[j].IsNull(i) {
   278  				result.AppendNull()
   279  			}
   280  		}
   281  	}
   282  	return nil
   283  }
   284  
   285  
   286  {{ end }}
   287  
   288  func (b *builtin{{ .compare.CompareName }}{{ .type.TypeName }}Sig) vectorized() bool {
   289  	return true
   290  }
   291  
   292  `))
   293  
   294  const builtinCompareVecTestHeader = `import (
   295  	"testing"
   296  
   297  	. "github.com/whtcorpsinc/check"
   298  	"github.com/whtcorpsinc/BerolinaSQL/ast"
   299  	"github.com/whtcorpsinc/milevadb/types"
   300  )
   301  
   302  var vecGeneratedBuiltinCompareCases = map[string][]vecExprBenchCase{
   303  `
   304  
   305  var builtinCompareVecTestFuncHeader = template.Must(template.New("").Parse(`	ast.{{ .CompareName }}: {
   306  `))
   307  
   308  var builtinCompareVecTestCase = template.Must(template.New("").Parse(`		{retEvalType: types.ETInt, childrenTypes: []types.EvalType{types.ET{{ .ETName }}, types.ET{{ .ETName }}}},
   309  `))
   310  
   311  var builtinCompareVecTestFuncTail = `	},
   312  `
   313  
   314  var builtinCompareVecTestTail = `}
   315  
   316  func (s *testEvaluatorSuite) TestVectorizedGeneratedBuiltinCompareEvalOneVec(c *C) {
   317  	testVectorizedEvalOneVec(c, vecGeneratedBuiltinCompareCases)
   318  }
   319  
   320  func (s *testEvaluatorSuite) TestVectorizedGeneratedBuiltinCompareFunc(c *C) {
   321  	testVectorizedBuiltinFunc(c, vecGeneratedBuiltinCompareCases)
   322  }
   323  
   324  func BenchmarkVectorizedGeneratedBuiltinCompareEvalOneVec(b *testing.B) {
   325  	benchmarkVectorizedEvalOneVec(b, vecGeneratedBuiltinCompareCases)
   326  }
   327  
   328  func BenchmarkVectorizedGeneratedBuiltinCompareFunc(b *testing.B) {
   329  	benchmarkVectorizedBuiltinFunc(b, vecGeneratedBuiltinCompareCases)
   330  }
   331  `
   332  
   333  var builtinCoalesceCompareVecTestFunc = template.Must(template.New("").Parse(`
   334  {
   335  	retEvalType: types.ET{{ .ETName }},
   336  	childrenTypes: []types.EvalType{types.ET{{ .ETName }}, types.ET{{ .ETName }}, types.ET{{ .ETName }}},
   337  	geners: []dataGenerator{
   338  		gener{*newDefaultGener(0.2, types.ET{{ .ETName }})},
   339  		gener{*newDefaultGener(0.2, types.ET{{ .ETName }})},
   340  		gener{*newDefaultGener(0.2, types.ET{{ .ETName }})},
   341  	},
   342  },
   343  
   344  
   345  `))
   346  
   347  type CompareContext struct {
   348  	// Describe the name of CompareContext(LT/LE/GT/GE/EQ/NE/NullEQ)
   349  	CompareName string
   350  	// Compare Operators
   351  	Operator string
   352  }
   353  
   354  var comparesMap = []CompareContext{
   355  	{CompareName: "LT", Operator: "<"},
   356  	{CompareName: "LE", Operator: "<="},
   357  	{CompareName: "GT", Operator: ">"},
   358  	{CompareName: "GE", Operator: ">="},
   359  	{CompareName: "EQ", Operator: "=="},
   360  	{CompareName: "NE", Operator: "!="},
   361  	{CompareName: "NullEQ"},
   362  	{CompareName: "Coalesce"},
   363  }
   364  
   365  var typesMap = []TypeContext{
   366  	TypeInt,
   367  	TypeReal,
   368  	TypeDecimal,
   369  	TypeString,
   370  	TypeDatetime,
   371  	TypeDuration,
   372  	TypeJSON,
   373  }
   374  
   375  func generateDotGo(fileName string, compares []CompareContext, types []TypeContext) (err error) {
   376  	w := new(bytes.Buffer)
   377  	w.WriteString(header)
   378  	w.WriteString(newLine)
   379  	w.WriteString(builtinCompareImports)
   380  
   381  	var ctx = make(map[string]interface{})
   382  	for _, compareCtx := range compares {
   383  		for _, typeCtx := range types {
   384  			ctx["compare"] = compareCtx
   385  			ctx["type"] = typeCtx
   386  			if compareCtx.CompareName == "NullEQ" {
   387  				if typeCtx.TypeName == TypeInt.TypeName {
   388  					continue
   389  				}
   390  				err := builtinNullEQCompareVecTpl.InterDircute(w, ctx)
   391  				if err != nil {
   392  					return err
   393  				}
   394  			} else if compareCtx.CompareName == "Coalesce" {
   395  
   396  				err := builtinCoalesceCompareVecTpl.InterDircute(w, ctx)
   397  				if err != nil {
   398  					return err
   399  				}
   400  
   401  			} else {
   402  				if typeCtx.TypeName == TypeInt.TypeName {
   403  					continue
   404  				}
   405  				err := builtinCompareVecTpl.InterDircute(w, ctx)
   406  				if err != nil {
   407  					return err
   408  				}
   409  			}
   410  		}
   411  	}
   412  	data, err := format.Source(w.Bytes())
   413  	if err != nil {
   414  		log.Println("[Warn]", fileName+": gofmt failed", err)
   415  		data = w.Bytes() // write original data for debugging
   416  	}
   417  	return ioutil.WriteFile(fileName, data, 0644)
   418  }
   419  
   420  func generateTestDotGo(fileName string, compares []CompareContext, types []TypeContext) error {
   421  	w := new(bytes.Buffer)
   422  	w.WriteString(header)
   423  	w.WriteString(builtinCompareVecTestHeader)
   424  
   425  	for _, compareCtx := range compares {
   426  		if compareCtx.CompareName == "Coalesce" {
   427  			err := builtinCompareVecTestFuncHeader.InterDircute(w, CompareContext{CompareName: "Coalesce"})
   428  			if err != nil {
   429  				return err
   430  			}
   431  			for _, typeCtx := range types {
   432  				err := builtinCoalesceCompareVecTestFunc.InterDircute(w, typeCtx)
   433  				if err != nil {
   434  					return err
   435  				}
   436  			}
   437  			w.WriteString(builtinCompareVecTestFuncTail)
   438  			continue
   439  		}
   440  		err := builtinCompareVecTestFuncHeader.InterDircute(w, compareCtx)
   441  		if err != nil {
   442  			return err
   443  		}
   444  		for _, typeCtx := range types {
   445  			if typeCtx.TypeName == TypeInt.TypeName {
   446  				continue
   447  			}
   448  			err := builtinCompareVecTestCase.InterDircute(w, typeCtx)
   449  			if err != nil {
   450  				return err
   451  			}
   452  		}
   453  		w.WriteString(builtinCompareVecTestFuncTail)
   454  	}
   455  	w.WriteString(builtinCompareVecTestTail)
   456  
   457  	data, err := format.Source(w.Bytes())
   458  	if err != nil {
   459  		log.Println("[Warn]", fileName+": gofmt failed", err)
   460  		data = w.Bytes() // write original data for debugging
   461  	}
   462  	return ioutil.WriteFile(fileName, data, 0644)
   463  }
   464  
   465  // generateOneFile generate one xxx.go file and the associated xxx_test.go file.
   466  func generateOneFile(fileNamePrefix string, compares []CompareContext,
   467  	types []TypeContext) (err error) {
   468  
   469  	err = generateDotGo(fileNamePrefix+".go", compares, types)
   470  	if err != nil {
   471  		return
   472  	}
   473  	err = generateTestDotGo(fileNamePrefix+"_test.go", compares, types)
   474  	return
   475  }
   476  
   477  func main() {
   478  	flag.Parse()
   479  	var err error
   480  	outputDir := "."
   481  	err = generateOneFile(filepath.Join(outputDir, "builtin_compare_vec_generated"),
   482  		comparesMap, typesMap)
   483  	if err != nil {
   484  		log.Fatalln("generateOneFile", err)
   485  	}
   486  }