github.com/MontFerret/ferret@v0.18.0/pkg/runtime/expressions/clauses/collect_iterator.go (about)

     1  package clauses
     2  
     3  import (
     4  	"context"
     5  
     6  	"github.com/MontFerret/ferret/pkg/runtime/collections"
     7  	"github.com/MontFerret/ferret/pkg/runtime/core"
     8  	"github.com/MontFerret/ferret/pkg/runtime/values"
     9  	"github.com/MontFerret/ferret/pkg/runtime/values/types"
    10  )
    11  
    12  type CollectIterator struct {
    13  	ready      bool
    14  	values     []*core.Scope
    15  	pos        int
    16  	src        core.SourceMap
    17  	params     *Collect
    18  	dataSource collections.Iterator
    19  }
    20  
    21  func NewCollectIterator(
    22  	src core.SourceMap,
    23  	params *Collect,
    24  	dataSource collections.Iterator,
    25  ) (*CollectIterator, error) {
    26  	if params.group != nil {
    27  		if params.group.selectors != nil {
    28  			var err error
    29  			sorters := make([]*collections.Sorter, len(params.group.selectors))
    30  
    31  			for i, selector := range params.group.selectors {
    32  				sorter, err := newGroupSorter(selector)
    33  
    34  				if err != nil {
    35  					return nil, err
    36  				}
    37  
    38  				sorters[i] = sorter
    39  			}
    40  
    41  			dataSource, err = collections.NewSortIterator(dataSource, sorters...)
    42  
    43  			if err != nil {
    44  				return nil, err
    45  			}
    46  		}
    47  
    48  		if params.group.count != nil && params.group.projection != nil {
    49  			return nil, core.Error(core.ErrInvalidArgumentNumber, "counter and projection cannot be used together")
    50  		}
    51  	}
    52  
    53  	return &CollectIterator{
    54  		false,
    55  		nil,
    56  		0,
    57  		src,
    58  		params,
    59  		dataSource,
    60  	}, nil
    61  }
    62  
    63  func newGroupSorter(selector *CollectSelector) (*collections.Sorter, error) {
    64  	return collections.NewSorter(func(ctx context.Context, first, second *core.Scope) (int64, error) {
    65  		f, err := selector.expression.Exec(ctx, first)
    66  
    67  		if err != nil {
    68  			return -1, err
    69  		}
    70  
    71  		s, err := selector.expression.Exec(ctx, second)
    72  
    73  		if err != nil {
    74  			return -1, err
    75  		}
    76  
    77  		return f.Compare(s), nil
    78  	}, collections.SortDirectionAsc)
    79  }
    80  
    81  func (iterator *CollectIterator) Next(ctx context.Context, scope *core.Scope) (*core.Scope, error) {
    82  	if !iterator.ready {
    83  		iterator.ready = true
    84  		groups, err := iterator.init(ctx, scope)
    85  
    86  		if err != nil {
    87  			return nil, err
    88  		}
    89  
    90  		iterator.values = groups
    91  	}
    92  
    93  	if len(iterator.values) > iterator.pos {
    94  		val := iterator.values[iterator.pos]
    95  		iterator.pos++
    96  
    97  		return val, nil
    98  	}
    99  
   100  	return nil, core.ErrNoMoreData
   101  }
   102  
   103  func (iterator *CollectIterator) init(ctx context.Context, scope *core.Scope) ([]*core.Scope, error) {
   104  	if iterator.params.group != nil {
   105  		return iterator.group(ctx, scope)
   106  	}
   107  
   108  	if iterator.params.count != nil {
   109  		return iterator.count(ctx, scope)
   110  	}
   111  
   112  	if iterator.params.aggregate != nil {
   113  		return iterator.aggregate(ctx, scope)
   114  	}
   115  
   116  	return nil, core.ErrInvalidOperation
   117  }
   118  
   119  func (iterator *CollectIterator) group(ctx context.Context, scope *core.Scope) ([]*core.Scope, error) {
   120  	// TODO: honestly, this code is ugly. it needs to be refactored in more chained way with much less if statements
   121  	// slice of groups
   122  	collected := make([]*core.Scope, 0, 10)
   123  	// hash table of unique values
   124  	// key is a DataSet hash
   125  	// value is its index in result slice (collected)
   126  	hashTable := make(map[uint64]int)
   127  
   128  	groupSelectors := iterator.params.group.selectors
   129  	proj := iterator.params.group.projection
   130  	count := iterator.params.group.count
   131  	aggr := iterator.params.group.aggregate
   132  
   133  	// iterating over underlying data source
   134  	for {
   135  		// keep all defined variables in forked scopes
   136  		// all those variables should not be available for further executions
   137  		dataSourceScope, err := iterator.dataSource.Next(ctx, scope.Fork())
   138  
   139  		if err != nil {
   140  			if core.IsNoMoreData(err) {
   141  				break
   142  			}
   143  
   144  			return nil, err
   145  		}
   146  
   147  		// this data dataSourceScope represents a data of a given iteration with values retrieved by selectors
   148  		collectScope := scope.Fork()
   149  
   150  		// map for calculating a hash value
   151  		vals := make(map[string]core.Value)
   152  
   153  		// iterate over each selector for a current data
   154  		for _, selector := range groupSelectors {
   155  			// execute a selector and get a value
   156  			// e.g. COLLECT age = u.age
   157  			value, err := selector.expression.Exec(ctx, dataSourceScope)
   158  
   159  			if err != nil {
   160  				return nil, err
   161  			}
   162  
   163  			if err := collectScope.SetVariable(selector.variable, value); err != nil {
   164  				return nil, err
   165  			}
   166  
   167  			vals[selector.variable] = value
   168  		}
   169  
   170  		// it important to get hash value before projection and counting
   171  		// otherwise hash value will be inaccurate
   172  		h := values.MapHash(vals)
   173  
   174  		_, exists := hashTable[h]
   175  
   176  		if !exists {
   177  			collected = append(collected, collectScope)
   178  			hashTable[h] = len(collected) - 1
   179  
   180  			if proj != nil {
   181  				// create a new variable for keeping projection
   182  				if err := collectScope.SetVariable(proj.selector.variable, values.NewArray(10)); err != nil {
   183  					return nil, err
   184  				}
   185  			} else if count != nil {
   186  				// create a new variable for keeping counter
   187  				if err := collectScope.SetVariable(count.variable, values.ZeroInt); err != nil {
   188  					return nil, err
   189  				}
   190  			} else if aggr != nil {
   191  				// create a new variable for keeping aggregated values
   192  				for _, selector := range aggr.selectors {
   193  					arr := values.NewArray(len(selector.aggregators))
   194  
   195  					for range selector.aggregators {
   196  						arr.Push(values.None)
   197  					}
   198  
   199  					if err := collectScope.SetVariable(selector.variable, arr); err != nil {
   200  						return nil, err
   201  					}
   202  				}
   203  			}
   204  		}
   205  
   206  		if proj != nil {
   207  			idx := hashTable[h]
   208  			collectedScope := collected[idx]
   209  			groupValue, err := collectedScope.GetVariable(proj.selector.variable)
   210  
   211  			if err != nil {
   212  				return nil, err
   213  			}
   214  
   215  			arr, ok := groupValue.(*values.Array)
   216  
   217  			if !ok {
   218  				return nil, core.TypeError(groupValue.Type(), types.Int)
   219  			}
   220  
   221  			value, err := proj.selector.expression.Exec(ctx, dataSourceScope)
   222  
   223  			if err != nil {
   224  				return nil, err
   225  			}
   226  
   227  			arr.Push(value)
   228  		} else if count != nil {
   229  			idx := hashTable[h]
   230  			ds := collected[idx]
   231  			groupValue, err := ds.GetVariable(count.variable)
   232  
   233  			if err != nil {
   234  				return nil, err
   235  			}
   236  
   237  			counter, ok := groupValue.(values.Int)
   238  
   239  			if !ok {
   240  				return nil, core.TypeError(groupValue.Type(), types.Int)
   241  			}
   242  
   243  			groupValue = counter + 1
   244  			// dataSourceScope a new value
   245  			if err := ds.UpdateVariable(count.variable, groupValue); err != nil {
   246  				return nil, err
   247  			}
   248  		} else if aggr != nil {
   249  			idx := hashTable[h]
   250  			ds := collected[idx]
   251  
   252  			// iterate over each selector for a current data dataSourceScope
   253  			for _, selector := range aggr.selectors {
   254  				sv, err := ds.GetVariable(selector.variable)
   255  
   256  				if err != nil {
   257  					return nil, err
   258  				}
   259  
   260  				vv := sv.(*values.Array)
   261  
   262  				// execute a selector and get a value
   263  				// e.g. AGGREGATE age = CONCAT(u.age, u.dob)
   264  				// u.age and u.dob get executed
   265  				for idx, exp := range selector.aggregators {
   266  					arg, err := exp.Exec(ctx, dataSourceScope)
   267  
   268  					if err != nil {
   269  						return nil, err
   270  					}
   271  
   272  					var args *values.Array
   273  					idx := values.NewInt(idx)
   274  
   275  					if vv.Get(idx) == values.None {
   276  						args = values.NewArray(10)
   277  						vv.Set(idx, args)
   278  					} else {
   279  						args = vv.Get(idx).(*values.Array)
   280  					}
   281  
   282  					args.Push(arg)
   283  				}
   284  			}
   285  		}
   286  	}
   287  
   288  	if aggr != nil {
   289  		for _, iterScope := range collected {
   290  			for _, selector := range aggr.selectors {
   291  				sv, err := iterScope.GetVariable(selector.variable)
   292  
   293  				if err != nil {
   294  					return nil, err
   295  				}
   296  
   297  				arr := sv.(*values.Array)
   298  
   299  				matrix := make([]core.Value, arr.Length())
   300  
   301  				arr.ForEach(func(value core.Value, idx int) bool {
   302  					matrix[idx] = value
   303  
   304  					return true
   305  				})
   306  
   307  				reduced, err := selector.reducer(ctx, matrix...)
   308  
   309  				if err != nil {
   310  					return nil, err
   311  				}
   312  
   313  				// replace value with calculated one
   314  				if err := iterScope.UpdateVariable(selector.variable, reduced); err != nil {
   315  					return nil, err
   316  				}
   317  			}
   318  		}
   319  	}
   320  
   321  	return collected, nil
   322  }
   323  
   324  func (iterator *CollectIterator) count(ctx context.Context, scope *core.Scope) ([]*core.Scope, error) {
   325  	var counter int
   326  
   327  	// iterating over underlying data source
   328  	for {
   329  		// keep all defined variables in forked scopes
   330  		// all those variables should not be available for further executions
   331  		_, err := iterator.dataSource.Next(ctx, scope.Fork())
   332  
   333  		if err != nil {
   334  			if core.IsNoMoreData(err) {
   335  				break
   336  			}
   337  
   338  			return nil, err
   339  		}
   340  
   341  		counter++
   342  	}
   343  
   344  	cs := scope.Fork()
   345  
   346  	if err := cs.SetVariable(iterator.params.count.variable, values.NewInt(counter)); err != nil {
   347  		return nil, err
   348  	}
   349  
   350  	return []*core.Scope{cs}, nil
   351  }
   352  
   353  func (iterator *CollectIterator) aggregate(ctx context.Context, scope *core.Scope) ([]*core.Scope, error) {
   354  	cs := scope.Fork()
   355  
   356  	// matrix of aggregated expressions
   357  	// string key of the map is a selector variable
   358  	// value of the map is a matrix of arguments
   359  	// e.g. x = CONCAT(arg1, arg2, argN...)
   360  	// x is a string key where a nested array is an array of all values of argN expressions
   361  	aggregated := make(map[string][]core.Value)
   362  	selectors := iterator.params.aggregate.selectors
   363  
   364  	// iterating over underlying data source
   365  	for {
   366  		// keep all defined variables in forked scopes
   367  		// all those variables should not be available for further executions
   368  		os, err := iterator.dataSource.Next(ctx, scope.Fork())
   369  
   370  		if err != nil {
   371  			if core.IsNoMoreData(err) {
   372  				break
   373  			}
   374  
   375  			return nil, err
   376  		}
   377  
   378  		// iterate over each selector for a current data set
   379  		for _, selector := range selectors {
   380  			vv, exists := aggregated[selector.variable]
   381  
   382  			if !exists {
   383  				vv = make([]core.Value, len(selector.aggregators))
   384  				aggregated[selector.variable] = vv
   385  			}
   386  
   387  			// execute a selector and get a value
   388  			// e.g. AGGREGATE age = CONCAT(u.age, u.dob)
   389  			// u.age and u.dob get executed
   390  			for idx, exp := range selector.aggregators {
   391  				arg, err := exp.Exec(ctx, os)
   392  
   393  				if err != nil {
   394  					return nil, err
   395  				}
   396  
   397  				var args *values.Array
   398  
   399  				if vv[idx] == nil {
   400  					args = values.NewArray(10)
   401  					vv[idx] = args
   402  				} else {
   403  					args = vv[idx].(*values.Array)
   404  				}
   405  
   406  				args.Push(arg)
   407  			}
   408  		}
   409  	}
   410  
   411  	for _, selector := range selectors {
   412  		matrix := aggregated[selector.variable]
   413  
   414  		reduced, err := selector.reducer(ctx, matrix...)
   415  
   416  		if err != nil {
   417  			return nil, err
   418  		}
   419  
   420  		if err := cs.SetVariable(selector.variable, reduced); err != nil {
   421  			return nil, err
   422  		}
   423  	}
   424  
   425  	return []*core.Scope{cs}, nil
   426  }