github.com/thiagoyeds/go-cloud@v0.26.0/docstore/gcpfirestore/query.go (about)

     1  // Copyright 2019 The Go Cloud Development Kit Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // TODO(jba): figure out how to get filters with uints to work: since they are represented as
    16  // int64s, the sign is wrong.
    17  
    18  package gcpfirestore
    19  
    20  import (
    21  	"context"
    22  	"fmt"
    23  	"math"
    24  	"path"
    25  	"reflect"
    26  	"strings"
    27  	"time"
    28  
    29  	"gocloud.dev/docstore/driver"
    30  	"gocloud.dev/internal/gcerr"
    31  	pb "google.golang.org/genproto/googleapis/firestore/v1"
    32  	"google.golang.org/protobuf/types/known/wrapperspb"
    33  )
    34  
    35  func (c *collection) RunGetQuery(ctx context.Context, q *driver.Query) (driver.DocumentIterator, error) {
    36  	return c.newDocIterator(ctx, q)
    37  }
    38  
    39  func (c *collection) newDocIterator(ctx context.Context, q *driver.Query) (*docIterator, error) {
    40  	sq, localFilters, err := c.queryToProto(q)
    41  	if err != nil {
    42  		return nil, err
    43  	}
    44  	req := &pb.RunQueryRequest{
    45  		Parent:    path.Dir(c.collPath),
    46  		QueryType: &pb.RunQueryRequest_StructuredQuery{sq},
    47  	}
    48  	if q.BeforeQuery != nil {
    49  		if err := q.BeforeQuery(driver.AsFunc(req)); err != nil {
    50  			return nil, err
    51  		}
    52  	}
    53  	ctx, cancel := context.WithCancel(ctx)
    54  	sc, err := c.client.RunQuery(ctx, req)
    55  	if err != nil {
    56  		cancel()
    57  		return nil, err
    58  	}
    59  	return &docIterator{
    60  		streamClient: sc,
    61  		nameField:    c.nameField,
    62  		revField:     c.opts.RevisionField,
    63  		localFilters: localFilters,
    64  		cancel:       cancel,
    65  	}, nil
    66  }
    67  
    68  ////////////////////////////////////////////////////////////////
    69  // The code below is adapted from cloud.google.com/go/firestore.
    70  
    71  type docIterator struct {
    72  	streamClient        pb.Firestore_RunQueryClient
    73  	nameField, revField string
    74  	localFilters        []driver.Filter
    75  	// We call cancel to make sure the stream client doesn't leak resources.
    76  	// We don't need to call it if Recv() returns a non-nil error.
    77  	// See https://godoc.org/google.golang.org/grpc#ClientConn.NewStream.
    78  	cancel func()
    79  }
    80  
    81  func (it *docIterator) Next(ctx context.Context, doc driver.Document) error {
    82  	res, err := it.nextResponse(ctx)
    83  	if err != nil {
    84  		return err
    85  	}
    86  	return decodeDoc(res.Document, doc, it.nameField, it.revField)
    87  }
    88  
    89  func (it *docIterator) nextResponse(ctx context.Context) (*pb.RunQueryResponse, error) {
    90  	for {
    91  		res, err := it.streamClient.Recv()
    92  		if err != nil {
    93  			return nil, err
    94  		}
    95  		// No document => partial progress; keep receiving.
    96  		if res.Document == nil {
    97  			continue
    98  		}
    99  		match, err := it.evaluateLocalFilters(res.Document)
   100  		if err != nil {
   101  			return nil, err
   102  		}
   103  		if match {
   104  			return res, nil
   105  		}
   106  	}
   107  }
   108  
   109  // Report whether the filters are true of the document.
   110  func (it *docIterator) evaluateLocalFilters(pdoc *pb.Document) (bool, error) {
   111  	if len(it.localFilters) == 0 {
   112  		return true, nil
   113  	}
   114  	// TODO(jba): optimization: evaluate the filter directly on the proto document, without decoding.
   115  	m := map[string]interface{}{}
   116  	doc, err := driver.NewDocument(m)
   117  	if err != nil {
   118  		return false, err
   119  	}
   120  	if err := decodeDoc(pdoc, doc, it.nameField, it.revField); err != nil {
   121  		return false, err
   122  	}
   123  	for _, f := range it.localFilters {
   124  		if !evaluateFilter(f, doc) {
   125  			return false, nil
   126  		}
   127  	}
   128  	return true, nil
   129  }
   130  
   131  func evaluateFilter(f driver.Filter, doc driver.Document) bool {
   132  	val, err := doc.Get(f.FieldPath)
   133  	if err != nil {
   134  		// Treat a missing field as false.
   135  		return false
   136  	}
   137  	// Compare times.
   138  	if t1, ok := val.(time.Time); ok {
   139  		if t2, ok := f.Value.(time.Time); ok {
   140  			return applyComparison(f.Op, driver.CompareTimes(t1, t2))
   141  		} else {
   142  			return false
   143  		}
   144  	}
   145  	lhs := reflect.ValueOf(val)
   146  	rhs := reflect.ValueOf(f.Value)
   147  	if lhs.Kind() == reflect.String {
   148  		if rhs.Kind() != reflect.String {
   149  			return false
   150  		}
   151  		return applyComparison(f.Op, strings.Compare(lhs.String(), rhs.String()))
   152  	}
   153  
   154  	cmp, err := driver.CompareNumbers(lhs, rhs)
   155  	if err != nil {
   156  		return false
   157  	}
   158  	return applyComparison(f.Op, cmp)
   159  }
   160  
   161  // op is one of the five permitted docstore operators ("=", "<", etc.)
   162  // c is the result of strings.Compare or the like.
   163  func applyComparison(op string, c int) bool {
   164  	switch op {
   165  	case driver.EqualOp:
   166  		return c == 0
   167  	case ">":
   168  		return c > 0
   169  	case "<":
   170  		return c < 0
   171  	case ">=":
   172  		return c >= 0
   173  	case "<=":
   174  		return c <= 0
   175  	default:
   176  		panic("bad op")
   177  	}
   178  }
   179  
   180  func (it *docIterator) Stop() { it.cancel() }
   181  
   182  func (it *docIterator) As(i interface{}) bool {
   183  	p, ok := i.(*pb.Firestore_RunQueryClient)
   184  	if !ok {
   185  		return false
   186  	}
   187  	*p = it.streamClient
   188  	return true
   189  }
   190  
   191  // Converts the query to a Firestore proto. Also returns filters that need to be
   192  // evaluated on the client.
   193  func (c *collection) queryToProto(q *driver.Query) (*pb.StructuredQuery, []driver.Filter, error) {
   194  	// The collection ID is the last component of the collection path.
   195  	collID := path.Base(c.collPath)
   196  	p := &pb.StructuredQuery{
   197  		From: []*pb.StructuredQuery_CollectionSelector{{CollectionId: collID}},
   198  	}
   199  	if len(q.FieldPaths) > 0 {
   200  		p.Select = &pb.StructuredQuery_Projection{}
   201  		for _, fp := range q.FieldPaths {
   202  			p.Select.Fields = append(p.Select.Fields, fieldRef(fp))
   203  		}
   204  	}
   205  	if q.Limit > 0 {
   206  		p.Limit = &wrapperspb.Int32Value{Value: int32(q.Limit)}
   207  	}
   208  
   209  	// TODO(jba): make sure we retrieve the fields needed for local filters.
   210  	sendFilters, localFilters := splitFilters(q.Filters)
   211  	if len(localFilters) > 0 && !c.opts.AllowLocalFilters {
   212  		return nil, nil, gcerr.Newf(gcerr.InvalidArgument, nil, "query requires local filters; set Options.AllowLocalFilters to true to enable")
   213  	}
   214  
   215  	// If there is only one filter, use it directly. Otherwise, construct
   216  	// a CompositeFilter.
   217  	var pfs []*pb.StructuredQuery_Filter
   218  	for _, f := range sendFilters {
   219  		pf, err := c.filterToProto(f)
   220  		if err != nil {
   221  			return nil, nil, err
   222  		}
   223  		pfs = append(pfs, pf)
   224  	}
   225  	if len(pfs) == 1 {
   226  		p.Where = pfs[0]
   227  	} else if len(pfs) > 1 {
   228  		p.Where = &pb.StructuredQuery_Filter{
   229  			FilterType: &pb.StructuredQuery_Filter_CompositeFilter{&pb.StructuredQuery_CompositeFilter{
   230  				Op:      pb.StructuredQuery_CompositeFilter_AND,
   231  				Filters: pfs,
   232  			}},
   233  		}
   234  	}
   235  
   236  	if q.OrderByField != "" {
   237  		// TODO(jba): reorder filters so order-by one is first of inequalities?
   238  		// TODO(jba): see if it's OK if filter inequality direction differs from sort direction.
   239  		fref := []string{q.OrderByField}
   240  		if q.OrderByField == c.nameField {
   241  			fref[0] = "__name__"
   242  		}
   243  		var dir pb.StructuredQuery_Direction
   244  		if q.OrderAscending {
   245  			dir = pb.StructuredQuery_ASCENDING
   246  		} else {
   247  			dir = pb.StructuredQuery_DESCENDING
   248  		}
   249  		p.OrderBy = []*pb.StructuredQuery_Order{{Field: fieldRef(fref), Direction: dir}}
   250  	}
   251  
   252  	// TODO(jba): cursors (start/end)
   253  	return p, localFilters, nil
   254  }
   255  
   256  // splitFilters separates the list of query filters into those we can send to the Firestore service,
   257  // and those we must evaluate here on the client.
   258  func splitFilters(fs []driver.Filter) (sendToFirestore, evaluateLocally []driver.Filter) {
   259  	// Enforce that only one field can have an inequality.
   260  	var rangeFP []string
   261  	for _, f := range fs {
   262  		if f.Op == driver.EqualOp {
   263  			sendToFirestore = append(sendToFirestore, f)
   264  		} else {
   265  			if rangeFP == nil || driver.FieldPathsEqual(rangeFP, f.FieldPath) {
   266  				// Multiple inequality filters on the same field are OK.
   267  				rangeFP = f.FieldPath
   268  				sendToFirestore = append(sendToFirestore, f)
   269  			} else {
   270  				evaluateLocally = append(evaluateLocally, f)
   271  			}
   272  		}
   273  	}
   274  	return sendToFirestore, evaluateLocally
   275  }
   276  
   277  func (c *collection) filterToProto(f driver.Filter) (*pb.StructuredQuery_Filter, error) {
   278  	// Treat filters on the name field specially.
   279  	if c.nameField != "" && driver.FieldPathEqualsField(f.FieldPath, c.nameField) {
   280  		v := reflect.ValueOf(f.Value)
   281  		if v.Kind() != reflect.String {
   282  			return nil, gcerr.Newf(gcerr.InvalidArgument, nil,
   283  				"name field filter value %v of type %[1]T is not a string", f.Value)
   284  		}
   285  		return newFieldFilter([]string{"__name__"}, f.Op,
   286  			&pb.Value{ValueType: &pb.Value_ReferenceValue{c.collPath + "/" + v.String()}})
   287  	}
   288  	// "= nil" and "= NaN" are handled specially.
   289  	if uop, ok := unaryOpFor(f.Value); ok {
   290  		if f.Op != driver.EqualOp {
   291  			return nil, fmt.Errorf("firestore: must use '=' when comparing %v", f.Value)
   292  		}
   293  		return &pb.StructuredQuery_Filter{
   294  			FilterType: &pb.StructuredQuery_Filter_UnaryFilter{
   295  				UnaryFilter: &pb.StructuredQuery_UnaryFilter{
   296  					OperandType: &pb.StructuredQuery_UnaryFilter_Field{
   297  						Field: fieldRef(f.FieldPath),
   298  					},
   299  					Op: uop,
   300  				},
   301  			},
   302  		}, nil
   303  	}
   304  	pv, err := encodeValue(f.Value)
   305  	if err != nil {
   306  		return nil, err
   307  	}
   308  	return newFieldFilter(f.FieldPath, f.Op, pv)
   309  }
   310  
   311  func unaryOpFor(value interface{}) (pb.StructuredQuery_UnaryFilter_Operator, bool) {
   312  	switch {
   313  	case value == nil:
   314  		return pb.StructuredQuery_UnaryFilter_IS_NULL, true
   315  	case isNaN(value):
   316  		return pb.StructuredQuery_UnaryFilter_IS_NAN, true
   317  	default:
   318  		return pb.StructuredQuery_UnaryFilter_OPERATOR_UNSPECIFIED, false
   319  	}
   320  }
   321  
   322  func isNaN(x interface{}) bool {
   323  	switch x := x.(type) {
   324  	case float32:
   325  		return math.IsNaN(float64(x))
   326  	case float64:
   327  		return math.IsNaN(x)
   328  	default:
   329  		return false
   330  	}
   331  }
   332  
   333  func fieldRef(fp []string) *pb.StructuredQuery_FieldReference {
   334  	return &pb.StructuredQuery_FieldReference{FieldPath: toServiceFieldPath(fp)}
   335  }
   336  
   337  func newFieldFilter(fp []string, op string, val *pb.Value) (*pb.StructuredQuery_Filter, error) {
   338  	var fop pb.StructuredQuery_FieldFilter_Operator
   339  	switch op {
   340  	case "<":
   341  		fop = pb.StructuredQuery_FieldFilter_LESS_THAN
   342  	case "<=":
   343  		fop = pb.StructuredQuery_FieldFilter_LESS_THAN_OR_EQUAL
   344  	case ">":
   345  		fop = pb.StructuredQuery_FieldFilter_GREATER_THAN
   346  	case ">=":
   347  		fop = pb.StructuredQuery_FieldFilter_GREATER_THAN_OR_EQUAL
   348  	case driver.EqualOp:
   349  		fop = pb.StructuredQuery_FieldFilter_EQUAL
   350  	// TODO(jba): can we support array-contains portably?
   351  	// case "array-contains":
   352  	// 	fop = pb.StructuredQuery_FieldFilter_ARRAY_CONTAINS
   353  	default:
   354  		return nil, gcerr.Newf(gcerr.InvalidArgument, nil, "invalid operator: %q", op)
   355  	}
   356  	return &pb.StructuredQuery_Filter{
   357  		FilterType: &pb.StructuredQuery_Filter_FieldFilter{
   358  			FieldFilter: &pb.StructuredQuery_FieldFilter{
   359  				Field: fieldRef(fp),
   360  				Op:    fop,
   361  				Value: val,
   362  			},
   363  		},
   364  	}, nil
   365  }
   366  
   367  func (c *collection) QueryPlan(q *driver.Query) (string, error) {
   368  	return "unknown", nil
   369  }