github.com/SaurabhDubey-Groww/go-cloud@v0.0.0-20221124105541-b26c29285fd8/docstore/gcpfirestore/query.go (about)

     1  // Copyright 2019 The Go Cloud Development Kit Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // TODO(jba): figure out how to get filters with uints to work: since they are represented as
    16  // int64s, the sign is wrong.
    17  
    18  package gcpfirestore
    19  
    20  import (
    21  	"context"
    22  	"fmt"
    23  	"math"
    24  	"path"
    25  	"reflect"
    26  	"strings"
    27  	"time"
    28  
    29  	"gocloud.dev/docstore/driver"
    30  	"gocloud.dev/internal/gcerr"
    31  	pb "google.golang.org/genproto/googleapis/firestore/v1"
    32  	"google.golang.org/protobuf/types/known/wrapperspb"
    33  )
    34  
    35  func (c *collection) RunGetQuery(ctx context.Context, q *driver.Query) (driver.DocumentIterator, error) {
    36  	return c.newDocIterator(ctx, q)
    37  }
    38  
    39  func (c *collection) newDocIterator(ctx context.Context, q *driver.Query) (*docIterator, error) {
    40  	sq, localFilters, err := c.queryToProto(q)
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  	req := &pb.RunQueryRequest{
    45  		Parent:    path.Dir(c.collPath),
    46  		QueryType: &pb.RunQueryRequest_StructuredQuery{sq},
    47  	}
    48  	if q.BeforeQuery != nil {
    49  		if err := q.BeforeQuery(driver.AsFunc(req)); err != nil {
    50  			return nil, err
    51  		}
    52  	}
    53  	ctx, cancel := context.WithCancel(ctx)
    54  	sc, err := c.client.RunQuery(ctx, req)
    55  	if err != nil {
    56  		cancel()
    57  		return nil, err
    58  	}
    59  	return &docIterator{
    60  		streamClient: sc,
    61  		nameField:    c.nameField,
    62  		revField:     c.opts.RevisionField,
    63  		localFilters: localFilters,
    64  		cancel:       cancel,
    65  	}, nil
    66  }
    67  
    68  ////////////////////////////////////////////////////////////////
    69  // The code below is adapted from cloud.google.com/go/firestore.
    70  
    71  type docIterator struct {
    72  	streamClient        pb.Firestore_RunQueryClient
    73  	nameField, revField string
    74  	localFilters        []driver.Filter
    75  	// We call cancel to make sure the stream client doesn't leak resources.
    76  	// We don't need to call it if Recv() returns a non-nil error.
    77  	// See https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
    78  	cancel func()
    79  }
    80  
    81  func (it *docIterator) Next(ctx context.Context, doc driver.Document) error {
    82  	res, err := it.nextResponse(ctx)
    83  	if err != nil {
    84  		return err
    85  	}
    86  	return decodeDoc(res.Document, doc, it.nameField, it.revField)
    87  }
    88  
    89  func (it *docIterator) nextResponse(ctx context.Context) (*pb.RunQueryResponse, error) {
    90  	for {
    91  		res, err := it.streamClient.Recv()
    92  		if err != nil {
    93  			return nil, err
    94  		}
    95  		// No document => partial progress; keep receiving.
    96  		if res.Document == nil {
    97  			continue
    98  		}
    99  		match, err := it.evaluateLocalFilters(res.Document)
   100  		if err != nil {
   101  			return nil, err
   102  		}
   103  		if match {
   104  			return res, nil
   105  		}
   106  	}
   107  }
   108  
   109  // Report whether the filters are true of the document.
   110  func (it *docIterator) evaluateLocalFilters(pdoc *pb.Document) (bool, error) {
   111  	if len(it.localFilters) == 0 {
   112  		return true, nil
   113  	}
   114  	// TODO(jba): optimization: evaluate the filter directly on the proto document, without decoding.
   115  	m := map[string]interface{}{}
   116  	doc, err := driver.NewDocument(m)
   117  	if err != nil {
   118  		return false, err
   119  	}
   120  	if err := decodeDoc(pdoc, doc, it.nameField, it.revField); err != nil {
   121  		return false, err
   122  	}
   123  	for _, f := range it.localFilters {
   124  		if !evaluateFilter(f, doc) {
   125  			return false, nil
   126  		}
   127  	}
   128  	return true, nil
   129  }
   130  
   131  func evaluateFilter(f driver.Filter, doc driver.Document) bool {
   132  	val, err := doc.Get(f.FieldPath)
   133  	if err != nil {
   134  		// Treat a missing field as false.
   135  		return false
   136  	}
   137  	// Compare times.
   138  	if t1, ok := val.(time.Time); ok {
   139  		if t2, ok := f.Value.(time.Time); ok {
   140  			return applyComparison(f.Op, driver.CompareTimes(t1, t2))
   141  		}
   142  		return false
   143  	}
   144  	lhs := reflect.ValueOf(val)
   145  	rhs := reflect.ValueOf(f.Value)
   146  	if lhs.Kind() == reflect.String {
   147  		if rhs.Kind() != reflect.String {
   148  			return false
   149  		}
   150  		return applyComparison(f.Op, strings.Compare(lhs.String(), rhs.String()))
   151  	}
   152  
   153  	cmp, err := driver.CompareNumbers(lhs, rhs)
   154  	if err != nil {
   155  		return false
   156  	}
   157  	return applyComparison(f.Op, cmp)
   158  }
   159  
   160  // op is one of the five permitted docstore operators ("=", "<", etc.)
   161  // c is the result of strings.Compare or the like.
   162  func applyComparison(op string, c int) bool {
   163  	switch op {
   164  	case driver.EqualOp:
   165  		return c == 0
   166  	case ">":
   167  		return c > 0
   168  	case "<":
   169  		return c < 0
   170  	case ">=":
   171  		return c >= 0
   172  	case "<=":
   173  		return c <= 0
   174  	default:
   175  		panic("bad op")
   176  	}
   177  }
   178  
   179  func (it *docIterator) Stop() { it.cancel() }
   180  
   181  func (it *docIterator) As(i interface{}) bool {
   182  	p, ok := i.(*pb.Firestore_RunQueryClient)
   183  	if !ok {
   184  		return false
   185  	}
   186  	*p = it.streamClient
   187  	return true
   188  }
   189  
   190  // Converts the query to a Firestore proto. Also returns filters that need to be
   191  // evaluated on the client.
   192  func (c *collection) queryToProto(q *driver.Query) (*pb.StructuredQuery, []driver.Filter, error) {
   193  	// The collection ID is the last component of the collection path.
   194  	collID := path.Base(c.collPath)
   195  	p := &pb.StructuredQuery{
   196  		From: []*pb.StructuredQuery_CollectionSelector{{CollectionId: collID}},
   197  	}
   198  	if len(q.FieldPaths) > 0 {
   199  		p.Select = &pb.StructuredQuery_Projection{}
   200  		for _, fp := range q.FieldPaths {
   201  			p.Select.Fields = append(p.Select.Fields, fieldRef(fp))
   202  		}
   203  	}
   204  	if q.Limit > 0 {
   205  		p.Limit = &wrapperspb.Int32Value{Value: int32(q.Limit)}
   206  	}
   207  
   208  	// TODO(jba): make sure we retrieve the fields needed for local filters.
   209  	sendFilters, localFilters := splitFilters(q.Filters)
   210  	if len(localFilters) > 0 && !c.opts.AllowLocalFilters {
   211  		return nil, nil, gcerr.Newf(gcerr.InvalidArgument, nil, "query requires local filters; set Options.AllowLocalFilters to true to enable")
   212  	}
   213  
   214  	// If there is only one filter, use it directly. Otherwise, construct
   215  	// a CompositeFilter.
   216  	var pfs []*pb.StructuredQuery_Filter
   217  	for _, f := range sendFilters {
   218  		pf, err := c.filterToProto(f)
   219  		if err != nil {
   220  			return nil, nil, err
   221  		}
   222  		pfs = append(pfs, pf)
   223  	}
   224  	if len(pfs) == 1 {
   225  		p.Where = pfs[0]
   226  	} else if len(pfs) > 1 {
   227  		p.Where = &pb.StructuredQuery_Filter{
   228  			FilterType: &pb.StructuredQuery_Filter_CompositeFilter{&pb.StructuredQuery_CompositeFilter{
   229  				Op:      pb.StructuredQuery_CompositeFilter_AND,
   230  				Filters: pfs,
   231  			}},
   232  		}
   233  	}
   234  
   235  	if q.OrderByField != "" {
   236  		// TODO(jba): reorder filters so order-by one is first of inequalities?
   237  		// TODO(jba): see if it's OK if filter inequality direction differs from sort direction.
   238  		fref := []string{q.OrderByField}
   239  		if q.OrderByField == c.nameField {
   240  			fref[0] = "__name__"
   241  		}
   242  		var dir pb.StructuredQuery_Direction
   243  		if q.OrderAscending {
   244  			dir = pb.StructuredQuery_ASCENDING
   245  		} else {
   246  			dir = pb.StructuredQuery_DESCENDING
   247  		}
   248  		p.OrderBy = []*pb.StructuredQuery_Order{{Field: fieldRef(fref), Direction: dir}}
   249  	}
   250  
   251  	// TODO(jba): cursors (start/end)
   252  	return p, localFilters, nil
   253  }
   254  
   255  // splitFilters separates the list of query filters into those we can send to the Firestore service,
   256  // and those we must evaluate here on the client.
   257  func splitFilters(fs []driver.Filter) (sendToFirestore, evaluateLocally []driver.Filter) {
   258  	// Enforce that only one field can have an inequality.
   259  	var rangeFP []string
   260  	for _, f := range fs {
   261  		if f.Op == driver.EqualOp {
   262  			sendToFirestore = append(sendToFirestore, f)
   263  		} else {
   264  			if rangeFP == nil || driver.FieldPathsEqual(rangeFP, f.FieldPath) {
   265  				// Multiple inequality filters on the same field are OK.
   266  				rangeFP = f.FieldPath
   267  				sendToFirestore = append(sendToFirestore, f)
   268  			} else {
   269  				evaluateLocally = append(evaluateLocally, f)
   270  			}
   271  		}
   272  	}
   273  	return sendToFirestore, evaluateLocally
   274  }
   275  
   276  func (c *collection) filterToProto(f driver.Filter) (*pb.StructuredQuery_Filter, error) {
   277  	// Treat filters on the name field specially.
   278  	if c.nameField != "" && driver.FieldPathEqualsField(f.FieldPath, c.nameField) {
   279  		v := reflect.ValueOf(f.Value)
   280  		if v.Kind() != reflect.String {
   281  			return nil, gcerr.Newf(gcerr.InvalidArgument, nil,
   282  				"name field filter value %v of type %[1]T is not a string", f.Value)
   283  		}
   284  		return newFieldFilter([]string{"__name__"}, f.Op,
   285  			&pb.Value{ValueType: &pb.Value_ReferenceValue{c.collPath + "/" + v.String()}})
   286  	}
   287  	// "= nil" and "= NaN" are handled specially.
   288  	if uop, ok := unaryOpFor(f.Value); ok {
   289  		if f.Op != driver.EqualOp {
   290  			return nil, fmt.Errorf("firestore: must use '=' when comparing %v", f.Value)
   291  		}
   292  		return &pb.StructuredQuery_Filter{
   293  			FilterType: &pb.StructuredQuery_Filter_UnaryFilter{
   294  				UnaryFilter: &pb.StructuredQuery_UnaryFilter{
   295  					OperandType: &pb.StructuredQuery_UnaryFilter_Field{
   296  						Field: fieldRef(f.FieldPath),
   297  					},
   298  					Op: uop,
   299  				},
   300  			},
   301  		}, nil
   302  	}
   303  	pv, err := encodeValue(f.Value)
   304  	if err != nil {
   305  		return nil, err
   306  	}
   307  	return newFieldFilter(f.FieldPath, f.Op, pv)
   308  }
   309  
   310  func unaryOpFor(value interface{}) (pb.StructuredQuery_UnaryFilter_Operator, bool) {
   311  	switch {
   312  	case value == nil:
   313  		return pb.StructuredQuery_UnaryFilter_IS_NULL, true
   314  	case isNaN(value):
   315  		return pb.StructuredQuery_UnaryFilter_IS_NAN, true
   316  	default:
   317  		return pb.StructuredQuery_UnaryFilter_OPERATOR_UNSPECIFIED, false
   318  	}
   319  }
   320  
   321  func isNaN(x interface{}) bool {
   322  	switch x := x.(type) {
   323  	case float32:
   324  		return math.IsNaN(float64(x))
   325  	case float64:
   326  		return math.IsNaN(x)
   327  	default:
   328  		return false
   329  	}
   330  }
   331  
   332  func fieldRef(fp []string) *pb.StructuredQuery_FieldReference {
   333  	return &pb.StructuredQuery_FieldReference{FieldPath: toServiceFieldPath(fp)}
   334  }
   335  
   336  func newFieldFilter(fp []string, op string, val *pb.Value) (*pb.StructuredQuery_Filter, error) {
   337  	var fop pb.StructuredQuery_FieldFilter_Operator
   338  	switch op {
   339  	case "<":
   340  		fop = pb.StructuredQuery_FieldFilter_LESS_THAN
   341  	case "<=":
   342  		fop = pb.StructuredQuery_FieldFilter_LESS_THAN_OR_EQUAL
   343  	case ">":
   344  		fop = pb.StructuredQuery_FieldFilter_GREATER_THAN
   345  	case ">=":
   346  		fop = pb.StructuredQuery_FieldFilter_GREATER_THAN_OR_EQUAL
   347  	case driver.EqualOp:
   348  		fop = pb.StructuredQuery_FieldFilter_EQUAL
   349  	// TODO(jba): can we support array-contains portably?
   350  	// case "array-contains":
   351  	// 	fop = pb.StructuredQuery_FieldFilter_ARRAY_CONTAINS
   352  	default:
   353  		return nil, gcerr.Newf(gcerr.InvalidArgument, nil, "invalid operator: %q", op)
   354  	}
   355  	return &pb.StructuredQuery_Filter{
   356  		FilterType: &pb.StructuredQuery_Filter_FieldFilter{
   357  			FieldFilter: &pb.StructuredQuery_FieldFilter{
   358  				Field: fieldRef(fp),
   359  				Op:    fop,
   360  				Value: val,
   361  			},
   362  		},
   363  	}, nil
   364  }
   365  
   366  func (c *collection) QueryPlan(q *driver.Query) (string, error) {
   367  	return "unknown", nil
   368  }