github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/colexec/multi_col/group_concat/group_concat.go (about)

     1  // Copyright 2022 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //	http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  package group_concat
    15  
    16  import (
    17  	"bytes"
    18  	"fmt"
    19  	"strings"
    20  	"unsafe"
    21  
    22  	"github.com/matrixorigin/matrixone/pkg/common/hashmap"
    23  	"github.com/matrixorigin/matrixone/pkg/common/mpool"
    24  	"github.com/matrixorigin/matrixone/pkg/container/nulls"
    25  	"github.com/matrixorigin/matrixone/pkg/container/types"
    26  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    27  	"github.com/matrixorigin/matrixone/pkg/sql/colexec/agg"
    28  )
    29  
    30  // +------+------+------+
    31  // | a    | b    | c    |
    32  // +------+------+------+
    33  // |    1 |    2 |    3 |
    34  // |    4 |    5 |    6 |
    35  // +------+------+------+
    36  // select group_concat(a,b,c separator "|") from t;
    37  // res[0] = "123|456"
    38  // inserts = "encode(1,2,3)|encode(4,5,6)
    39  // we need inserts to store the source keys, so we can use then where merge
    40  type GroupConcat struct {
    41  	arg     *Argument
    42  	res     []string
    43  	inserts []string
    44  	maps    []*hashmap.StrHashMap
    45  	// in group_concat(distinct a,b ), the itype will be a and b's types
    46  	ityp   []types.Type
    47  	groups int // groups record the real group number
    48  }
    49  
    50  type EncodeGroupConcat struct {
    51  	res_strData     []byte
    52  	inserts_strData []byte
    53  	arg             *Argument
    54  	ityp            []types.Type
    55  	groups          int
    56  }
    57  
    58  func NewGroupConcat(arg *Argument, typs []types.Type) agg.Agg[any] {
    59  	return &GroupConcat{
    60  		arg:  arg,
    61  		ityp: typs,
    62  	}
    63  }
    64  
    65  // We need to implements the interface of Agg
    66  func (gc *GroupConcat) MarshalBinary() (data []byte, err error) {
    67  	eg := &EncodeGroupConcat{
    68  		res_strData:     types.EncodeStringSlice(gc.res),
    69  		inserts_strData: types.EncodeStringSlice(gc.inserts),
    70  		arg:             gc.arg,
    71  		groups:          gc.groups,
    72  		ityp:            gc.ityp,
    73  	}
    74  	return types.Encode(eg)
    75  }
    76  
    77  // encoding.BinaryUnmarshaler
    78  func (gc *GroupConcat) UnmarshalBinary(data []byte) error {
    79  	eg := &EncodeGroupConcat{}
    80  	types.Decode(data, eg)
    81  	m := mpool.MustNewZero()
    82  	da1, err := m.Alloc(len(eg.inserts_strData))
    83  	if err != nil {
    84  		return err
    85  	}
    86  	copy(da1, eg.inserts_strData)
    87  	gc.inserts = types.DecodeStringSlice(da1)
    88  	da2, err := m.Alloc(len(eg.res_strData))
    89  	if err != nil {
    90  		return err
    91  	}
    92  	copy(da2, eg.res_strData)
    93  	gc.res = types.DecodeStringSlice(da2)
    94  	gc.arg = eg.arg
    95  	gc.groups = eg.groups
    96  	gc.ityp = eg.ityp
    97  	gc.maps = make([]*hashmap.StrHashMap, gc.groups)
    98  	for i := 0; i < gc.groups; i++ {
    99  		gc.maps[i], err = hashmap.NewStrMap(false, 0, 0, m)
   100  		if err != nil {
   101  			return err
   102  		}
   103  		for k := range gc.inserts {
   104  			gc.maps[i].InsertValue(gc.inserts[k])
   105  		}
   106  	}
   107  	return nil
   108  }
   109  
   110  // Dup will duplicate a new agg with the same type.
   111  func (gc *GroupConcat) Dup() agg.Agg[any] {
   112  	var newRes []string = make([]string, len(gc.res))
   113  	copy(newRes, gc.res)
   114  	var newItyp []types.Type = make([]types.Type, 0, len(gc.ityp))
   115  	copy(newItyp, gc.ityp)
   116  	var inserts []string = make([]string, 0, len(gc.inserts))
   117  	return &GroupConcat{
   118  		arg:     gc.arg,
   119  		res:     newRes,
   120  		inserts: inserts,
   121  		ityp:    newItyp,
   122  	}
   123  }
   124  
   125  // Type return the type of the agg's result.
   126  func (gc *GroupConcat) OutputType() types.Type {
   127  	typ := types.T_text.ToType()
   128  	// set to largest length
   129  	typ.Width = types.MaxVarcharLen
   130  	return typ
   131  }
   132  
   133  // group_concat is not a normal agg func, we don't need this func
   134  func (gc *GroupConcat) InputTypes() []types.Type {
   135  	return gc.ityp
   136  }
   137  
   138  // String return related information of the agg.
   139  // used to show query plans.
   140  func (gc *GroupConcat) String() string {
   141  	buf := new(bytes.Buffer)
   142  	buf.WriteString("group_concat( ")
   143  	if gc.arg.Dist {
   144  		buf.WriteString("distinct ")
   145  	}
   146  	for i, expr := range gc.arg.GroupExpr {
   147  		if i > 0 {
   148  			buf.WriteString(", ")
   149  		}
   150  		buf.WriteString(fmt.Sprintf("%v", expr))
   151  	}
   152  	if len(gc.arg.OrderByExpr) > 0 {
   153  		buf.WriteString(" order by ")
   154  	}
   155  	for i, expr := range gc.arg.OrderByExpr {
   156  		if i > 0 {
   157  			buf.WriteString(", ")
   158  		}
   159  		buf.WriteString(fmt.Sprintf("%v", expr))
   160  	}
   161  	buf.WriteString(fmt.Sprintf(" separtor %v)", gc.arg.Separator))
   162  	return buf.String()
   163  }
   164  
   165  // Free the agg.
   166  func (gc *GroupConcat) Free(*mpool.MPool) {
   167  	for _, mp := range gc.maps {
   168  		mp.Free()
   169  	}
   170  	gc.maps = nil
   171  }
   172  
   173  // Grows allocates n groups for the agg.
   174  func (gc *GroupConcat) Grows(n int, m *mpool.MPool) error {
   175  	if len(gc.res) == 0 {
   176  		gc.res = make([]string, 0, n)
   177  		gc.inserts = make([]string, 0, n)
   178  		for i := 0; i < n; i++ {
   179  			gc.res = append(gc.res, "")
   180  			gc.inserts = append(gc.inserts, "")
   181  		}
   182  		if gc.arg.Dist {
   183  			gc.maps = make([]*hashmap.StrHashMap, 0, n)
   184  			for i := 0; i < n; i++ {
   185  				mp, err := hashmap.NewStrMap(false, 0, 0, m)
   186  				if err != nil {
   187  					return err
   188  				}
   189  				gc.maps = append(gc.maps, mp)
   190  			}
   191  		}
   192  	} else {
   193  		for i := 0; i < n; i++ {
   194  			gc.res = append(gc.res, "")
   195  			gc.inserts = append(gc.inserts, "")
   196  			if gc.arg.Dist {
   197  				mp, err := hashmap.NewStrMap(false, 0, 0, m)
   198  				if err != nil {
   199  					return err
   200  				}
   201  				gc.maps = append(gc.maps, mp)
   202  			}
   203  		}
   204  	}
   205  	return nil
   206  }
   207  
   208  // Eval method calculates and returns the final result of the aggregate function.
   209  func (gc *GroupConcat) Eval(m *mpool.MPool) (*vector.Vector, error) {
   210  	vec := vector.New(gc.OutputType())
   211  	nsp := nulls.NewWithSize(gc.groups)
   212  	vec.Nsp = nsp
   213  	for _, v := range gc.res {
   214  		if err := vec.Append([]byte(v), false, m); err != nil {
   215  			vec.Free(m)
   216  			return nil, err
   217  		}
   218  	}
   219  	return vec, nil
   220  }
   221  
   222  // Fill use the rowIndex-rows of vector to update the data of groupIndex-group.
   223  // rowCount indicates the number of times the rowIndex-row is repeated.
   224  // for group_concat(distinct a,b,c separator '|'); vecs is: a,b,c
   225  // remember that, we won't do evalExpr here, so the groupExpr is not used here
   226  func (gc *GroupConcat) Fill(groupIndex int64, rowIndex int64, rowCount int64, vecs []*vector.Vector) error {
   227  	if hasNull(vecs, rowIndex) {
   228  		return nil
   229  	}
   230  	length := len(gc.arg.GroupExpr)
   231  	var res_row string
   232  	var insert_row string
   233  	var flag bool
   234  	var err error
   235  	for i := 0; i < length; i++ {
   236  		s, _ := VectorToString(vecs[i], int(rowIndex))
   237  		res_row += s
   238  		// prefix length + data
   239  		length := uint16(len(s))
   240  		insert_row += string(unsafe.Slice((*byte)(unsafe.Pointer(&length)), 2)) + s
   241  	}
   242  	if gc.arg.Dist {
   243  		if flag, err = gc.maps[groupIndex].InsertValue(insert_row); err != nil {
   244  			return err
   245  		}
   246  		if flag {
   247  			if len(gc.res[groupIndex]) != 0 {
   248  				gc.res[groupIndex] += gc.arg.Separator
   249  				gc.inserts[groupIndex] += gc.arg.Separator
   250  			} else {
   251  				gc.groups++
   252  			}
   253  			gc.res[groupIndex] += res_row
   254  			gc.inserts[groupIndex] += insert_row
   255  		}
   256  	} else {
   257  		for k := 0; k < int(rowCount); k++ {
   258  			if len(gc.res[groupIndex]) != 0 {
   259  				gc.res[groupIndex] += gc.arg.Separator
   260  				gc.inserts[groupIndex] += gc.arg.Separator
   261  			} else {
   262  				gc.groups++
   263  			}
   264  			gc.res[groupIndex] += res_row
   265  			gc.inserts[groupIndex] += insert_row
   266  		}
   267  	}
   268  	return nil
   269  }
   270  
   271  // BulkFill use a whole vector to update the data of agg's group
   272  // groupIndex is the index number of the group
   273  // rowCounts is the count number of each row.
   274  func (gc *GroupConcat) BulkFill(groupIndex int64, rowCounts []int64, vecs []*vector.Vector) error {
   275  	length := vecs[0].Length()
   276  	for i := 0; i < length; i++ {
   277  		if err := gc.Fill(groupIndex, int64(i), rowCounts[i], vecs); err != nil {
   278  			return err
   279  		}
   280  	}
   281  	return nil
   282  }
   283  
   284  // BatchFill use part of the vector to update the data of agg's group
   285  //
   286  //	os(origin-s) records information about which groups need to be updated
   287  //	if length of os is N, we use first N of vps to do update work.
   288  //	And if os[i] > 0, it means the agg's (vps[i]-1)th group is a new one (never been assigned a value),
   289  //	Maybe this feature can help us to do some optimization work.
   290  //	So we use the os as a parameter but not len(os).
   291  //
   292  //	agg's (vps[i]-1)th group is related to vector's (offset+i)th row.
   293  //	rowCounts[i] is count number of the row[i]
   294  //
   295  // For a more detailed introduction of rowCounts, please refer to comments of Function Fill.
   296  func (gc *GroupConcat) BatchFill(offset int64, os []uint8, vps []uint64, rowCounts []int64, vecs []*vector.Vector) error {
   297  	for i := range os {
   298  		if vps[i] == 0 {
   299  			continue
   300  		}
   301  		if err := gc.Fill(int64(vps[i]-1), offset+int64(i), rowCounts[i+int(offset)], vecs); err != nil {
   302  			return err
   303  		}
   304  	}
   305  	return nil
   306  }
   307  
   308  // Merge will merge a couple of group between 2 aggregate function structures.
   309  // It merges the groupIndex1-group of agg1 and
   310  // groupIndex2-group of agg2
   311  func (gc *GroupConcat) Merge(agg2 agg.Agg[any], groupIndex1 int64, groupIndex2 int64) error {
   312  	gc2 := agg2.(*GroupConcat)
   313  	if gc.arg.Dist {
   314  		rows := strings.Split(gc2.inserts[groupIndex2], gc2.arg.Separator)
   315  		ress := strings.Split(gc2.res[groupIndex2], gc2.arg.Separator)
   316  		for i, row := range rows {
   317  			if len(row) == 0 {
   318  				continue
   319  			}
   320  			flag, err := gc.maps[groupIndex1].InsertValue(row)
   321  			if err != nil {
   322  				return err
   323  			}
   324  			if flag {
   325  				if len(gc.res[groupIndex1]) > 0 {
   326  					gc.res[groupIndex1] += gc.arg.Separator
   327  				} else {
   328  					gc.groups++
   329  				}
   330  				gc.res[groupIndex1] += ress[i]
   331  			}
   332  		}
   333  	} else {
   334  		if len(gc.res[groupIndex1]) > 0 {
   335  			gc.res[groupIndex1] += gc.arg.Separator
   336  		} else {
   337  			gc.groups++
   338  		}
   339  		gc.res[groupIndex1] += gc2.res[groupIndex2]
   340  	}
   341  	return nil
   342  }
   343  
   344  // BatchMerge merges multi groups of agg1 and agg2
   345  //
   346  //	agg1's (vps[i]-1)th group is related to agg2's (start+i)th group
   347  //
   348  // For more introduction of os, please refer to comments of Function BatchFill.
   349  func (gc *GroupConcat) BatchMerge(agg2 agg.Agg[any], start int64, os []uint8, vps []uint64) error {
   350  	gc2 := agg2.(*GroupConcat)
   351  	for i := range os {
   352  		if vps[i] == 0 {
   353  			continue
   354  		}
   355  		if err := gc.Merge(gc2, int64(vps[i]-1), int64(i+int(start))); err != nil {
   356  			return err
   357  		}
   358  	}
   359  	return nil
   360  }
   361  
   362  // GetInputTypes get types of aggregate's input arguments.
   363  func (gc *GroupConcat) GetInputTypes() []types.Type {
   364  	return gc.ityp
   365  }
   366  
   367  // GetOperatorId get types of aggregate's aggregate id.
   368  // this is used to print log in group string();
   369  func (gc *GroupConcat) GetOperatorId() int {
   370  	return agg.AggregateGroupConcat
   371  }
   372  
   373  func (gc *GroupConcat) IsDistinct() bool {
   374  	return gc.arg.Dist
   375  }
   376  
   377  // WildAggReAlloc reallocate for agg structure from memory pool.
   378  func (gc *GroupConcat) WildAggReAlloc(m *mpool.MPool) error {
   379  	for i := 0; i < len(gc.res); i++ {
   380  		d, err := m.Alloc(len(gc.res[i]))
   381  		if err != nil {
   382  			return err
   383  		}
   384  		copy(d, []byte(gc.res[i]))
   385  		gc.res[i] = string(d)
   386  	}
   387  	for i := 0; i < len(gc.inserts); i++ {
   388  		d, err := m.Alloc(len(gc.inserts[i]))
   389  		if err != nil {
   390  			return err
   391  		}
   392  		copy(d, []byte(gc.inserts[i]))
   393  		gc.inserts[i] = string(d)
   394  	}
   395  	return nil
   396  }
   397  
   398  func VectorToString(vec *vector.Vector, rowIndex int) (string, error) {
   399  	if nulls.Any(vec.Nsp) {
   400  		return "", nil
   401  	}
   402  	switch vec.Typ.Oid {
   403  	case types.T_bool:
   404  		flag := vector.GetValueAt[bool](vec, int64(rowIndex))
   405  		if flag {
   406  			return "1", nil
   407  		}
   408  		return "0", nil
   409  	case types.T_int8:
   410  		return fmt.Sprintf("%v", vector.GetValueAt[int8](vec, int64(rowIndex))), nil
   411  	case types.T_int16:
   412  		return fmt.Sprintf("%v", vector.GetValueAt[int16](vec, int64(rowIndex))), nil
   413  	case types.T_int32:
   414  		return fmt.Sprintf("%v", vector.GetValueAt[int32](vec, int64(rowIndex))), nil
   415  	case types.T_int64:
   416  		return fmt.Sprintf("%v", vector.GetValueAt[int64](vec, int64(rowIndex))), nil
   417  	case types.T_uint8:
   418  		return fmt.Sprintf("%v", vector.GetValueAt[uint8](vec, int64(rowIndex))), nil
   419  	case types.T_uint16:
   420  		return fmt.Sprintf("%v", vector.GetValueAt[uint16](vec, int64(rowIndex))), nil
   421  	case types.T_uint32:
   422  		return fmt.Sprintf("%v", vector.GetValueAt[uint32](vec, int64(rowIndex))), nil
   423  	case types.T_uint64:
   424  		return fmt.Sprintf("%v", vector.GetValueAt[uint64](vec, int64(rowIndex))), nil
   425  	case types.T_float32:
   426  		return fmt.Sprintf("%v", vector.GetValueAt[float32](vec, int64(rowIndex))), nil
   427  	case types.T_float64:
   428  		return fmt.Sprintf("%v", vector.GetValueAt[float64](vec, int64(rowIndex))), nil
   429  	case types.T_char, types.T_varchar, types.T_text, types.T_blob:
   430  		return vec.GetString(int64(rowIndex)), nil
   431  	case types.T_decimal64:
   432  		val := vector.GetValueAt[types.Decimal64](vec, int64(rowIndex))
   433  		return val.String(), nil
   434  	case types.T_decimal128:
   435  		val := vector.GetValueAt[types.Decimal128](vec, int64(rowIndex))
   436  		return val.String(), nil
   437  	case types.T_json:
   438  		val := vec.GetBytes(int64(rowIndex))
   439  		byteJson := types.DecodeJson(val)
   440  		return byteJson.String(), nil
   441  	case types.T_uuid:
   442  		val := vector.GetValueAt[types.Uuid](vec, int64(rowIndex))
   443  		return val.ToString(), nil
   444  	case types.T_date:
   445  		val := vector.GetValueAt[types.Date](vec, int64(rowIndex))
   446  		return val.String(), nil
   447  	case types.T_time:
   448  		val := vector.GetValueAt[types.Time](vec, int64(rowIndex))
   449  		return val.String(), nil
   450  	case types.T_datetime:
   451  		val := vector.GetValueAt[types.Datetime](vec, int64(rowIndex))
   452  		return val.String(), nil
   453  	default:
   454  		return "", nil
   455  	}
   456  }
   457  
   458  func hasNull(vecs []*vector.Vector, rowIdx int64) bool {
   459  	for i := 0; i < len(vecs); i++ {
   460  		if vecs[i].Nsp.Contains(uint64(rowIdx)) {
   461  			return true
   462  		}
   463  	}
   464  	return false
   465  }