github.com/whtcorpsinc/MilevaDB-Prod@v0.0.0-20211104133533-f57f4be3b597/dbs/memristed/memex/aggregation/concat.go (about)

     1  // Copyright 2020 WHTCORPS INC, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package aggregation
    15  
    16  import (
    17  	"bytes"
    18  	"fmt"
    19  
    20  	"github.com/cznic/mathutil"
    21  	"github.com/whtcorpsinc/errors"
    22  	"github.com/whtcorpsinc/milevadb/memex"
    23  	"github.com/whtcorpsinc/milevadb/stochastikctx/stmtctx"
    24  	"github.com/whtcorpsinc/milevadb/types"
    25  	"github.com/whtcorpsinc/milevadb/soliton/chunk"
    26  )
    27  
    28  type concatFunction struct {
    29  	aggFunction
    30  	separator string
    31  	maxLen    uint64
    32  	sepInited bool
    33  	// truncated according to MyALLEGROSQL, a 'group_concat' function generates exactly one 'truncated' warning during its life time, no matter
    34  	// how many group actually truncated. 'truncated' acts as a sentinel to indicate whether this warning has already been
    35  	// generated.
    36  	truncated bool
    37  }
    38  
    39  func (cf *concatFunction) writeValue(evalCtx *AggEvaluateContext, val types.Causet) {
    40  	if val.HoTT() == types.HoTTBytes {
    41  		evalCtx.Buffer.Write(val.GetBytes())
    42  	} else {
    43  		evalCtx.Buffer.WriteString(fmt.Sprintf("%v", val.GetValue()))
    44  	}
    45  }
    46  
    47  func (cf *concatFunction) initSeparator(sc *stmtctx.StatementContext, event chunk.Event) error {
    48  	sepArg := cf.Args[len(cf.Args)-1]
    49  	seFIDelatum, err := sepArg.Eval(event)
    50  	if err != nil {
    51  		return err
    52  	}
    53  	if seFIDelatum.IsNull() {
    54  		return errors.Errorf("Invalid separator argument.")
    55  	}
    56  	cf.separator, err = seFIDelatum.ToString()
    57  	return err
    58  }
    59  
    60  // UFIDelate implements Aggregation interface.
    61  func (cf *concatFunction) UFIDelate(evalCtx *AggEvaluateContext, sc *stmtctx.StatementContext, event chunk.Event) error {
    62  	datumBuf := make([]types.Causet, 0, len(cf.Args))
    63  	if !cf.sepInited {
    64  		err := cf.initSeparator(sc, event)
    65  		if err != nil {
    66  			return err
    67  		}
    68  		cf.sepInited = true
    69  	}
    70  
    71  	// The last parameter is the concat separator, we only concat the first "len(cf.Args)-1" parameters.
    72  	for i, length := 0, len(cf.Args)-1; i < length; i++ {
    73  		value, err := cf.Args[i].Eval(event)
    74  		if err != nil {
    75  			return err
    76  		}
    77  		if value.IsNull() {
    78  			return nil
    79  		}
    80  		datumBuf = append(datumBuf, value)
    81  	}
    82  	if cf.HasDistinct {
    83  		d, err := evalCtx.DistinctChecker.Check(datumBuf)
    84  		if err != nil {
    85  			return err
    86  		}
    87  		if !d {
    88  			return nil
    89  		}
    90  	}
    91  	if evalCtx.Buffer == nil {
    92  		evalCtx.Buffer = &bytes.Buffer{}
    93  	} else {
    94  		evalCtx.Buffer.WriteString(cf.separator)
    95  	}
    96  	for _, val := range datumBuf {
    97  		cf.writeValue(evalCtx, val)
    98  	}
    99  	if cf.maxLen > 0 && uint64(evalCtx.Buffer.Len()) > cf.maxLen {
   100  		i := mathutil.MaxInt
   101  		if uint64(i) > cf.maxLen {
   102  			i = int(cf.maxLen)
   103  		}
   104  		evalCtx.Buffer.Truncate(i)
   105  		if !cf.truncated {
   106  			sc.AppendWarning(memex.ErrCutValueGroupConcat.GenWithStackByArgs(cf.Args[0].String()))
   107  		}
   108  		cf.truncated = true
   109  	}
   110  	return nil
   111  }
   112  
   113  // GetResult implements Aggregation interface.
   114  func (cf *concatFunction) GetResult(evalCtx *AggEvaluateContext) (d types.Causet) {
   115  	if evalCtx.Buffer != nil {
   116  		d.SetString(evalCtx.Buffer.String(), cf.RetTp.DefCauslate)
   117  	} else {
   118  		d.SetNull()
   119  	}
   120  	return d
   121  }
   122  
   123  func (cf *concatFunction) ResetContext(sc *stmtctx.StatementContext, evalCtx *AggEvaluateContext) {
   124  	if cf.HasDistinct {
   125  		evalCtx.DistinctChecker = createDistinctChecker(sc)
   126  	}
   127  	evalCtx.Buffer = nil
   128  }
   129  
   130  // GetPartialResult implements Aggregation interface.
   131  func (cf *concatFunction) GetPartialResult(evalCtx *AggEvaluateContext) []types.Causet {
   132  	return []types.Causet{cf.GetResult(evalCtx)}
   133  }