github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/aggregation/group_concat.go (about)

     1  // Copyright 2021 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package aggregation
    16  
    17  import (
    18  	"fmt"
    19  	"sort"
    20  	"strings"
    21  
    22  	"github.com/dolthub/vitess/go/vt/proto/query"
    23  
    24  	"github.com/dolthub/go-mysql-server/sql"
    25  	"github.com/dolthub/go-mysql-server/sql/expression"
    26  	"github.com/dolthub/go-mysql-server/sql/types"
    27  )
    28  
    29  type GroupConcat struct {
    30  	distinct    string
    31  	sf          sql.SortFields
    32  	separator   string
    33  	selectExprs []sql.Expression
    34  	maxLen      int
    35  	returnType  sql.Type
    36  	window      *sql.WindowDefinition
    37  	id          sql.ColumnId
    38  }
    39  
    40  var _ sql.FunctionExpression = &GroupConcat{}
    41  var _ sql.Aggregation = &GroupConcat{}
    42  var _ sql.WindowAdaptableExpression = (*GroupConcat)(nil)
    43  
    44  func NewEmptyGroupConcat() sql.Expression {
    45  	return &GroupConcat{}
    46  }
    47  
    48  // FunctionName implements sql.FunctionExpression
    49  func (g *GroupConcat) FunctionName() string {
    50  	return "group_concat"
    51  }
    52  
    53  // Description implements sql.FunctionExpression
    54  func (g *GroupConcat) Description() string {
    55  	return "returns a string result with the concatenated non-NULL values from a group."
    56  }
    57  
    58  func NewGroupConcat(distinct string, orderBy sql.SortFields, separator string, selectExprs []sql.Expression, maxLen int) *GroupConcat {
    59  	return &GroupConcat{distinct: distinct, sf: orderBy, separator: separator, selectExprs: selectExprs, maxLen: maxLen}
    60  }
    61  
    62  // Id implements the Aggregation interface
    63  func (a *GroupConcat) Id() sql.ColumnId {
    64  	return a.id
    65  }
    66  
    67  // WithId implements the Aggregation interface
    68  func (a *GroupConcat) WithId(id sql.ColumnId) sql.IdExpression {
    69  	ret := *a
    70  	ret.id = id
    71  	return &ret
    72  }
    73  
    74  // WithWindow implements sql.Aggregation
    75  func (g *GroupConcat) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression {
    76  	ng := *g
    77  	ng.window = window
    78  	return &ng
    79  }
    80  
    81  // Window implements sql.Aggregation
    82  func (g *GroupConcat) Window() *sql.WindowDefinition {
    83  	return g.window
    84  }
    85  
    86  // NewBuffer creates a new buffer for the aggregation.
    87  func (g *GroupConcat) NewBuffer() (sql.AggregationBuffer, error) {
    88  	var rows []sql.Row
    89  	distinctSet := make(map[string]bool)
    90  	return &groupConcatBuffer{g, rows, distinctSet}, nil
    91  }
    92  
    93  // NewWindowFunctionAggregation implements sql.WindowAdaptableExpression
    94  func (g *GroupConcat) NewWindowFunction() (sql.WindowFunction, error) {
    95  	return NewGroupConcatAgg(g), nil
    96  }
    97  
    98  // Eval implements the Expression interface.
    99  func (g *GroupConcat) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
   100  	return nil, ErrEvalUnsupportedOnAggregation.New("GroupConcat")
   101  }
   102  
   103  // Resolved implements the Expression interface.
   104  func (g *GroupConcat) Resolved() bool {
   105  	for _, se := range g.selectExprs {
   106  		if !se.Resolved() {
   107  			return false
   108  		}
   109  	}
   110  
   111  	sfs := g.sf.ToExpressions()
   112  
   113  	for _, sf := range sfs {
   114  		if !sf.Resolved() {
   115  			return false
   116  		}
   117  	}
   118  
   119  	return true
   120  }
   121  
   122  func (g *GroupConcat) String() string {
   123  	sb := strings.Builder{}
   124  	sb.WriteString("group_concat(")
   125  	if g.distinct != "" {
   126  		sb.WriteString(fmt.Sprintf("distinct %s", g.distinct))
   127  	}
   128  
   129  	if g.selectExprs != nil {
   130  		var exprs = make([]string, len(g.selectExprs))
   131  		for i, expr := range g.selectExprs {
   132  			exprs[i] = expr.String()
   133  		}
   134  
   135  		sb.WriteString(strings.Join(exprs, ", "))
   136  	}
   137  
   138  	if len(g.sf) > 0 {
   139  		sb.WriteString(" order by ")
   140  		for i, ob := range g.sf {
   141  			if i > 0 {
   142  				sb.WriteString(", ")
   143  			}
   144  			sb.WriteString(ob.String())
   145  		}
   146  	}
   147  
   148  	sb.WriteString(" separator ")
   149  	sb.WriteString(fmt.Sprintf("'%s'", g.separator))
   150  
   151  	sb.WriteString(")")
   152  
   153  	return sb.String()
   154  }
   155  
   156  // Type implements the Expression interface.
   157  // cc: https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html#function_group-concat for explanations
   158  // on return type.
   159  func (g *GroupConcat) Type() sql.Type {
   160  	if g.returnType == types.Blob {
   161  		if g.maxLen <= 512 {
   162  			return types.MustCreateString(query.Type_VARBINARY, 512, sql.Collation_binary)
   163  		} else {
   164  			return types.Blob
   165  		}
   166  	} else {
   167  		if g.maxLen <= 512 {
   168  			return types.MustCreateString(query.Type_VARCHAR, 512, sql.Collation_Default)
   169  		} else {
   170  			return types.Text
   171  		}
   172  	}
   173  }
   174  
   175  // IsNullable implements the Expression interface.
   176  func (g *GroupConcat) IsNullable() bool {
   177  	return false
   178  }
   179  
   180  // Children implements the Expression interface.
   181  func (g *GroupConcat) Children() []sql.Expression {
   182  	return append(g.sf.ToExpressions(), g.selectExprs...)
   183  }
   184  
   185  // WithChildren implements the Expression interface.
   186  func (g *GroupConcat) WithChildren(children ...sql.Expression) (sql.Expression, error) {
   187  	if len(children) == 0 {
   188  		return nil, sql.ErrInvalidChildrenNumber.New(GroupConcat{}, len(children), 2)
   189  	}
   190  
   191  	// Get the order by expression using the length of the sort fields.
   192  	sortFieldMarker := len(g.sf)
   193  	orderByExpr := children[:len(g.sf)]
   194  
   195  	return NewGroupConcat(g.distinct, g.sf.FromExpressions(orderByExpr...), g.separator, children[sortFieldMarker:], g.maxLen), nil
   196  }
   197  
   198  type groupConcatBuffer struct {
   199  	gc          *GroupConcat
   200  	rows        []sql.Row
   201  	distinctSet map[string]bool
   202  }
   203  
   204  // Update implements the AggregationBuffer interface.
   205  func (g *groupConcatBuffer) Update(ctx *sql.Context, originalRow sql.Row) error {
   206  	evalRow, retType, err := evalExprs(ctx, g.gc.selectExprs, originalRow)
   207  	if err != nil {
   208  		return err
   209  	}
   210  
   211  	g.gc.returnType = retType
   212  
   213  	// Skip if this is a null row
   214  	if evalRow == nil {
   215  		return nil
   216  	}
   217  
   218  	var v interface{}
   219  	var vs string
   220  	if types.IsBlobType(retType) {
   221  		v, _, err = types.Blob.Convert(evalRow[0])
   222  		if err != nil {
   223  			return err
   224  		}
   225  		vs = string(v.([]byte))
   226  		if len(vs) == 0 {
   227  			return nil
   228  		}
   229  	} else {
   230  		v, _, err = types.LongText.Convert(evalRow[0])
   231  		if err != nil {
   232  			return err
   233  		}
   234  		if v == nil {
   235  			return nil
   236  		}
   237  		vs = v.(string)
   238  	}
   239  
   240  	// Get the current array of rows and the map
   241  	// Check if distinct is active if so look at and update our map
   242  	if g.gc.distinct != "" {
   243  		// If this value exists go ahead and return nil
   244  		if _, ok := g.distinctSet[vs]; ok {
   245  			return nil
   246  		} else {
   247  			g.distinctSet[vs] = true
   248  		}
   249  	}
   250  
   251  	// Append the current value to the end of the row. We want to preserve the row's original structure for
   252  	// for sort ordering in the final step.
   253  	g.rows = append(g.rows, append(originalRow, nil, vs))
   254  
   255  	return nil
   256  }
   257  
   258  // Eval implements the AggregationBuffer interface.
   259  // cc: https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html#function_group-concat
   260  func (g *groupConcatBuffer) Eval(ctx *sql.Context) (interface{}, error) {
   261  	rows := g.rows
   262  
   263  	if len(rows) == 0 {
   264  		return nil, nil
   265  	}
   266  
   267  	// Execute the order operation if it exists.
   268  	if g.gc.sf != nil {
   269  		sorter := &expression.Sorter{
   270  			SortFields: g.gc.sf,
   271  			Rows:       rows,
   272  			Ctx:        ctx,
   273  		}
   274  
   275  		sort.Stable(sorter)
   276  		if sorter.LastError != nil {
   277  			return nil, sorter.LastError
   278  		}
   279  	}
   280  
   281  	sb := strings.Builder{}
   282  	for i, row := range rows {
   283  		lastIdx := len(row) - 1
   284  		if i == 0 {
   285  			sb.WriteString(row[lastIdx].(string))
   286  		} else {
   287  			sb.WriteString(g.gc.separator)
   288  			sb.WriteString(row[lastIdx].(string))
   289  		}
   290  
   291  		// Don't allow the string to cross maxlen
   292  		if sb.Len() >= g.gc.maxLen {
   293  			break
   294  		}
   295  	}
   296  
   297  	ret := sb.String()
   298  
   299  	// There might be a couple of character differences even if we broke early in the loop
   300  	if len(ret) > g.gc.maxLen {
   301  		ret = ret[:g.gc.maxLen]
   302  	}
   303  
   304  	// Add this to handle any one off errors.
   305  	return ret, nil
   306  }
   307  
   308  // Dispose implements the Disposable interface.
   309  func (g *groupConcatBuffer) Dispose() {
   310  }
   311  
   312  func evalExprs(ctx *sql.Context, exprs []sql.Expression, row sql.Row) (sql.Row, sql.Type, error) {
   313  	result := make(sql.Row, len(exprs))
   314  	retType := types.Blob
   315  	for i, expr := range exprs {
   316  		var err error
   317  		result[i], err = expr.Eval(ctx, row)
   318  		if err != nil {
   319  			return nil, nil, err
   320  		}
   321  
   322  		// If every expression returns Blob type return Blob otherwise return Text.
   323  		if expr.Type() != types.Blob {
   324  			retType = types.Text
   325  		}
   326  	}
   327  
   328  	return result, retType, nil
   329  }