github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/aggregation/group_concat.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package aggregation 16 17 import ( 18 "fmt" 19 "sort" 20 "strings" 21 22 "github.com/dolthub/vitess/go/vt/proto/query" 23 24 "github.com/dolthub/go-mysql-server/sql" 25 "github.com/dolthub/go-mysql-server/sql/expression" 26 "github.com/dolthub/go-mysql-server/sql/types" 27 ) 28 29 type GroupConcat struct { 30 distinct string 31 sf sql.SortFields 32 separator string 33 selectExprs []sql.Expression 34 maxLen int 35 returnType sql.Type 36 window *sql.WindowDefinition 37 id sql.ColumnId 38 } 39 40 var _ sql.FunctionExpression = &GroupConcat{} 41 var _ sql.Aggregation = &GroupConcat{} 42 var _ sql.WindowAdaptableExpression = (*GroupConcat)(nil) 43 44 func NewEmptyGroupConcat() sql.Expression { 45 return &GroupConcat{} 46 } 47 48 // FunctionName implements sql.FunctionExpression 49 func (g *GroupConcat) FunctionName() string { 50 return "group_concat" 51 } 52 53 // Description implements sql.FunctionExpression 54 func (g *GroupConcat) Description() string { 55 return "returns a string result with the concatenated non-NULL values from a group." 56 } 57 58 func NewGroupConcat(distinct string, orderBy sql.SortFields, separator string, selectExprs []sql.Expression, maxLen int) *GroupConcat { 59 return &GroupConcat{distinct: distinct, sf: orderBy, separator: separator, selectExprs: selectExprs, maxLen: maxLen} 60 } 61 62 // Id implements the Aggregation interface 63 func (a *GroupConcat) Id() sql.ColumnId { 64 return a.id 65 } 66 67 // WithId implements the Aggregation interface 68 func (a *GroupConcat) WithId(id sql.ColumnId) sql.IdExpression { 69 ret := *a 70 ret.id = id 71 return &ret 72 } 73 74 // WithWindow implements sql.Aggregation 75 func (g *GroupConcat) WithWindow(window *sql.WindowDefinition) sql.WindowAdaptableExpression { 76 ng := *g 77 ng.window = window 78 return &ng 79 } 80 81 // Window implements sql.Aggregation 82 func (g *GroupConcat) Window() *sql.WindowDefinition { 83 return g.window 84 } 85 86 // NewBuffer creates a new buffer for the aggregation. 87 func (g *GroupConcat) NewBuffer() (sql.AggregationBuffer, error) { 88 var rows []sql.Row 89 distinctSet := make(map[string]bool) 90 return &groupConcatBuffer{g, rows, distinctSet}, nil 91 } 92 93 // NewWindowFunctionAggregation implements sql.WindowAdaptableExpression 94 func (g *GroupConcat) NewWindowFunction() (sql.WindowFunction, error) { 95 return NewGroupConcatAgg(g), nil 96 } 97 98 // Eval implements the Expression interface. 99 func (g *GroupConcat) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { 100 return nil, ErrEvalUnsupportedOnAggregation.New("GroupConcat") 101 } 102 103 // Resolved implements the Expression interface. 104 func (g *GroupConcat) Resolved() bool { 105 for _, se := range g.selectExprs { 106 if !se.Resolved() { 107 return false 108 } 109 } 110 111 sfs := g.sf.ToExpressions() 112 113 for _, sf := range sfs { 114 if !sf.Resolved() { 115 return false 116 } 117 } 118 119 return true 120 } 121 122 func (g *GroupConcat) String() string { 123 sb := strings.Builder{} 124 sb.WriteString("group_concat(") 125 if g.distinct != "" { 126 sb.WriteString(fmt.Sprintf("distinct %s", g.distinct)) 127 } 128 129 if g.selectExprs != nil { 130 var exprs = make([]string, len(g.selectExprs)) 131 for i, expr := range g.selectExprs { 132 exprs[i] = expr.String() 133 } 134 135 sb.WriteString(strings.Join(exprs, ", ")) 136 } 137 138 if len(g.sf) > 0 { 139 sb.WriteString(" order by ") 140 for i, ob := range g.sf { 141 if i > 0 { 142 sb.WriteString(", ") 143 } 144 sb.WriteString(ob.String()) 145 } 146 } 147 148 sb.WriteString(" separator ") 149 sb.WriteString(fmt.Sprintf("'%s'", g.separator)) 150 151 sb.WriteString(")") 152 153 return sb.String() 154 } 155 156 // Type implements the Expression interface. 157 // cc: https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html#function_group-concat for explanations 158 // on return type. 159 func (g *GroupConcat) Type() sql.Type { 160 if g.returnType == types.Blob { 161 if g.maxLen <= 512 { 162 return types.MustCreateString(query.Type_VARBINARY, 512, sql.Collation_binary) 163 } else { 164 return types.Blob 165 } 166 } else { 167 if g.maxLen <= 512 { 168 return types.MustCreateString(query.Type_VARCHAR, 512, sql.Collation_Default) 169 } else { 170 return types.Text 171 } 172 } 173 } 174 175 // IsNullable implements the Expression interface. 176 func (g *GroupConcat) IsNullable() bool { 177 return false 178 } 179 180 // Children implements the Expression interface. 181 func (g *GroupConcat) Children() []sql.Expression { 182 return append(g.sf.ToExpressions(), g.selectExprs...) 183 } 184 185 // WithChildren implements the Expression interface. 186 func (g *GroupConcat) WithChildren(children ...sql.Expression) (sql.Expression, error) { 187 if len(children) == 0 { 188 return nil, sql.ErrInvalidChildrenNumber.New(GroupConcat{}, len(children), 2) 189 } 190 191 // Get the order by expression using the length of the sort fields. 192 sortFieldMarker := len(g.sf) 193 orderByExpr := children[:len(g.sf)] 194 195 return NewGroupConcat(g.distinct, g.sf.FromExpressions(orderByExpr...), g.separator, children[sortFieldMarker:], g.maxLen), nil 196 } 197 198 type groupConcatBuffer struct { 199 gc *GroupConcat 200 rows []sql.Row 201 distinctSet map[string]bool 202 } 203 204 // Update implements the AggregationBuffer interface. 205 func (g *groupConcatBuffer) Update(ctx *sql.Context, originalRow sql.Row) error { 206 evalRow, retType, err := evalExprs(ctx, g.gc.selectExprs, originalRow) 207 if err != nil { 208 return err 209 } 210 211 g.gc.returnType = retType 212 213 // Skip if this is a null row 214 if evalRow == nil { 215 return nil 216 } 217 218 var v interface{} 219 var vs string 220 if types.IsBlobType(retType) { 221 v, _, err = types.Blob.Convert(evalRow[0]) 222 if err != nil { 223 return err 224 } 225 vs = string(v.([]byte)) 226 if len(vs) == 0 { 227 return nil 228 } 229 } else { 230 v, _, err = types.LongText.Convert(evalRow[0]) 231 if err != nil { 232 return err 233 } 234 if v == nil { 235 return nil 236 } 237 vs = v.(string) 238 } 239 240 // Get the current array of rows and the map 241 // Check if distinct is active if so look at and update our map 242 if g.gc.distinct != "" { 243 // If this value exists go ahead and return nil 244 if _, ok := g.distinctSet[vs]; ok { 245 return nil 246 } else { 247 g.distinctSet[vs] = true 248 } 249 } 250 251 // Append the current value to the end of the row. We want to preserve the row's original structure for 252 // for sort ordering in the final step. 253 g.rows = append(g.rows, append(originalRow, nil, vs)) 254 255 return nil 256 } 257 258 // Eval implements the AggregationBuffer interface. 259 // cc: https://dev.mysql.com/doc/refman/8.0/en/aggregate-functions.html#function_group-concat 260 func (g *groupConcatBuffer) Eval(ctx *sql.Context) (interface{}, error) { 261 rows := g.rows 262 263 if len(rows) == 0 { 264 return nil, nil 265 } 266 267 // Execute the order operation if it exists. 268 if g.gc.sf != nil { 269 sorter := &expression.Sorter{ 270 SortFields: g.gc.sf, 271 Rows: rows, 272 Ctx: ctx, 273 } 274 275 sort.Stable(sorter) 276 if sorter.LastError != nil { 277 return nil, sorter.LastError 278 } 279 } 280 281 sb := strings.Builder{} 282 for i, row := range rows { 283 lastIdx := len(row) - 1 284 if i == 0 { 285 sb.WriteString(row[lastIdx].(string)) 286 } else { 287 sb.WriteString(g.gc.separator) 288 sb.WriteString(row[lastIdx].(string)) 289 } 290 291 // Don't allow the string to cross maxlen 292 if sb.Len() >= g.gc.maxLen { 293 break 294 } 295 } 296 297 ret := sb.String() 298 299 // There might be a couple of character differences even if we broke early in the loop 300 if len(ret) > g.gc.maxLen { 301 ret = ret[:g.gc.maxLen] 302 } 303 304 // Add this to handle any one off errors. 305 return ret, nil 306 } 307 308 // Dispose implements the Disposable interface. 309 func (g *groupConcatBuffer) Dispose() { 310 } 311 312 func evalExprs(ctx *sql.Context, exprs []sql.Expression, row sql.Row) (sql.Row, sql.Type, error) { 313 result := make(sql.Row, len(exprs)) 314 retType := types.Blob 315 for i, expr := range exprs { 316 var err error 317 result[i], err = expr.Eval(ctx, row) 318 if err != nil { 319 return nil, nil, err 320 } 321 322 // If every expression returns Blob type return Blob otherwise return Text. 323 if expr.Type() != types.Blob { 324 retType = types.Text 325 } 326 } 327 328 return result, retType, nil 329 }