github.com/unigraph-dev/dgraph@v1.1.1-0.20200923154953-8b52b426f765/query/aggregator.go (about) 1 /* 2 * Copyright 2017-2018 Dgraph Labs, Inc. and Contributors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package query 18 19 import ( 20 "bytes" 21 "math" 22 "time" 23 24 "github.com/dgraph-io/dgraph/protos/pb" 25 "github.com/dgraph-io/dgraph/types" 26 "github.com/dgraph-io/dgraph/x" 27 "github.com/pkg/errors" 28 ) 29 30 type aggregator struct { 31 name string 32 result types.Val 33 count int // used when we need avergae. 34 } 35 36 func isUnary(f string) bool { 37 return f == "ln" || f == "exp" || f == "u-" || f == "sqrt" || 38 f == "floor" || f == "ceil" || f == "since" 39 } 40 41 func isBinaryBoolean(f string) bool { 42 return f == "<" || f == ">" || f == "<=" || f == ">=" || 43 f == "==" || f == "!=" 44 } 45 46 func isTernary(f string) bool { 47 return f == "cond" 48 } 49 50 func isBinary(f string) bool { 51 return f == "+" || f == "*" || f == "-" || f == "/" || f == "%" || 52 f == "max" || f == "min" || f == "logbase" || f == "pow" 53 } 54 55 func convertTo(from *pb.TaskValue) (types.Val, error) { 56 vh, _ := getValue(from) 57 if bytes.Equal(from.Val, x.Nilbyte) { 58 return vh, ErrEmptyVal 59 } 60 va, err := types.Convert(vh, vh.Tid) 61 if err != nil { 62 return vh, errors.Wrapf(err, "Fail to convert from api.Value to types.Val") 63 } 64 return va, err 65 } 66 67 func compareValues(ag string, va, vb types.Val) (bool, error) { 68 if !isBinaryBoolean(ag) { 69 x.Fatalf("Function %v is not binary boolean", ag) 70 } 71 72 _, err := types.Less(va, vb) 73 if err != nil { 74 //Try to convert values. 75 if va.Tid == types.IntID { 76 va.Tid = types.FloatID 77 va.Value = float64(va.Value.(int64)) 78 } else if vb.Tid == types.IntID { 79 vb.Tid = types.FloatID 80 vb.Value = float64(vb.Value.(int64)) 81 } else { 82 return false, err 83 } 84 } 85 isLess, err := types.Less(va, vb) 86 if err != nil { 87 return false, err 88 } 89 isMore, err := types.Less(vb, va) 90 if err != nil { 91 return false, err 92 } 93 isEqual, err := types.Equal(va, vb) 94 if err != nil { 95 return false, err 96 } 97 switch ag { 98 case "<": 99 return isLess, nil 100 case ">": 101 return isMore, nil 102 case "<=": 103 return isLess || isEqual, nil 104 case ">=": 105 return isMore || isEqual, nil 106 case "==": 107 return isEqual, nil 108 case "!=": 109 return !isEqual, nil 110 } 111 return false, errors.Errorf("Invalid compare function %q", ag) 112 } 113 114 func (ag *aggregator) ApplyVal(v types.Val) error { 115 if v.Value == nil { 116 // If the value is missing, treat it as 0. 117 v.Value = int64(0) 118 v.Tid = types.IntID 119 } 120 121 var isIntOrFloat bool 122 var l float64 123 if v.Tid == types.IntID { 124 l = float64(v.Value.(int64)) 125 v.Value = l 126 v.Tid = types.FloatID 127 isIntOrFloat = true 128 } else if v.Tid == types.FloatID { 129 l = v.Value.(float64) 130 isIntOrFloat = true 131 } 132 // If its not int or float, keep the type. 133 134 var res types.Val 135 if isUnary(ag.name) { 136 switch ag.name { 137 case "ln": 138 if !isIntOrFloat { 139 return errors.Errorf("Wrong type encountered for func %q", ag.name) 140 } 141 v.Value = math.Log(l) 142 res = v 143 case "exp": 144 if !isIntOrFloat { 145 return errors.Errorf("Wrong type encountered for func %q", ag.name) 146 } 147 v.Value = math.Exp(l) 148 res = v 149 case "u-": 150 if !isIntOrFloat { 151 return errors.Errorf("Wrong type encountered for func %q", ag.name) 152 } 153 v.Value = -l 154 res = v 155 case "sqrt": 156 if !isIntOrFloat { 157 return errors.Errorf("Wrong type encountered for func %q", ag.name) 158 } 159 v.Value = math.Sqrt(l) 160 res = v 161 case "floor": 162 if !isIntOrFloat { 163 return errors.Errorf("Wrong type encountered for func %q", ag.name) 164 } 165 v.Value = math.Floor(l) 166 res = v 167 case "ceil": 168 if !isIntOrFloat { 169 return errors.Errorf("Wrong type encountered for func %q", ag.name) 170 } 171 v.Value = math.Ceil(l) 172 res = v 173 case "since": 174 if v.Tid == types.DateTimeID { 175 v.Value = float64(time.Since(v.Value.(time.Time))) / 1000000000.0 176 v.Tid = types.FloatID 177 } else { 178 return errors.Errorf("Wrong type encountered for func %q", ag.name) 179 } 180 res = v 181 } 182 ag.result = res 183 return nil 184 } 185 186 if ag.result.Value == nil { 187 ag.result = v 188 return nil 189 } 190 191 va := ag.result 192 if va.Tid != types.IntID && va.Tid != types.FloatID { 193 isIntOrFloat = false 194 } 195 switch ag.name { 196 case "+": 197 if !isIntOrFloat { 198 return errors.Errorf("Wrong type encountered for func %q", ag.name) 199 } 200 va.Value = va.Value.(float64) + l 201 res = va 202 case "-": 203 if !isIntOrFloat { 204 return errors.Errorf("Wrong type encountered for func %q", ag.name) 205 } 206 va.Value = va.Value.(float64) - l 207 res = va 208 case "*": 209 if !isIntOrFloat { 210 return errors.Errorf("Wrong type encountered for func %q", ag.name) 211 } 212 va.Value = va.Value.(float64) * l 213 res = va 214 case "/": 215 if !isIntOrFloat { 216 return errors.Errorf("Wrong type encountered for func %q %q %q", ag.name, va.Tid, v.Tid) 217 } 218 if l == 0 { 219 return errors.Errorf("Division by zero") 220 } 221 va.Value = va.Value.(float64) / l 222 res = va 223 case "%": 224 if !isIntOrFloat { 225 return errors.Errorf("Wrong type encountered for func %q", ag.name) 226 } 227 if l == 0 { 228 return errors.Errorf("Division by zero") 229 } 230 va.Value = math.Mod(va.Value.(float64), l) 231 res = va 232 case "pow": 233 if !isIntOrFloat { 234 return errors.Errorf("Wrong type encountered for func %q", ag.name) 235 } 236 va.Value = math.Pow(va.Value.(float64), l) 237 res = va 238 case "logbase": 239 if l == 1 { 240 return nil 241 } 242 if !isIntOrFloat { 243 return errors.Errorf("Wrong type encountered for func %q", ag.name) 244 } 245 va.Value = math.Log(va.Value.(float64)) / math.Log(l) 246 res = va 247 case "min": 248 r, err := types.Less(va, v) 249 if err == nil && !r { 250 res = v 251 } else { 252 res = va 253 } 254 case "max": 255 r, err := types.Less(va, v) 256 if err == nil && r { 257 res = v 258 } else { 259 res = va 260 } 261 default: 262 return errors.Errorf("Unhandled aggregator function %q", ag.name) 263 } 264 ag.result = res 265 return nil 266 } 267 268 func (ag *aggregator) Apply(val types.Val) { 269 if ag.result.Value == nil { 270 ag.result = val 271 ag.count++ 272 return 273 } 274 275 va := ag.result 276 vb := val 277 var res types.Val 278 switch ag.name { 279 case "min": 280 r, err := types.Less(va, vb) 281 if err == nil && !r { 282 res = vb 283 } else { 284 res = va 285 } 286 case "max": 287 r, err := types.Less(va, vb) 288 if err == nil && r { 289 res = vb 290 } else { 291 res = va 292 } 293 case "sum", "avg": 294 if va.Tid == types.IntID && vb.Tid == types.IntID { 295 va.Value = va.Value.(int64) + vb.Value.(int64) 296 } else if va.Tid == types.FloatID && vb.Tid == types.FloatID { 297 va.Value = va.Value.(float64) + vb.Value.(float64) 298 } 299 // Skipping the else case since that means the pair cannot be summed. 300 res = va 301 default: 302 x.Fatalf("Unhandled aggregator function %v", ag.name) 303 } 304 ag.count++ 305 ag.result = res 306 } 307 308 func (ag *aggregator) ValueMarshalled() (*pb.TaskValue, error) { 309 data := types.ValueForType(types.BinaryID) 310 ag.divideByCount() 311 res := &pb.TaskValue{ValType: ag.result.Tid.Enum(), Val: x.Nilbyte} 312 if ag.result.Value == nil { 313 return res, nil 314 } 315 // We'll divide it by the count if it's an avg aggregator. 316 err := types.Marshal(ag.result, &data) 317 if err != nil { 318 return res, err 319 } 320 res.Val = data.Value.([]byte) 321 return res, nil 322 } 323 324 func (ag *aggregator) divideByCount() { 325 if ag.name != "avg" || ag.count == 0 || ag.result.Value == nil { 326 return 327 } 328 var v float64 329 if ag.result.Tid == types.IntID { 330 v = float64(ag.result.Value.(int64)) 331 } else if ag.result.Tid == types.FloatID { 332 v = ag.result.Value.(float64) 333 } 334 335 ag.result.Tid = types.FloatID 336 ag.result.Value = v / float64(ag.count) 337 } 338 339 func (ag *aggregator) Value() (types.Val, error) { 340 if ag.result.Value == nil { 341 return ag.result, ErrEmptyVal 342 } 343 ag.divideByCount() 344 if ag.result.Tid == types.FloatID { 345 if math.IsInf(ag.result.Value.(float64), 1) { 346 ag.result.Value = math.MaxFloat64 347 } else if math.IsInf(ag.result.Value.(float64), -1) { 348 ag.result.Value = -1 * math.MaxFloat64 349 } else if math.IsNaN(ag.result.Value.(float64)) { 350 ag.result.Value = 0.0 351 } 352 } 353 return ag.result, nil 354 }