github.com/unigraph-dev/dgraph@v1.1.1-0.20200923154953-8b52b426f765/query/groupby.go (about) 1 /* 2 * Copyright 2017-2018 Dgraph Labs, Inc. and Contributors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package query 18 19 import ( 20 "fmt" 21 "sort" 22 "strconv" 23 24 "github.com/dgraph-io/dgraph/algo" 25 "github.com/dgraph-io/dgraph/protos/pb" 26 "github.com/dgraph-io/dgraph/types" 27 "github.com/pkg/errors" 28 ) 29 30 type groupPair struct { 31 key types.Val 32 attr string 33 } 34 35 type groupResult struct { 36 keys []groupPair 37 aggregates []groupPair 38 uids []uint64 39 } 40 41 func (grp *groupResult) aggregateChild(child *SubGraph) error { 42 fieldName := child.Params.Alias 43 if child.Params.DoCount { 44 if child.Attr != "uid" { 45 return errors.Errorf("Only uid predicate is allowed in count within groupby") 46 } 47 if fieldName == "" { 48 fieldName = "count" 49 } 50 grp.aggregates = append(grp.aggregates, groupPair{ 51 attr: fieldName, 52 key: types.Val{ 53 Tid: types.IntID, 54 Value: int64(len(grp.uids)), 55 }, 56 }) 57 return nil 58 } 59 if child.SrcFunc != nil && isAggregatorFn(child.SrcFunc.Name) { 60 if fieldName == "" { 61 fieldName = fmt.Sprintf("%s(%s)", child.SrcFunc.Name, child.Attr) 62 } 63 finalVal, err := aggregateGroup(grp, child) 64 if err != nil { 65 return err 66 } 67 grp.aggregates = append(grp.aggregates, groupPair{ 68 attr: fieldName, 69 key: finalVal, 70 }) 71 } 72 return nil 73 } 74 75 type groupResults struct { 76 group []*groupResult 77 } 78 79 type groupElements struct { 80 entities *pb.List 81 key types.Val 82 } 83 84 type uniq struct { 85 elements map[string]groupElements 86 attr string 87 } 88 89 type dedup struct { 90 groups []*uniq 91 } 92 93 func (d *dedup) getGroup(attr string) *uniq { 94 var res *uniq 95 // Looping last to first is better in this case. 96 for i := len(d.groups) - 1; i >= 0; i-- { 97 it := d.groups[i] 98 if attr == it.attr { 99 res = it 100 break 101 } 102 } 103 if res == nil { 104 // Create a new entry. 105 res = &uniq{ 106 attr: attr, 107 elements: make(map[string]groupElements), 108 } 109 d.groups = append(d.groups, res) 110 } 111 return res 112 } 113 114 func (d *dedup) addValue(attr string, value types.Val, uid uint64) { 115 cur := d.getGroup(attr) 116 // Create the string key. 117 var strKey string 118 if value.Tid == types.UidID { 119 strKey = strconv.FormatUint(value.Value.(uint64), 10) 120 } else { 121 valC := types.Val{Tid: types.StringID, Value: ""} 122 err := types.Marshal(value, &valC) 123 if err != nil { 124 return 125 } 126 strKey = valC.Value.(string) 127 } 128 129 if _, ok := cur.elements[strKey]; !ok { 130 // If this is the first element of the group. 131 cur.elements[strKey] = groupElements{ 132 key: value, 133 entities: &pb.List{Uids: []uint64{}}, 134 } 135 } 136 curEntity := cur.elements[strKey].entities 137 curEntity.Uids = append(curEntity.Uids, uid) 138 } 139 140 func aggregateGroup(grp *groupResult, child *SubGraph) (types.Val, error) { 141 ag := aggregator{ 142 name: child.SrcFunc.Name, 143 } 144 for _, uid := range grp.uids { 145 idx := sort.Search(len(child.SrcUIDs.Uids), func(i int) bool { 146 return child.SrcUIDs.Uids[i] >= uid 147 }) 148 if idx == len(child.SrcUIDs.Uids) || child.SrcUIDs.Uids[idx] != uid { 149 continue 150 } 151 152 if len(child.valueMatrix[idx].Values) == 0 { 153 continue 154 } 155 v := child.valueMatrix[idx].Values[0] 156 val, err := convertWithBestEffort(v, child.Attr) 157 if err != nil { 158 continue 159 } 160 ag.Apply(val) 161 } 162 return ag.Value() 163 } 164 165 // formGroup creates all possible groups with the list of uids that belong to that 166 // group. 167 func (res *groupResults) formGroups(dedupMap dedup, cur *pb.List, groupVal []groupPair) { 168 l := len(groupVal) 169 if len(dedupMap.groups) == 0 || (l != 0 && len(cur.Uids) == 0) { 170 // This group is already empty or no group can be formed. So stop. 171 return 172 } 173 174 if l == len(dedupMap.groups) { 175 a := make([]uint64, len(cur.Uids)) 176 b := make([]groupPair, len(groupVal)) 177 copy(a, cur.Uids) 178 copy(b, groupVal) 179 res.group = append(res.group, &groupResult{ 180 uids: a, 181 keys: b, 182 }) 183 return 184 } 185 186 for _, v := range dedupMap.groups[l].elements { 187 temp := new(pb.List) 188 groupVal = append(groupVal, groupPair{ 189 key: v.key, 190 attr: dedupMap.groups[l].attr, 191 }) 192 if l != 0 { 193 algo.IntersectWith(cur, v.entities, temp) 194 } else { 195 temp.Uids = make([]uint64, len(v.entities.Uids)) 196 copy(temp.Uids, v.entities.Uids) 197 } 198 res.formGroups(dedupMap, temp, groupVal) 199 groupVal = groupVal[:len(groupVal)-1] 200 } 201 } 202 203 func (sg *SubGraph) formResult(ul *pb.List) (*groupResults, error) { 204 var dedupMap dedup 205 res := new(groupResults) 206 207 for _, child := range sg.Children { 208 if !child.Params.ignoreResult { 209 continue 210 } 211 212 attr := child.Params.Alias 213 if attr == "" { 214 attr = child.Attr 215 } 216 if len(child.DestUIDs.GetUids()) > 0 { 217 // It's a UID node. 218 for i := 0; i < len(child.uidMatrix); i++ { 219 srcUid := child.SrcUIDs.Uids[i] 220 // Ignore uids which are not part of srcUid. 221 if algo.IndexOf(ul, srcUid) < 0 { 222 continue 223 } 224 225 ul := child.uidMatrix[i] 226 for _, uid := range ul.GetUids() { 227 dedupMap.addValue(attr, types.Val{Tid: types.UidID, Value: uid}, srcUid) 228 } 229 } 230 } else { 231 // It's a value node. 232 for i, v := range child.valueMatrix { 233 srcUid := child.SrcUIDs.Uids[i] 234 if len(v.Values) == 0 || algo.IndexOf(ul, srcUid) < 0 { 235 continue 236 } 237 val, err := convertTo(v.Values[0]) 238 if err != nil { 239 continue 240 } 241 dedupMap.addValue(attr, val, srcUid) 242 } 243 } 244 } 245 246 // Create all the groups here. 247 res.formGroups(dedupMap, &pb.List{}, []groupPair{}) 248 249 // Go over the groups and aggregate the values. 250 for _, child := range sg.Children { 251 if child.Params.ignoreResult { 252 continue 253 } 254 // This is a aggregation node. 255 for _, grp := range res.group { 256 err := grp.aggregateChild(child) 257 if err != nil && err != ErrEmptyVal { 258 return res, err 259 } 260 } 261 } 262 // Sort to order the groups for determinism. 263 sort.Slice(res.group, func(i, j int) bool { 264 return groupLess(res.group[i], res.group[j]) 265 }) 266 267 return res, nil 268 } 269 270 // This function is to use the fillVars. It is similar to formResult, the only difference being 271 // that it considers the whole uidMatrix to do the grouping before assigning the variable. 272 // TODO - Check if we can reduce this duplication. 273 func (sg *SubGraph) fillGroupedVars(doneVars map[string]varValue, path []*SubGraph) error { 274 var childHasVar bool 275 for _, child := range sg.Children { 276 if child.Params.Var != "" { 277 childHasVar = true 278 break 279 } 280 } 281 282 if !childHasVar { 283 return nil 284 } 285 286 var pathNode *SubGraph 287 var dedupMap dedup 288 289 for _, child := range sg.Children { 290 if !child.Params.ignoreResult { 291 continue 292 } 293 294 attr := child.Params.Alias 295 if attr == "" { 296 attr = child.Attr 297 } 298 if len(child.DestUIDs.GetUids()) > 0 { 299 // It's a UID node. 300 for i := 0; i < len(child.uidMatrix); i++ { 301 srcUid := child.SrcUIDs.Uids[i] 302 ul := child.uidMatrix[i] 303 for _, uid := range ul.Uids { 304 dedupMap.addValue(attr, types.Val{Tid: types.UidID, Value: uid}, srcUid) 305 } 306 } 307 pathNode = child 308 } else { 309 // It's a value node. 310 for i, v := range child.valueMatrix { 311 srcUid := child.SrcUIDs.Uids[i] 312 if len(v.Values) == 0 { 313 continue 314 } 315 val, err := convertTo(v.Values[0]) 316 if err != nil { 317 continue 318 } 319 dedupMap.addValue(attr, val, srcUid) 320 } 321 } 322 } 323 324 // Create all the groups here. 325 res := new(groupResults) 326 res.formGroups(dedupMap, &pb.List{}, []groupPair{}) 327 328 // Go over the groups and aggregate the values. 329 for _, child := range sg.Children { 330 if child.Params.ignoreResult { 331 continue 332 } 333 // This is a aggregation node. 334 for _, grp := range res.group { 335 err := grp.aggregateChild(child) 336 if err != nil && err != ErrEmptyVal { 337 return err 338 } 339 } 340 if child.Params.Var == "" { 341 continue 342 } 343 chVar := child.Params.Var 344 345 tempMap := make(map[uint64]types.Val) 346 for _, grp := range res.group { 347 if len(grp.keys) == 0 { 348 continue 349 } 350 if len(grp.keys) > 1 { 351 return errors.Errorf("Expected one UID for var in groupby but got: %d", len(grp.keys)) 352 } 353 uidVal := grp.keys[0].key.Value 354 uid, ok := uidVal.(uint64) 355 if !ok { 356 return errors.Errorf("Vars can be assigned only when grouped by UID attribute") 357 } 358 // grp.aggregates could be empty if schema conversion failed during aggregation 359 if len(grp.aggregates) > 0 { 360 tempMap[uid] = grp.aggregates[len(grp.aggregates)-1].key 361 } 362 } 363 doneVars[chVar] = varValue{ 364 Vals: tempMap, 365 path: append(path, pathNode), 366 } 367 } 368 return nil 369 } 370 371 func (sg *SubGraph) processGroupBy(doneVars map[string]varValue, path []*SubGraph) error { 372 for _, ul := range sg.uidMatrix { 373 // We need to process groupby for each list as grouping needs to happen for each path of the 374 // tree. 375 376 r, err := sg.formResult(ul) 377 if err != nil { 378 return err 379 } 380 sg.GroupbyRes = append(sg.GroupbyRes, r) 381 } 382 383 if err := sg.fillGroupedVars(doneVars, path); err != nil { 384 return err 385 } 386 387 // All the result that we want to return is in sg.GroupbyRes 388 sg.Children = sg.Children[:0] 389 390 return nil 391 } 392 393 func groupLess(a, b *groupResult) bool { 394 if len(a.uids) < len(b.uids) { 395 return true 396 } else if len(a.uids) != len(b.uids) { 397 return false 398 } 399 if len(a.keys) < len(b.keys) { 400 return true 401 } else if len(a.keys) != len(b.keys) { 402 return false 403 } 404 if len(a.aggregates) < len(b.aggregates) { 405 return true 406 } else if len(a.aggregates) != len(b.aggregates) { 407 return false 408 } 409 410 for i := range a.keys { 411 l, err := types.Less(a.keys[i].key, b.keys[i].key) 412 if err == nil { 413 if l { 414 return l 415 } 416 l, _ = types.Less(b.keys[i].key, a.keys[i].key) 417 if l { 418 return !l 419 } 420 } 421 } 422 423 for i := range a.aggregates { 424 if l, err := types.Less(a.aggregates[i].key, b.aggregates[i].key); err == nil { 425 if l { 426 return l 427 } 428 l, _ = types.Less(b.aggregates[i].key, a.aggregates[i].key) 429 if l { 430 return !l 431 } 432 } 433 } 434 return false 435 }