
     1  // Copyright 2019 The Go Cloud Development Kit Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    15  package gcpfirestore
    17  // Encoding and decoding between supported docstore types and Firestore protos.
    19  import (
    20  	"errors"
    21  	"fmt"
    22  	"path"
    23  	"reflect"
    24  	"time"
    26  	""
    27  	pb ""
    28  	""
    29  	tspb ""
    30  )
    32  // encodeDoc encodes a driver.Document into Firestore's representation.
    33  // A Firestore document (*pb.Document) is just a Go map from strings to *pb.Values.
    34  func encodeDoc(doc driver.Document, nameField string) (*pb.Document, error) {
    35  	var e encoder
    36  	if err := doc.Encode(&e); err != nil {
    37  		return nil, err
    38  	}
    39  	fields := e.pv.GetMapValue().Fields
    40  	// Do not put the name field in the document itself.
    41  	if nameField != "" {
    42  		delete(fields, nameField)
    43  	}
    44  	return &pb.Document{Fields: fields}, nil
    45  }
    47  // encodeValue encodes a Go value as a Firestore Value.
    48  // The Firestore proto definition for Value is a oneof of various types,
    49  // including basic types like string as well as lists and maps.
    50  func encodeValue(x interface{}) (*pb.Value, error) {
    51  	var e encoder
    52  	if err := driver.Encode(reflect.ValueOf(x), &e); err != nil {
    53  		return nil, err
    54  	}
    55  	return e.pv, nil
    56  }
    58  // encoder implements driver.Encoder. Its job is to encode a single Firestore value.
    59  type encoder struct {
    60  	pv *pb.Value
    61  }
    63  var nullValue = &pb.Value{ValueType: &pb.Value_NullValue{}}
    65  func (e *encoder) EncodeNil()            { e.pv = nullValue }
    66  func (e *encoder) EncodeBool(x bool)     { e.pv = &pb.Value{ValueType: &pb.Value_BooleanValue{x}} }
    67  func (e *encoder) EncodeInt(x int64)     { e.pv = &pb.Value{ValueType: &pb.Value_IntegerValue{x}} }
    68  func (e *encoder) EncodeUint(x uint64)   { e.pv = &pb.Value{ValueType: &pb.Value_IntegerValue{int64(x)}} }
    69  func (e *encoder) EncodeBytes(x []byte)  { e.pv = &pb.Value{ValueType: &pb.Value_BytesValue{x}} }
    70  func (e *encoder) EncodeFloat(x float64) { e.pv = floatval(x) }
    71  func (e *encoder) EncodeString(x string) { e.pv = &pb.Value{ValueType: &pb.Value_StringValue{x}} }
    73  func (e *encoder) ListIndex(int) { panic("impossible") }
    74  func (e *encoder) MapKey(string) { panic("impossible") }
    76  func (e *encoder) EncodeList(n int) driver.Encoder {
    77  	s := make([]*pb.Value, n)
    78  	e.pv = &pb.Value{ValueType: &pb.Value_ArrayValue{&pb.ArrayValue{Values: s}}}
    79  	return &listEncoder{s: s}
    80  }
    82  func (e *encoder) EncodeMap(n int) driver.Encoder {
    83  	m := make(map[string]*pb.Value, n)
    84  	e.pv = &pb.Value{ValueType: &pb.Value_MapValue{&pb.MapValue{Fields: m}}}
    85  	return &mapEncoder{m: m}
    86  }
    88  var (
    89  	typeOfGoTime         = reflect.TypeOf(time.Time{})
    90  	typeOfProtoTimestamp = reflect.TypeOf((*tspb.Timestamp)(nil))
    91  	typeOfLatLng         = reflect.TypeOf((*latlng.LatLng)(nil))
    92  )
    94  // Encode time.Time, latlng.LatLng, and ts.Timestamp specially, because the Go Firestore
    95  // client does.
    96  func (e *encoder) EncodeSpecial(v reflect.Value) (bool, error) {
    97  	switch v.Type() {
    98  	case typeOfGoTime:
    99  		ts := tspb.New(v.Interface().(time.Time))
   100  		e.pv = &pb.Value{ValueType: &pb.Value_TimestampValue{ts}}
   101  		return true, nil
   102  	case typeOfProtoTimestamp:
   103  		if v.IsNil() {
   104  			e.pv = nullValue
   105  		} else {
   106  			e.pv = &pb.Value{ValueType: &pb.Value_TimestampValue{v.Interface().(*tspb.Timestamp)}}
   107  		}
   108  		return true, nil
   109  	case typeOfLatLng:
   110  		if v.IsNil() {
   111  			e.pv = nullValue
   112  		} else {
   113  			e.pv = &pb.Value{ValueType: &pb.Value_GeoPointValue{v.Interface().(*latlng.LatLng)}}
   114  		}
   115  		return true, nil
   116  	default:
   117  		return false, nil
   118  	}
   119  }
   121  type listEncoder struct {
   122  	s []*pb.Value
   123  	encoder
   124  }
   126  func (e *listEncoder) ListIndex(i int) { e.s[i] = e.pv }
   128  type mapEncoder struct {
   129  	m map[string]*pb.Value
   130  	encoder
   131  }
   133  func (e *mapEncoder) MapKey(k string) { e.m[k] = e.pv }
   135  func floatval(x float64) *pb.Value { return &pb.Value{ValueType: &pb.Value_DoubleValue{x}} }
   137  ////////////////////////////////////////////////////////////////
   139  // decodeDoc decodes a Firestore document into a driver.Document.
   140  func decodeDoc(pdoc *pb.Document, ddoc driver.Document, nameField, revField string) error {
   141  	if pdoc.Fields == nil {
   142  		pdoc.Fields = map[string]*pb.Value{}
   143  	}
   144  	if nameField != "" {
   145  		pdoc.Fields[nameField] = &pb.Value{ValueType: &pb.Value_StringValue{StringValue: path.Base(pdoc.Name)}}
   146  	}
   147  	mv := &pb.Value{ValueType: &pb.Value_MapValue{&pb.MapValue{Fields: pdoc.Fields}}}
   148  	if err := ddoc.Decode(decoder{mv}); err != nil {
   149  		return err
   150  	}
   151  	// Set the revision field in the document, if it exists, to the update time.
   152  	if ddoc.HasField(revField) && pdoc.UpdateTime != nil {
   153  		return ddoc.SetField(revField, pdoc.UpdateTime)
   154  	}
   155  	return nil
   156  }
   158  type decoder struct {
   159  	pv *pb.Value
   160  }
   162  func (d decoder) String() string { // for debugging
   163  	return fmt.Sprint(d.pv)
   164  }
   166  func (d decoder) AsNull() bool {
   167  	_, ok := d.pv.ValueType.(*pb.Value_NullValue)
   168  	return ok
   169  }
   171  func (d decoder) AsBool() (bool, bool) {
   172  	if b, ok := d.pv.ValueType.(*pb.Value_BooleanValue); ok {
   173  		return b.BooleanValue, true
   174  	}
   175  	return false, false
   176  }
   178  func (d decoder) AsString() (string, bool) {
   179  	if s, ok := d.pv.ValueType.(*pb.Value_StringValue); ok {
   180  		return s.StringValue, true
   181  	}
   182  	return "", false
   183  }
   185  func (d decoder) AsInt() (int64, bool) {
   186  	if i, ok := d.pv.ValueType.(*pb.Value_IntegerValue); ok {
   187  		return i.IntegerValue, true
   188  	}
   189  	return 0, false
   190  }
   192  func (d decoder) AsUint() (uint64, bool) {
   193  	if i, ok := d.pv.ValueType.(*pb.Value_IntegerValue); ok {
   194  		return uint64(i.IntegerValue), true
   195  	}
   196  	return 0, false
   197  }
   199  func (d decoder) AsFloat() (float64, bool) {
   200  	if f, ok := d.pv.ValueType.(*pb.Value_DoubleValue); ok {
   201  		return f.DoubleValue, true
   202  	}
   203  	return 0, false
   204  }
   206  func (d decoder) AsBytes() ([]byte, bool) {
   207  	if bs, ok := d.pv.ValueType.(*pb.Value_BytesValue); ok {
   208  		return bs.BytesValue, true
   209  	}
   210  	return nil, false
   211  }
   213  // AsInterface decodes the value in d into the most appropriate Go type.
   214  func (d decoder) AsInterface() (interface{}, error) {
   215  	return decodeValue(d.pv)
   216  }
   218  func decodeValue(v *pb.Value) (interface{}, error) {
   219  	switch v := v.ValueType.(type) {
   220  	case *pb.Value_NullValue:
   221  		return nil, nil
   222  	case *pb.Value_BooleanValue:
   223  		return v.BooleanValue, nil
   224  	case *pb.Value_IntegerValue:
   225  		return v.IntegerValue, nil
   226  	case *pb.Value_DoubleValue:
   227  		return v.DoubleValue, nil
   228  	case *pb.Value_StringValue:
   229  		return v.StringValue, nil
   230  	case *pb.Value_BytesValue:
   231  		return v.BytesValue, nil
   232  	case *pb.Value_TimestampValue:
   233  		// Return TimestampValue as time.Time.
   234  		return v.TimestampValue.AsTime(), nil
   235  	case *pb.Value_ReferenceValue:
   236  		// TODO(jba): support references
   237  		return nil, errors.New("references are not currently supported")
   238  	case *pb.Value_GeoPointValue:
   239  		// Return GeoPointValue as *latlng.LatLng.
   240  		return v.GeoPointValue, nil
   241  	case *pb.Value_ArrayValue:
   242  		s := make([]interface{}, len(v.ArrayValue.Values))
   243  		for i, pv := range v.ArrayValue.Values {
   244  			e, err := decodeValue(pv)
   245  			if err != nil {
   246  				return nil, err
   247  			}
   248  			s[i] = e
   249  		}
   250  		return s, nil
   251  	case *pb.Value_MapValue:
   252  		m := make(map[string]interface{}, len(v.MapValue.Fields))
   253  		for k, pv := range v.MapValue.Fields {
   254  			e, err := decodeValue(pv)
   255  			if err != nil {
   256  				return nil, err
   257  			}
   258  			m[k] = e
   259  		}
   260  		return m, nil
   261  	}
   262  	return nil, fmt.Errorf("unknown firestore value type %T", v)
   263  }
   265  func (d decoder) ListLen() (int, bool) {
   266  	a := d.pv.GetArrayValue()
   267  	if a == nil {
   268  		return 0, false
   269  	}
   270  	return len(a.Values), true
   271  }
   273  func (d decoder) DecodeList(f func(int, driver.Decoder) bool) {
   274  	for i, e := range d.pv.GetArrayValue().Values {
   275  		if !f(i, decoder{e}) {
   276  			return
   277  		}
   278  	}
   279  }
   280  func (d decoder) MapLen() (int, bool) {
   281  	m := d.pv.GetMapValue()
   282  	if m == nil {
   283  		return 0, false
   284  	}
   285  	return len(m.Fields), true
   286  }
   287  func (d decoder) DecodeMap(f func(string, driver.Decoder, bool) bool) {
   288  	for k, v := range d.pv.GetMapValue().Fields {
   289  		if !f(k, decoder{v}, true) {
   290  			return
   291  		}
   292  	}
   293  }
   295  func (d decoder) AsSpecial(v reflect.Value) (bool, interface{}, error) {
   296  	switch v.Type() {
   297  	case typeOfGoTime:
   298  		if ts, ok := d.pv.ValueType.(*pb.Value_TimestampValue); ok {
   299  			if ts.TimestampValue == nil {
   300  				return true, time.Time{}, nil
   301  			}
   302  			return true, ts.TimestampValue.AsTime(), nil
   303  		}
   304  		return true, nil, fmt.Errorf("expected TimestampValue for time.Time, got %+v", d.pv.ValueType)
   305  	case typeOfProtoTimestamp:
   306  		if ts, ok := d.pv.ValueType.(*pb.Value_TimestampValue); ok {
   307  			return true, ts.TimestampValue, nil
   308  		}
   309  		return true, nil, fmt.Errorf("expected TimestampValue for *ts.Timestamp, got %+v", d.pv.ValueType)
   311  	case typeOfLatLng:
   312  		if ll, ok := d.pv.ValueType.(*pb.Value_GeoPointValue); ok {
   313  			return true, ll.GeoPointValue, nil
   314  		}
   315  		return true, nil, fmt.Errorf("expected GeoPointValue for *latlng.LatLng, got %+v", d.pv.ValueType)
   317  	default:
   318  		return false, nil, nil
   319  	}
   320  }