github.com/unigraph-dev/dgraph@v1.1.1-0.20200923154953-8b52b426f765/query/aggregator.go (about)

     1  /*
     2   * Copyright 2017-2018 Dgraph Labs, Inc. and Contributors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package query
    18  
    19  import (
    20  	"bytes"
    21  	"math"
    22  	"time"
    23  
    24  	"github.com/dgraph-io/dgraph/protos/pb"
    25  	"github.com/dgraph-io/dgraph/types"
    26  	"github.com/dgraph-io/dgraph/x"
    27  	"github.com/pkg/errors"
    28  )
    29  
    30  type aggregator struct {
    31  	name   string
    32  	result types.Val
    33  	count  int // used when we need avergae.
    34  }
    35  
    36  func isUnary(f string) bool {
    37  	return f == "ln" || f == "exp" || f == "u-" || f == "sqrt" ||
    38  		f == "floor" || f == "ceil" || f == "since"
    39  }
    40  
    41  func isBinaryBoolean(f string) bool {
    42  	return f == "<" || f == ">" || f == "<=" || f == ">=" ||
    43  		f == "==" || f == "!="
    44  }
    45  
    46  func isTernary(f string) bool {
    47  	return f == "cond"
    48  }
    49  
    50  func isBinary(f string) bool {
    51  	return f == "+" || f == "*" || f == "-" || f == "/" || f == "%" ||
    52  		f == "max" || f == "min" || f == "logbase" || f == "pow"
    53  }
    54  
    55  func convertTo(from *pb.TaskValue) (types.Val, error) {
    56  	vh, _ := getValue(from)
    57  	if bytes.Equal(from.Val, x.Nilbyte) {
    58  		return vh, ErrEmptyVal
    59  	}
    60  	va, err := types.Convert(vh, vh.Tid)
    61  	if err != nil {
    62  		return vh, errors.Wrapf(err, "Fail to convert from api.Value to types.Val")
    63  	}
    64  	return va, err
    65  }
    66  
    67  func compareValues(ag string, va, vb types.Val) (bool, error) {
    68  	if !isBinaryBoolean(ag) {
    69  		x.Fatalf("Function %v is not binary boolean", ag)
    70  	}
    71  
    72  	_, err := types.Less(va, vb)
    73  	if err != nil {
    74  		//Try to convert values.
    75  		if va.Tid == types.IntID {
    76  			va.Tid = types.FloatID
    77  			va.Value = float64(va.Value.(int64))
    78  		} else if vb.Tid == types.IntID {
    79  			vb.Tid = types.FloatID
    80  			vb.Value = float64(vb.Value.(int64))
    81  		} else {
    82  			return false, err
    83  		}
    84  	}
    85  	isLess, err := types.Less(va, vb)
    86  	if err != nil {
    87  		return false, err
    88  	}
    89  	isMore, err := types.Less(vb, va)
    90  	if err != nil {
    91  		return false, err
    92  	}
    93  	isEqual, err := types.Equal(va, vb)
    94  	if err != nil {
    95  		return false, err
    96  	}
    97  	switch ag {
    98  	case "<":
    99  		return isLess, nil
   100  	case ">":
   101  		return isMore, nil
   102  	case "<=":
   103  		return isLess || isEqual, nil
   104  	case ">=":
   105  		return isMore || isEqual, nil
   106  	case "==":
   107  		return isEqual, nil
   108  	case "!=":
   109  		return !isEqual, nil
   110  	}
   111  	return false, errors.Errorf("Invalid compare function %q", ag)
   112  }
   113  
   114  func (ag *aggregator) ApplyVal(v types.Val) error {
   115  	if v.Value == nil {
   116  		// If the value is missing, treat it as 0.
   117  		v.Value = int64(0)
   118  		v.Tid = types.IntID
   119  	}
   120  
   121  	var isIntOrFloat bool
   122  	var l float64
   123  	if v.Tid == types.IntID {
   124  		l = float64(v.Value.(int64))
   125  		v.Value = l
   126  		v.Tid = types.FloatID
   127  		isIntOrFloat = true
   128  	} else if v.Tid == types.FloatID {
   129  		l = v.Value.(float64)
   130  		isIntOrFloat = true
   131  	}
   132  	// If its not int or float, keep the type.
   133  
   134  	var res types.Val
   135  	if isUnary(ag.name) {
   136  		switch ag.name {
   137  		case "ln":
   138  			if !isIntOrFloat {
   139  				return errors.Errorf("Wrong type encountered for func %q", ag.name)
   140  			}
   141  			v.Value = math.Log(l)
   142  			res = v
   143  		case "exp":
   144  			if !isIntOrFloat {
   145  				return errors.Errorf("Wrong type encountered for func %q", ag.name)
   146  			}
   147  			v.Value = math.Exp(l)
   148  			res = v
   149  		case "u-":
   150  			if !isIntOrFloat {
   151  				return errors.Errorf("Wrong type encountered for func %q", ag.name)
   152  			}
   153  			v.Value = -l
   154  			res = v
   155  		case "sqrt":
   156  			if !isIntOrFloat {
   157  				return errors.Errorf("Wrong type encountered for func %q", ag.name)
   158  			}
   159  			v.Value = math.Sqrt(l)
   160  			res = v
   161  		case "floor":
   162  			if !isIntOrFloat {
   163  				return errors.Errorf("Wrong type encountered for func %q", ag.name)
   164  			}
   165  			v.Value = math.Floor(l)
   166  			res = v
   167  		case "ceil":
   168  			if !isIntOrFloat {
   169  				return errors.Errorf("Wrong type encountered for func %q", ag.name)
   170  			}
   171  			v.Value = math.Ceil(l)
   172  			res = v
   173  		case "since":
   174  			if v.Tid == types.DateTimeID {
   175  				v.Value = float64(time.Since(v.Value.(time.Time))) / 1000000000.0
   176  				v.Tid = types.FloatID
   177  			} else {
   178  				return errors.Errorf("Wrong type encountered for func %q", ag.name)
   179  			}
   180  			res = v
   181  		}
   182  		ag.result = res
   183  		return nil
   184  	}
   185  
   186  	if ag.result.Value == nil {
   187  		ag.result = v
   188  		return nil
   189  	}
   190  
   191  	va := ag.result
   192  	if va.Tid != types.IntID && va.Tid != types.FloatID {
   193  		isIntOrFloat = false
   194  	}
   195  	switch ag.name {
   196  	case "+":
   197  		if !isIntOrFloat {
   198  			return errors.Errorf("Wrong type encountered for func %q", ag.name)
   199  		}
   200  		va.Value = va.Value.(float64) + l
   201  		res = va
   202  	case "-":
   203  		if !isIntOrFloat {
   204  			return errors.Errorf("Wrong type encountered for func %q", ag.name)
   205  		}
   206  		va.Value = va.Value.(float64) - l
   207  		res = va
   208  	case "*":
   209  		if !isIntOrFloat {
   210  			return errors.Errorf("Wrong type encountered for func %q", ag.name)
   211  		}
   212  		va.Value = va.Value.(float64) * l
   213  		res = va
   214  	case "/":
   215  		if !isIntOrFloat {
   216  			return errors.Errorf("Wrong type encountered for func %q %q %q", ag.name, va.Tid, v.Tid)
   217  		}
   218  		if l == 0 {
   219  			return errors.Errorf("Division by zero")
   220  		}
   221  		va.Value = va.Value.(float64) / l
   222  		res = va
   223  	case "%":
   224  		if !isIntOrFloat {
   225  			return errors.Errorf("Wrong type encountered for func %q", ag.name)
   226  		}
   227  		if l == 0 {
   228  			return errors.Errorf("Division by zero")
   229  		}
   230  		va.Value = math.Mod(va.Value.(float64), l)
   231  		res = va
   232  	case "pow":
   233  		if !isIntOrFloat {
   234  			return errors.Errorf("Wrong type encountered for func %q", ag.name)
   235  		}
   236  		va.Value = math.Pow(va.Value.(float64), l)
   237  		res = va
   238  	case "logbase":
   239  		if l == 1 {
   240  			return nil
   241  		}
   242  		if !isIntOrFloat {
   243  			return errors.Errorf("Wrong type encountered for func %q", ag.name)
   244  		}
   245  		va.Value = math.Log(va.Value.(float64)) / math.Log(l)
   246  		res = va
   247  	case "min":
   248  		r, err := types.Less(va, v)
   249  		if err == nil && !r {
   250  			res = v
   251  		} else {
   252  			res = va
   253  		}
   254  	case "max":
   255  		r, err := types.Less(va, v)
   256  		if err == nil && r {
   257  			res = v
   258  		} else {
   259  			res = va
   260  		}
   261  	default:
   262  		return errors.Errorf("Unhandled aggregator function %q", ag.name)
   263  	}
   264  	ag.result = res
   265  	return nil
   266  }
   267  
   268  func (ag *aggregator) Apply(val types.Val) {
   269  	if ag.result.Value == nil {
   270  		ag.result = val
   271  		ag.count++
   272  		return
   273  	}
   274  
   275  	va := ag.result
   276  	vb := val
   277  	var res types.Val
   278  	switch ag.name {
   279  	case "min":
   280  		r, err := types.Less(va, vb)
   281  		if err == nil && !r {
   282  			res = vb
   283  		} else {
   284  			res = va
   285  		}
   286  	case "max":
   287  		r, err := types.Less(va, vb)
   288  		if err == nil && r {
   289  			res = vb
   290  		} else {
   291  			res = va
   292  		}
   293  	case "sum", "avg":
   294  		if va.Tid == types.IntID && vb.Tid == types.IntID {
   295  			va.Value = va.Value.(int64) + vb.Value.(int64)
   296  		} else if va.Tid == types.FloatID && vb.Tid == types.FloatID {
   297  			va.Value = va.Value.(float64) + vb.Value.(float64)
   298  		}
   299  		// Skipping the else case since that means the pair cannot be summed.
   300  		res = va
   301  	default:
   302  		x.Fatalf("Unhandled aggregator function %v", ag.name)
   303  	}
   304  	ag.count++
   305  	ag.result = res
   306  }
   307  
   308  func (ag *aggregator) ValueMarshalled() (*pb.TaskValue, error) {
   309  	data := types.ValueForType(types.BinaryID)
   310  	ag.divideByCount()
   311  	res := &pb.TaskValue{ValType: ag.result.Tid.Enum(), Val: x.Nilbyte}
   312  	if ag.result.Value == nil {
   313  		return res, nil
   314  	}
   315  	// We'll divide it by the count if it's an avg aggregator.
   316  	err := types.Marshal(ag.result, &data)
   317  	if err != nil {
   318  		return res, err
   319  	}
   320  	res.Val = data.Value.([]byte)
   321  	return res, nil
   322  }
   323  
   324  func (ag *aggregator) divideByCount() {
   325  	if ag.name != "avg" || ag.count == 0 || ag.result.Value == nil {
   326  		return
   327  	}
   328  	var v float64
   329  	if ag.result.Tid == types.IntID {
   330  		v = float64(ag.result.Value.(int64))
   331  	} else if ag.result.Tid == types.FloatID {
   332  		v = ag.result.Value.(float64)
   333  	}
   334  
   335  	ag.result.Tid = types.FloatID
   336  	ag.result.Value = v / float64(ag.count)
   337  }
   338  
   339  func (ag *aggregator) Value() (types.Val, error) {
   340  	if ag.result.Value == nil {
   341  		return ag.result, ErrEmptyVal
   342  	}
   343  	ag.divideByCount()
   344  	if ag.result.Tid == types.FloatID {
   345  		if math.IsInf(ag.result.Value.(float64), 1) {
   346  			ag.result.Value = math.MaxFloat64
   347  		} else if math.IsInf(ag.result.Value.(float64), -1) {
   348  			ag.result.Value = -1 * math.MaxFloat64
   349  		} else if math.IsNaN(ag.result.Value.(float64)) {
   350  			ag.result.Value = 0.0
   351  		}
   352  	}
   353  	return ag.result, nil
   354  }