github.com/unigraph-dev/dgraph@v1.1.1-0.20200923154953-8b52b426f765/query/groupby.go (about)

     1  /*
     2   * Copyright 2017-2018 Dgraph Labs, Inc. and Contributors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package query
    18  
    19  import (
    20  	"fmt"
    21  	"sort"
    22  	"strconv"
    23  
    24  	"github.com/dgraph-io/dgraph/algo"
    25  	"github.com/dgraph-io/dgraph/protos/pb"
    26  	"github.com/dgraph-io/dgraph/types"
    27  	"github.com/pkg/errors"
    28  )
    29  
    30  type groupPair struct {
    31  	key  types.Val
    32  	attr string
    33  }
    34  
    35  type groupResult struct {
    36  	keys       []groupPair
    37  	aggregates []groupPair
    38  	uids       []uint64
    39  }
    40  
    41  func (grp *groupResult) aggregateChild(child *SubGraph) error {
    42  	fieldName := child.Params.Alias
    43  	if child.Params.DoCount {
    44  		if child.Attr != "uid" {
    45  			return errors.Errorf("Only uid predicate is allowed in count within groupby")
    46  		}
    47  		if fieldName == "" {
    48  			fieldName = "count"
    49  		}
    50  		grp.aggregates = append(grp.aggregates, groupPair{
    51  			attr: fieldName,
    52  			key: types.Val{
    53  				Tid:   types.IntID,
    54  				Value: int64(len(grp.uids)),
    55  			},
    56  		})
    57  		return nil
    58  	}
    59  	if child.SrcFunc != nil && isAggregatorFn(child.SrcFunc.Name) {
    60  		if fieldName == "" {
    61  			fieldName = fmt.Sprintf("%s(%s)", child.SrcFunc.Name, child.Attr)
    62  		}
    63  		finalVal, err := aggregateGroup(grp, child)
    64  		if err != nil {
    65  			return err
    66  		}
    67  		grp.aggregates = append(grp.aggregates, groupPair{
    68  			attr: fieldName,
    69  			key:  finalVal,
    70  		})
    71  	}
    72  	return nil
    73  }
    74  
    75  type groupResults struct {
    76  	group []*groupResult
    77  }
    78  
    79  type groupElements struct {
    80  	entities *pb.List
    81  	key      types.Val
    82  }
    83  
    84  type uniq struct {
    85  	elements map[string]groupElements
    86  	attr     string
    87  }
    88  
    89  type dedup struct {
    90  	groups []*uniq
    91  }
    92  
    93  func (d *dedup) getGroup(attr string) *uniq {
    94  	var res *uniq
    95  	// Looping last to first is better in this case.
    96  	for i := len(d.groups) - 1; i >= 0; i-- {
    97  		it := d.groups[i]
    98  		if attr == it.attr {
    99  			res = it
   100  			break
   101  		}
   102  	}
   103  	if res == nil {
   104  		// Create a new entry.
   105  		res = &uniq{
   106  			attr:     attr,
   107  			elements: make(map[string]groupElements),
   108  		}
   109  		d.groups = append(d.groups, res)
   110  	}
   111  	return res
   112  }
   113  
   114  func (d *dedup) addValue(attr string, value types.Val, uid uint64) {
   115  	cur := d.getGroup(attr)
   116  	// Create the string key.
   117  	var strKey string
   118  	if value.Tid == types.UidID {
   119  		strKey = strconv.FormatUint(value.Value.(uint64), 10)
   120  	} else {
   121  		valC := types.Val{Tid: types.StringID, Value: ""}
   122  		err := types.Marshal(value, &valC)
   123  		if err != nil {
   124  			return
   125  		}
   126  		strKey = valC.Value.(string)
   127  	}
   128  
   129  	if _, ok := cur.elements[strKey]; !ok {
   130  		// If this is the first element of the group.
   131  		cur.elements[strKey] = groupElements{
   132  			key:      value,
   133  			entities: &pb.List{Uids: []uint64{}},
   134  		}
   135  	}
   136  	curEntity := cur.elements[strKey].entities
   137  	curEntity.Uids = append(curEntity.Uids, uid)
   138  }
   139  
   140  func aggregateGroup(grp *groupResult, child *SubGraph) (types.Val, error) {
   141  	ag := aggregator{
   142  		name: child.SrcFunc.Name,
   143  	}
   144  	for _, uid := range grp.uids {
   145  		idx := sort.Search(len(child.SrcUIDs.Uids), func(i int) bool {
   146  			return child.SrcUIDs.Uids[i] >= uid
   147  		})
   148  		if idx == len(child.SrcUIDs.Uids) || child.SrcUIDs.Uids[idx] != uid {
   149  			continue
   150  		}
   151  
   152  		if len(child.valueMatrix[idx].Values) == 0 {
   153  			continue
   154  		}
   155  		v := child.valueMatrix[idx].Values[0]
   156  		val, err := convertWithBestEffort(v, child.Attr)
   157  		if err != nil {
   158  			continue
   159  		}
   160  		ag.Apply(val)
   161  	}
   162  	return ag.Value()
   163  }
   164  
   165  // formGroup creates all possible groups with the list of uids that belong to that
   166  // group.
   167  func (res *groupResults) formGroups(dedupMap dedup, cur *pb.List, groupVal []groupPair) {
   168  	l := len(groupVal)
   169  	if len(dedupMap.groups) == 0 || (l != 0 && len(cur.Uids) == 0) {
   170  		// This group is already empty or no group can be formed. So stop.
   171  		return
   172  	}
   173  
   174  	if l == len(dedupMap.groups) {
   175  		a := make([]uint64, len(cur.Uids))
   176  		b := make([]groupPair, len(groupVal))
   177  		copy(a, cur.Uids)
   178  		copy(b, groupVal)
   179  		res.group = append(res.group, &groupResult{
   180  			uids: a,
   181  			keys: b,
   182  		})
   183  		return
   184  	}
   185  
   186  	for _, v := range dedupMap.groups[l].elements {
   187  		temp := new(pb.List)
   188  		groupVal = append(groupVal, groupPair{
   189  			key:  v.key,
   190  			attr: dedupMap.groups[l].attr,
   191  		})
   192  		if l != 0 {
   193  			algo.IntersectWith(cur, v.entities, temp)
   194  		} else {
   195  			temp.Uids = make([]uint64, len(v.entities.Uids))
   196  			copy(temp.Uids, v.entities.Uids)
   197  		}
   198  		res.formGroups(dedupMap, temp, groupVal)
   199  		groupVal = groupVal[:len(groupVal)-1]
   200  	}
   201  }
   202  
   203  func (sg *SubGraph) formResult(ul *pb.List) (*groupResults, error) {
   204  	var dedupMap dedup
   205  	res := new(groupResults)
   206  
   207  	for _, child := range sg.Children {
   208  		if !child.Params.ignoreResult {
   209  			continue
   210  		}
   211  
   212  		attr := child.Params.Alias
   213  		if attr == "" {
   214  			attr = child.Attr
   215  		}
   216  		if len(child.DestUIDs.GetUids()) > 0 {
   217  			// It's a UID node.
   218  			for i := 0; i < len(child.uidMatrix); i++ {
   219  				srcUid := child.SrcUIDs.Uids[i]
   220  				// Ignore uids which are not part of srcUid.
   221  				if algo.IndexOf(ul, srcUid) < 0 {
   222  					continue
   223  				}
   224  
   225  				ul := child.uidMatrix[i]
   226  				for _, uid := range ul.GetUids() {
   227  					dedupMap.addValue(attr, types.Val{Tid: types.UidID, Value: uid}, srcUid)
   228  				}
   229  			}
   230  		} else {
   231  			// It's a value node.
   232  			for i, v := range child.valueMatrix {
   233  				srcUid := child.SrcUIDs.Uids[i]
   234  				if len(v.Values) == 0 || algo.IndexOf(ul, srcUid) < 0 {
   235  					continue
   236  				}
   237  				val, err := convertTo(v.Values[0])
   238  				if err != nil {
   239  					continue
   240  				}
   241  				dedupMap.addValue(attr, val, srcUid)
   242  			}
   243  		}
   244  	}
   245  
   246  	// Create all the groups here.
   247  	res.formGroups(dedupMap, &pb.List{}, []groupPair{})
   248  
   249  	// Go over the groups and aggregate the values.
   250  	for _, child := range sg.Children {
   251  		if child.Params.ignoreResult {
   252  			continue
   253  		}
   254  		// This is a aggregation node.
   255  		for _, grp := range res.group {
   256  			err := grp.aggregateChild(child)
   257  			if err != nil && err != ErrEmptyVal {
   258  				return res, err
   259  			}
   260  		}
   261  	}
   262  	// Sort to order the groups for determinism.
   263  	sort.Slice(res.group, func(i, j int) bool {
   264  		return groupLess(res.group[i], res.group[j])
   265  	})
   266  
   267  	return res, nil
   268  }
   269  
   270  // This function is to use the fillVars. It is similar to formResult, the only difference being
   271  // that it considers the whole uidMatrix to do the grouping before assigning the variable.
   272  // TODO - Check if we can reduce this duplication.
   273  func (sg *SubGraph) fillGroupedVars(doneVars map[string]varValue, path []*SubGraph) error {
   274  	var childHasVar bool
   275  	for _, child := range sg.Children {
   276  		if child.Params.Var != "" {
   277  			childHasVar = true
   278  			break
   279  		}
   280  	}
   281  
   282  	if !childHasVar {
   283  		return nil
   284  	}
   285  
   286  	var pathNode *SubGraph
   287  	var dedupMap dedup
   288  
   289  	for _, child := range sg.Children {
   290  		if !child.Params.ignoreResult {
   291  			continue
   292  		}
   293  
   294  		attr := child.Params.Alias
   295  		if attr == "" {
   296  			attr = child.Attr
   297  		}
   298  		if len(child.DestUIDs.GetUids()) > 0 {
   299  			// It's a UID node.
   300  			for i := 0; i < len(child.uidMatrix); i++ {
   301  				srcUid := child.SrcUIDs.Uids[i]
   302  				ul := child.uidMatrix[i]
   303  				for _, uid := range ul.Uids {
   304  					dedupMap.addValue(attr, types.Val{Tid: types.UidID, Value: uid}, srcUid)
   305  				}
   306  			}
   307  			pathNode = child
   308  		} else {
   309  			// It's a value node.
   310  			for i, v := range child.valueMatrix {
   311  				srcUid := child.SrcUIDs.Uids[i]
   312  				if len(v.Values) == 0 {
   313  					continue
   314  				}
   315  				val, err := convertTo(v.Values[0])
   316  				if err != nil {
   317  					continue
   318  				}
   319  				dedupMap.addValue(attr, val, srcUid)
   320  			}
   321  		}
   322  	}
   323  
   324  	// Create all the groups here.
   325  	res := new(groupResults)
   326  	res.formGroups(dedupMap, &pb.List{}, []groupPair{})
   327  
   328  	// Go over the groups and aggregate the values.
   329  	for _, child := range sg.Children {
   330  		if child.Params.ignoreResult {
   331  			continue
   332  		}
   333  		// This is a aggregation node.
   334  		for _, grp := range res.group {
   335  			err := grp.aggregateChild(child)
   336  			if err != nil && err != ErrEmptyVal {
   337  				return err
   338  			}
   339  		}
   340  		if child.Params.Var == "" {
   341  			continue
   342  		}
   343  		chVar := child.Params.Var
   344  
   345  		tempMap := make(map[uint64]types.Val)
   346  		for _, grp := range res.group {
   347  			if len(grp.keys) == 0 {
   348  				continue
   349  			}
   350  			if len(grp.keys) > 1 {
   351  				return errors.Errorf("Expected one UID for var in groupby but got: %d", len(grp.keys))
   352  			}
   353  			uidVal := grp.keys[0].key.Value
   354  			uid, ok := uidVal.(uint64)
   355  			if !ok {
   356  				return errors.Errorf("Vars can be assigned only when grouped by UID attribute")
   357  			}
   358  			// grp.aggregates could be empty if schema conversion failed during aggregation
   359  			if len(grp.aggregates) > 0 {
   360  				tempMap[uid] = grp.aggregates[len(grp.aggregates)-1].key
   361  			}
   362  		}
   363  		doneVars[chVar] = varValue{
   364  			Vals: tempMap,
   365  			path: append(path, pathNode),
   366  		}
   367  	}
   368  	return nil
   369  }
   370  
   371  func (sg *SubGraph) processGroupBy(doneVars map[string]varValue, path []*SubGraph) error {
   372  	for _, ul := range sg.uidMatrix {
   373  		// We need to process groupby for each list as grouping needs to happen for each path of the
   374  		// tree.
   375  
   376  		r, err := sg.formResult(ul)
   377  		if err != nil {
   378  			return err
   379  		}
   380  		sg.GroupbyRes = append(sg.GroupbyRes, r)
   381  	}
   382  
   383  	if err := sg.fillGroupedVars(doneVars, path); err != nil {
   384  		return err
   385  	}
   386  
   387  	// All the result that we want to return is in sg.GroupbyRes
   388  	sg.Children = sg.Children[:0]
   389  
   390  	return nil
   391  }
   392  
   393  func groupLess(a, b *groupResult) bool {
   394  	if len(a.uids) < len(b.uids) {
   395  		return true
   396  	} else if len(a.uids) != len(b.uids) {
   397  		return false
   398  	}
   399  	if len(a.keys) < len(b.keys) {
   400  		return true
   401  	} else if len(a.keys) != len(b.keys) {
   402  		return false
   403  	}
   404  	if len(a.aggregates) < len(b.aggregates) {
   405  		return true
   406  	} else if len(a.aggregates) != len(b.aggregates) {
   407  		return false
   408  	}
   409  
   410  	for i := range a.keys {
   411  		l, err := types.Less(a.keys[i].key, b.keys[i].key)
   412  		if err == nil {
   413  			if l {
   414  				return l
   415  			}
   416  			l, _ = types.Less(b.keys[i].key, a.keys[i].key)
   417  			if l {
   418  				return !l
   419  			}
   420  		}
   421  	}
   422  
   423  	for i := range a.aggregates {
   424  		if l, err := types.Less(a.aggregates[i].key, b.aggregates[i].key); err == nil {
   425  			if l {
   426  				return l
   427  			}
   428  			l, _ = types.Less(b.aggregates[i].key, a.aggregates[i].key)
   429  			if l {
   430  				return !l
   431  			}
   432  		}
   433  	}
   434  	return false
   435  }