github.com/dolthub/go-mysql-server@v0.18.0/sql/expression/function/aggregation/group_concat_test.go (about) 1 // Copyright 2021 Dolthub, Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 package aggregation 16 17 import ( 18 "testing" 19 20 "github.com/dolthub/vitess/go/vt/proto/query" 21 "github.com/stretchr/testify/require" 22 23 "github.com/dolthub/go-mysql-server/sql" 24 "github.com/dolthub/go-mysql-server/sql/expression" 25 "github.com/dolthub/go-mysql-server/sql/types" 26 ) 27 28 func TestGroupConcat_FunctionName(t *testing.T) { 29 assert := require.New(t) 30 31 m := NewGroupConcat("field", nil, ",", nil, 1024) 32 33 assert.Equal("group_concat(distinct field separator ',')", m.String()) 34 35 m = NewGroupConcat("field", nil, "-", nil, 1024) 36 37 assert.Equal("group_concat(distinct field separator '-')", m.String()) 38 39 sf := sql.SortFields{ 40 {Column: expression.NewUnresolvedColumn("field"), Order: sql.Ascending}, 41 {Column: expression.NewUnresolvedColumn("field2"), Order: sql.Descending}, 42 } 43 44 m = NewGroupConcat("field", sf, "-", nil, 1024) 45 46 assert.Equal("group_concat(distinct field order by field ASC, field2 DESC separator '-')", m.String()) 47 } 48 49 // Validates that the return length of GROUP_CONCAT is bounded by group_concat_max_len (default 1024) 50 func TestGroupConcat_PastMaxLen(t *testing.T) { 51 var rows []sql.Row 52 ctx := sql.NewEmptyContext() 53 54 for i := 0; i < 2000; i++ { 55 rows = append(rows, sql.Row{int64(i)}) 56 } 57 58 maxLenInt, err := ctx.GetSessionVariable(ctx, "group_concat_max_len") 59 require.NoError(t, err) 60 maxLen := maxLenInt.(uint64) 61 62 gc := NewGroupConcat("", nil, ",", []sql.Expression{expression.NewGetField(0, types.Int64, "int", true)}, int(maxLen)) 63 64 buf, _ := gc.NewBuffer() 65 for _, row := range rows { 66 require.NoError(t, buf.Update(ctx, row)) 67 } 68 69 result, err := buf.Eval(ctx) 70 rs := result.(string) 71 72 require.NoError(t, err) 73 require.Equal(t, int(maxLen), len(rs)) 74 } 75 76 // Validate that group_concat returns the correct return type 77 func TestGroupConcat_ReturnType(t *testing.T) { 78 ctx := sql.NewEmptyContext() 79 80 testCases := []struct { 81 expression []sql.Expression 82 maxLen int 83 returnType sql.Type 84 row sql.Row 85 }{ 86 {[]sql.Expression{expression.NewGetField(0, types.LongText, "test", true)}, 200, types.MustCreateString(query.Type_VARCHAR, 512, sql.Collation_Default), sql.Row{int64(1)}}, 87 {[]sql.Expression{expression.NewGetField(0, types.Text, "text", true)}, 1020, types.Text, sql.Row{int64(1)}}, 88 {[]sql.Expression{expression.NewGetField(0, types.Blob, "myblob", true)}, 200, types.MustCreateString(query.Type_VARBINARY, 512, sql.Collation_binary), sql.Row{"hi"}}, 89 {[]sql.Expression{expression.NewGetField(0, types.Blob, "myblob", true)}, 1020, types.Blob, sql.Row{"hi"}}, 90 } 91 92 for _, tt := range testCases { 93 gc := NewGroupConcat("", nil, ",", tt.expression, tt.maxLen) 94 95 buf, _ := gc.NewBuffer() 96 97 err := buf.Update(ctx, tt.row) 98 require.NoError(t, err) 99 100 _, err = buf.Eval(ctx) 101 require.NoError(t, err) 102 103 require.Equal(t, tt.returnType, gc.Type()) 104 } 105 }