github.com/matrixorigin/matrixone@v0.7.0/pkg/sql/colexec/agg/agg.go (about)

     1  // Copyright 2021 Matrix Origin
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package agg
    16  
    17  import (
    18  	"fmt"
    19  
    20  	"github.com/matrixorigin/matrixone/pkg/common/mpool"
    21  	"github.com/matrixorigin/matrixone/pkg/container/nulls"
    22  	"github.com/matrixorigin/matrixone/pkg/container/types"
    23  	"github.com/matrixorigin/matrixone/pkg/container/vector"
    24  )
    25  
    26  func NewUnaryAgg[T1, T2 any](op int, priv AggStruct, isCount bool, ityp, otyp types.Type, grows func(int),
    27  	eval func([]T2) []T2, merge func(int64, int64, T2, T2, bool, bool, any) (T2, bool),
    28  	fill func(int64, T1, T2, int64, bool, bool) (T2, bool),
    29  	batchFill func(any, any, int64, int64, []uint64, []int64, *nulls.Nulls) error) Agg[*UnaryAgg[T1, T2]] {
    30  	return &UnaryAgg[T1, T2]{
    31  		op:        op,
    32  		priv:      priv,
    33  		otyp:      otyp,
    34  		eval:      eval,
    35  		fill:      fill,
    36  		merge:     merge,
    37  		grows:     grows,
    38  		batchFill: batchFill,
    39  		isCount:   isCount,
    40  		ityps:     []types.Type{ityp},
    41  	}
    42  }
    43  
    44  func (a *UnaryAgg[T1, T2]) String() string {
    45  	return fmt.Sprintf("%v", a.vs)
    46  }
    47  
    48  func (a *UnaryAgg[T1, T2]) Free(m *mpool.MPool) {
    49  	if a.da != nil {
    50  		m.Free(a.da)
    51  		a.da = nil
    52  		a.vs = nil
    53  	}
    54  }
    55  
    56  func (a *UnaryAgg[T1, T2]) Dup() Agg[any] {
    57  	return &UnaryAgg[T1, T2]{
    58  		otyp:  a.otyp,
    59  		ityps: a.ityps,
    60  		fill:  a.fill,
    61  		merge: a.merge,
    62  		grows: a.grows,
    63  		eval:  a.eval,
    64  	}
    65  }
    66  
    67  func (a *UnaryAgg[T1, T2]) OutputType() types.Type {
    68  	return a.otyp
    69  }
    70  
    71  func (a *UnaryAgg[T1, T2]) InputTypes() []types.Type {
    72  	return a.ityps
    73  }
    74  
    75  func (a *UnaryAgg[T1, T2]) Grows(size int, m *mpool.MPool) error {
    76  	if a.otyp.IsString() {
    77  		if len(a.vs) == 0 {
    78  			a.es = make([]bool, 0, size)
    79  			a.vs = make([]T2, 0, size)
    80  			a.vs = a.vs[:size]
    81  			for i := 0; i < size; i++ {
    82  				a.es = append(a.es, true)
    83  			}
    84  		} else {
    85  			var v T2
    86  			for i := 0; i < size; i++ {
    87  				a.es = append(a.es, true)
    88  				a.vs = append(a.vs, v)
    89  			}
    90  		}
    91  		a.grows(size)
    92  		return nil
    93  	}
    94  	sz := a.otyp.TypeSize()
    95  	n := len(a.vs)
    96  	if n == 0 {
    97  		data, err := m.Alloc(size * sz)
    98  		if err != nil {
    99  			return err
   100  		}
   101  		a.da = data
   102  		a.es = make([]bool, 0, size)
   103  		a.vs = types.DecodeSlice[T2](a.da)
   104  	} else if n+size >= cap(a.vs) {
   105  		a.da = a.da[:n*sz]
   106  		data, err := m.Grow(a.da, (n+size)*sz)
   107  		if err != nil {
   108  			return err
   109  		}
   110  		a.da = data
   111  		a.vs = types.DecodeSlice[T2](a.da)
   112  	}
   113  	a.vs = a.vs[:n+size]
   114  	a.da = a.da[:(n+size)*sz]
   115  	for i := 0; i < size; i++ {
   116  		a.es = append(a.es, true)
   117  	}
   118  	a.grows(size)
   119  	return nil
   120  }
   121  
   122  func (a *UnaryAgg[T1, T2]) Fill(i int64, sel, z int64, vecs []*vector.Vector) error {
   123  	vec := vecs[0]
   124  	hasNull := vec.GetNulls().Contains(uint64(sel))
   125  	if vec.Typ.IsString() {
   126  		a.vs[i], a.es[i] = a.fill(i, (any)(vec.GetBytes(sel)).(T1), a.vs[i], z, a.es[i], hasNull)
   127  	} else {
   128  		a.vs[i], a.es[i] = a.fill(i, vector.GetColumn[T1](vec)[sel], a.vs[i], z, a.es[i], hasNull)
   129  	}
   130  	return nil
   131  }
   132  
   133  func (a *UnaryAgg[T1, T2]) BatchFill(start int64, os []uint8, vps []uint64, zs []int64, vecs []*vector.Vector) error {
   134  	vec := vecs[0]
   135  	if vec.GetType().IsString() {
   136  		for i := range os {
   137  			hasNull := vec.GetNulls().Contains(uint64(i) + uint64(start))
   138  			if vps[i] == 0 {
   139  				continue
   140  			}
   141  			j := vps[i] - 1
   142  			if !vec.IsConst() {
   143  				a.vs[j], a.es[j] = a.fill(int64(j), (any)(vec.GetBytes(int64(i)+start)).(T1), a.vs[j], zs[int64(i)+start], a.es[j], hasNull)
   144  			} else {
   145  				a.vs[j], a.es[j] = a.fill(int64(j), (any)(vec.GetBytes(0)).(T1), a.vs[j], zs[int64(i)+start], a.es[j], hasNull)
   146  			}
   147  
   148  		}
   149  		return nil
   150  	}
   151  	vs := vector.GetColumn[T1](vec)
   152  	if a.batchFill != nil {
   153  		if err := a.batchFill(a.vs, vs, start, int64(len(os)), vps, zs, vec.GetNulls()); err != nil {
   154  			return err
   155  		}
   156  		nsp := vec.GetNulls()
   157  		if nsp.Any() {
   158  			for i := range os {
   159  				if !nsp.Contains(uint64(i) + uint64(start)) {
   160  					if vps[i] == 0 {
   161  						continue
   162  					}
   163  					a.es[vps[i]-1] = false
   164  				}
   165  			}
   166  		} else {
   167  			for i := range os {
   168  				if vps[i] == 0 {
   169  					continue
   170  				}
   171  				a.es[vps[i]-1] = false
   172  			}
   173  		}
   174  		return nil
   175  	}
   176  	for i := range os {
   177  		hasNull := vec.GetNulls().Contains(uint64(i) + uint64(start))
   178  		if vps[i] == 0 {
   179  			continue
   180  		}
   181  		j := vps[i] - 1
   182  		a.vs[j], a.es[j] = a.fill(int64(j), vs[int64(i)+start], a.vs[j], zs[int64(i)+start], a.es[j], hasNull)
   183  	}
   184  	return nil
   185  }
   186  
   187  func (a *UnaryAgg[T1, T2]) BulkFill(i int64, zs []int64, vecs []*vector.Vector) error {
   188  	vec := vecs[0]
   189  	if vec.GetType().IsString() {
   190  		len := vec.Length()
   191  		for j := 0; j < len; j++ {
   192  			hasNull := vec.GetNulls().Contains(uint64(j))
   193  			if !vec.IsConst() {
   194  				a.vs[i], a.es[i] = a.fill(i, (any)(vec.GetBytes(int64(j))).(T1), a.vs[i], zs[j], a.es[i], hasNull)
   195  			} else {
   196  				a.vs[i], a.es[i] = a.fill(i, (any)(vec.GetBytes(0)).(T1), a.vs[i], zs[j], a.es[i], hasNull)
   197  			}
   198  		}
   199  
   200  		return nil
   201  	}
   202  	vs := vector.GetColumn[T1](vec)
   203  	for j, v := range vs {
   204  		hasNull := vec.GetNulls().Contains(uint64(j))
   205  		a.vs[i], a.es[i] = a.fill(i, v, a.vs[i], zs[j], a.es[i], hasNull)
   206  	}
   207  	return nil
   208  }
   209  
   210  // Merge a[x] += b[y]
   211  func (a *UnaryAgg[T1, T2]) Merge(b Agg[any], x, y int64) error {
   212  	b0 := b.(*UnaryAgg[T1, T2])
   213  	if a.es[x] && !b0.es[y] {
   214  		a.otyp = b0.otyp
   215  	}
   216  	a.vs[x], a.es[x] = a.merge(x, y, a.vs[x], b0.vs[y], a.es[x], b0.es[y], b0.priv)
   217  	return nil
   218  }
   219  
   220  func (a *UnaryAgg[T1, T2]) BatchMerge(b Agg[any], start int64, os []uint8, vps []uint64) error {
   221  	b0 := b.(*UnaryAgg[T1, T2])
   222  	for i := range os {
   223  		if vps[i] == 0 {
   224  			continue
   225  		}
   226  		j := vps[i] - 1
   227  		if a.es[j] && !b0.es[int64(i)+start] {
   228  			a.otyp = b0.otyp
   229  		}
   230  		a.vs[j], a.es[j] = a.merge(int64(j), int64(i)+start, a.vs[j], b0.vs[int64(i)+start], a.es[j], b0.es[int64(i)+start], b0.priv)
   231  	}
   232  	return nil
   233  }
   234  
   235  func (a *UnaryAgg[T1, T2]) Eval(m *mpool.MPool) (*vector.Vector, error) {
   236  	defer func() {
   237  		a.Free(m)
   238  		a.da = nil
   239  		a.vs = nil
   240  		a.es = nil
   241  	}()
   242  	nsp := nulls.NewWithSize(len(a.es))
   243  	if !a.isCount {
   244  		for i, e := range a.es {
   245  			if e {
   246  				nsp.Set(uint64(i))
   247  			}
   248  		}
   249  	}
   250  	if a.otyp.IsString() {
   251  		vec := vector.New(a.otyp)
   252  		vec.Nsp = nsp
   253  		a.vs = a.eval(a.vs)
   254  		vs := (any)(a.vs).([][]byte)
   255  		for _, v := range vs {
   256  			if err := vec.Append(v, false, m); err != nil {
   257  				vec.Free(m)
   258  				return nil, err
   259  			}
   260  		}
   261  		return vec, nil
   262  	}
   263  	return vector.NewWithFixed(a.otyp, a.eval(a.vs), nsp, m), nil
   264  }
   265  
   266  func (a *UnaryAgg[T1, T2]) WildAggReAlloc(m *mpool.MPool) error {
   267  	d, err := m.Alloc(len(a.da))
   268  	if err != nil {
   269  		return err
   270  	}
   271  	copy(d, a.da)
   272  	a.da = d
   273  	setAggValues[T1, T2](a, a.otyp)
   274  	return nil
   275  }
   276  
   277  func (a *UnaryAgg[T1, T2]) IsDistinct() bool {
   278  	return false
   279  }
   280  
   281  func (a *UnaryAgg[T1, T2]) GetOperatorId() int {
   282  	return a.op
   283  }
   284  
   285  func (a *UnaryAgg[T1, T2]) GetInputTypes() []types.Type {
   286  	return a.ityps
   287  }
   288  
   289  func (a *UnaryAgg[T1, T2]) MarshalBinary() ([]byte, error) {
   290  	pData, err := a.priv.MarshalBinary()
   291  	if err != nil {
   292  		return nil, err
   293  	}
   294  	// encode the input types.
   295  	source := &EncodeAgg{
   296  		Op:         a.op,
   297  		Private:    pData,
   298  		Es:         a.es,
   299  		InputTypes: types.EncodeSlice(a.ityps),
   300  		OutputType: types.EncodeType(&a.otyp),
   301  		IsCount:    a.isCount,
   302  	}
   303  	switch {
   304  	case types.IsString(a.otyp.Oid):
   305  		source.Da = types.EncodeStringSlice(getUnaryAggStrVs(a))
   306  	default:
   307  		source.Da = a.da
   308  	}
   309  
   310  	return types.Encode(source)
   311  }
   312  
   313  func getUnaryAggStrVs(strUnaryAgg any) []string {
   314  	agg := strUnaryAgg.(*UnaryAgg[[]byte, []byte])
   315  	result := make([]string, len(agg.vs))
   316  	for i := range result {
   317  		result[i] = string(agg.vs[i])
   318  	}
   319  	return result
   320  }
   321  
   322  func (a *UnaryAgg[T1, T2]) UnmarshalBinary(data []byte) error {
   323  	// avoid resulting errors caused by morpc overusing memory
   324  	copyData := make([]byte, len(data))
   325  	copy(copyData, data)
   326  	decoded := new(EncodeAgg)
   327  	if err := types.Decode(copyData, decoded); err != nil {
   328  		return err
   329  	}
   330  
   331  	// Recover data
   332  	a.ityps = types.DecodeSlice[types.Type](decoded.InputTypes)
   333  	a.otyp = types.DecodeType(decoded.OutputType)
   334  	a.isCount = decoded.IsCount
   335  	a.es = decoded.Es
   336  	data = make([]byte, len(decoded.Da))
   337  	copy(data, decoded.Da)
   338  	a.da = data
   339  
   340  	setAggValues[T1, T2](a, a.otyp)
   341  
   342  	return a.priv.UnmarshalBinary(decoded.Private)
   343  }
   344  
   345  func setAggValues[T1, T2 any](agg any, typ types.Type) {
   346  	switch {
   347  	case types.IsString(typ.Oid):
   348  		a := agg.(*UnaryAgg[[]byte, []byte])
   349  		values := types.DecodeStringSlice(a.da)
   350  		a.vs = make([][]byte, len(values))
   351  		for i := range a.vs {
   352  			a.vs[i] = []byte(values[i])
   353  		}
   354  	default:
   355  		a := agg.(*UnaryAgg[T1, T2])
   356  		a.vs = types.DecodeSlice[T2](a.da)
   357  	}
   358  }