github.com/cornelk/go-cloud@v0.17.1/docstore/awsdynamodb/codec.go (about)

     1  // Copyright 2019 The Go Cloud Development Kit Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package awsdynamodb
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"reflect"
    21  	"strconv"
    22  	"time"
    23  
    24  	dyn "github.com/aws/aws-sdk-go/service/dynamodb"
    25  	"github.com/cornelk/go-cloud/docstore/driver"
    26  )
    27  
    28  var nullValue = new(dyn.AttributeValue).SetNULL(true)
    29  
    30  type encoder struct {
    31  	av *dyn.AttributeValue
    32  }
    33  
    34  func (e *encoder) EncodeNil()            { e.av = nullValue }
    35  func (e *encoder) EncodeBool(x bool)     { e.av = new(dyn.AttributeValue).SetBOOL(x) }
    36  func (e *encoder) EncodeInt(x int64)     { e.av = new(dyn.AttributeValue).SetN(strconv.FormatInt(x, 10)) }
    37  func (e *encoder) EncodeUint(x uint64)   { e.av = new(dyn.AttributeValue).SetN(strconv.FormatUint(x, 10)) }
    38  func (e *encoder) EncodeBytes(x []byte)  { e.av = new(dyn.AttributeValue).SetB(x) }
    39  func (e *encoder) EncodeFloat(x float64) { e.av = encodeFloat(x) }
    40  
    41  func (e *encoder) ListIndex(int) { panic("impossible") }
    42  func (e *encoder) MapKey(string) { panic("impossible") }
    43  
    44  func (e *encoder) EncodeString(x string) {
    45  	if len(x) == 0 {
    46  		e.av = nullValue
    47  	} else {
    48  		e.av = new(dyn.AttributeValue).SetS(x)
    49  	}
    50  }
    51  
    52  func (e *encoder) EncodeComplex(x complex128) {
    53  	e.av = new(dyn.AttributeValue).SetL([]*dyn.AttributeValue{encodeFloat(real(x)), encodeFloat(imag(x))})
    54  }
    55  
    56  func (e *encoder) EncodeList(n int) driver.Encoder {
    57  	s := make([]*dyn.AttributeValue, n)
    58  	e.av = new(dyn.AttributeValue).SetL(s)
    59  	return &listEncoder{s: s}
    60  }
    61  
    62  func (e *encoder) EncodeMap(n int) driver.Encoder {
    63  	m := make(map[string]*dyn.AttributeValue, n)
    64  	e.av = new(dyn.AttributeValue).SetM(m)
    65  	return &mapEncoder{m: m}
    66  }
    67  
    68  var typeOfGoTime = reflect.TypeOf(time.Time{})
    69  
    70  // EncodeSpecial encodes time.Time specially.
    71  func (e *encoder) EncodeSpecial(v reflect.Value) (bool, error) {
    72  	switch v.Type() {
    73  	case typeOfGoTime:
    74  		ts := v.Interface().(time.Time).Format(time.RFC3339Nano)
    75  		e.EncodeString(ts)
    76  	default:
    77  		return false, nil
    78  	}
    79  	return true, nil
    80  }
    81  
    82  type listEncoder struct {
    83  	s []*dyn.AttributeValue
    84  	encoder
    85  }
    86  
    87  func (e *listEncoder) ListIndex(i int) { e.s[i] = e.av }
    88  
    89  type mapEncoder struct {
    90  	m map[string]*dyn.AttributeValue
    91  	encoder
    92  }
    93  
    94  func (e *mapEncoder) MapKey(k string) { e.m[k] = e.av }
    95  
    96  func encodeDoc(doc driver.Document) (*dyn.AttributeValue, error) {
    97  	var e encoder
    98  	if err := doc.Encode(&e); err != nil {
    99  		return nil, err
   100  	}
   101  	return e.av, nil
   102  }
   103  
   104  // Encode the key fields of the given document into a map AttributeValue.
   105  // pkey and skey are the names of the partition key field and the sort key field.
   106  // pkey must always be non-empty, but skey may be empty if the collection has no sort key.
   107  func encodeDocKeyFields(doc driver.Document, pkey, skey string) (*dyn.AttributeValue, error) {
   108  	m := map[string]*dyn.AttributeValue{}
   109  
   110  	set := func(fieldName string) error {
   111  		fieldVal, err := doc.GetField(fieldName)
   112  		if err != nil {
   113  			return err
   114  		}
   115  		attrVal, err := encodeValue(fieldVal)
   116  		if err != nil {
   117  			return err
   118  		}
   119  		m[fieldName] = attrVal
   120  		return nil
   121  	}
   122  
   123  	if err := set(pkey); err != nil {
   124  		return nil, err
   125  	}
   126  	if skey != "" {
   127  		if err := set(skey); err != nil {
   128  			return nil, err
   129  		}
   130  	}
   131  	return new(dyn.AttributeValue).SetM(m), nil
   132  }
   133  
   134  func encodeValue(v interface{}) (*dyn.AttributeValue, error) {
   135  	var e encoder
   136  	if err := driver.Encode(reflect.ValueOf(v), &e); err != nil {
   137  		return nil, err
   138  	}
   139  	return e.av, nil
   140  }
   141  
   142  func encodeFloat(f float64) *dyn.AttributeValue {
   143  	return new(dyn.AttributeValue).SetN(strconv.FormatFloat(f, 'f', -1, 64))
   144  }
   145  
   146  ////////////////////////////////////////////////////////////////
   147  
   148  func decodeDoc(item *dyn.AttributeValue, doc driver.Document) error {
   149  	return doc.Decode(decoder{av: item})
   150  }
   151  
   152  type decoder struct {
   153  	av *dyn.AttributeValue
   154  }
   155  
   156  func (d decoder) String() string {
   157  	return d.av.String()
   158  }
   159  
   160  func (d decoder) AsBool() (bool, bool) {
   161  	if d.av.BOOL == nil {
   162  		return false, false
   163  	}
   164  	return *d.av.BOOL, true
   165  }
   166  
   167  func (d decoder) AsNull() bool {
   168  	return d.av.NULL != nil
   169  }
   170  
   171  func (d decoder) AsString() (string, bool) {
   172  	// Empty string is represented by NULL.
   173  	if d.av.NULL != nil {
   174  		return "", true
   175  	}
   176  	if d.av.S == nil {
   177  		return "", false
   178  	}
   179  	return *d.av.S, true
   180  }
   181  
   182  func (d decoder) AsInt() (int64, bool) {
   183  	if d.av.N == nil {
   184  		return 0, false
   185  	}
   186  	i, err := strconv.ParseInt(*d.av.N, 10, 64)
   187  	if err != nil {
   188  		return 0, false
   189  	}
   190  	return i, true
   191  }
   192  
   193  func (d decoder) AsUint() (uint64, bool) {
   194  	if d.av.N == nil {
   195  		return 0, false
   196  	}
   197  	u, err := strconv.ParseUint(*d.av.N, 10, 64)
   198  	if err != nil {
   199  		return 0, false
   200  	}
   201  	return u, true
   202  }
   203  
   204  func (d decoder) AsFloat() (float64, bool) {
   205  	if d.av.N == nil {
   206  		return 0, false
   207  	}
   208  	f, err := strconv.ParseFloat(*d.av.N, 64)
   209  	if err != nil {
   210  		return 0, false
   211  	}
   212  	return f, true
   213  
   214  }
   215  
   216  func (d decoder) AsComplex() (complex128, bool) {
   217  	if d.av.L == nil {
   218  		return 0, false
   219  	}
   220  	if len(d.av.L) != 2 {
   221  		return 0, false
   222  	}
   223  	r, ok := decoder{d.av.L[0]}.AsFloat()
   224  	if !ok {
   225  		return 0, false
   226  	}
   227  	i, ok := decoder{d.av.L[1]}.AsFloat()
   228  	if !ok {
   229  		return 0, false
   230  	}
   231  	return complex(r, i), true
   232  }
   233  
   234  func (d decoder) AsBytes() ([]byte, bool) {
   235  	if d.av.B == nil {
   236  		return nil, false
   237  	}
   238  	return d.av.B, true
   239  }
   240  
   241  func (d decoder) ListLen() (int, bool) {
   242  	if d.av.L == nil {
   243  		return 0, false
   244  	}
   245  	return len(d.av.L), true
   246  }
   247  
   248  func (d decoder) DecodeList(f func(i int, vd driver.Decoder) bool) {
   249  	for i, el := range d.av.L {
   250  		if !f(i, decoder{el}) {
   251  			break
   252  		}
   253  	}
   254  }
   255  
   256  func (d decoder) MapLen() (int, bool) {
   257  	if d.av.M == nil {
   258  		return 0, false
   259  	}
   260  	return len(d.av.M), true
   261  }
   262  
   263  func (d decoder) DecodeMap(f func(key string, vd driver.Decoder, exactMatch bool) bool) {
   264  	for k, av := range d.av.M {
   265  		if !f(k, decoder{av}, true) {
   266  			break
   267  		}
   268  	}
   269  }
   270  
   271  func (d decoder) AsInterface() (interface{}, error) {
   272  	return toGoValue(d.av)
   273  }
   274  
   275  func toGoValue(av *dyn.AttributeValue) (interface{}, error) {
   276  	switch {
   277  	case av.NULL != nil:
   278  		return nil, nil
   279  	case av.BOOL != nil:
   280  		return *av.BOOL, nil
   281  	case av.N != nil:
   282  		f, err := strconv.ParseFloat(*av.N, 64)
   283  		if err != nil {
   284  			return nil, err
   285  		}
   286  		i := int64(f)
   287  		if float64(i) == f {
   288  			return i, nil
   289  		}
   290  		u := uint64(f)
   291  		if float64(u) == f {
   292  			return u, nil
   293  		}
   294  		return f, nil
   295  
   296  	case av.B != nil:
   297  		return av.B, nil
   298  	case av.S != nil:
   299  		return *av.S, nil
   300  
   301  	case av.L != nil:
   302  		s := make([]interface{}, len(av.L))
   303  		for i, v := range av.L {
   304  			x, err := toGoValue(v)
   305  			if err != nil {
   306  				return nil, err
   307  			}
   308  			s[i] = x
   309  		}
   310  		return s, nil
   311  
   312  	case av.M != nil:
   313  		m := make(map[string]interface{}, len(av.M))
   314  		for k, v := range av.M {
   315  			x, err := toGoValue(v)
   316  			if err != nil {
   317  				return nil, err
   318  			}
   319  			m[k] = x
   320  		}
   321  		return m, nil
   322  
   323  	default:
   324  		return nil, fmt.Errorf("awsdynamodb: AttributeValue %s not supported", av)
   325  	}
   326  }
   327  
   328  func (d decoder) AsSpecial(v reflect.Value) (bool, interface{}, error) {
   329  	unsupportedTypes := `unsupported type, the docstore driver for DynamoDB does
   330  	not decode DynamoDB set types, such as string set, number set and binary set`
   331  	if d.av.SS != nil || d.av.NS != nil || d.av.BS != nil {
   332  		return true, nil, errors.New(unsupportedTypes)
   333  	}
   334  	switch v.Type() {
   335  	case typeOfGoTime:
   336  		if d.av.S == nil {
   337  			return false, nil, errors.New("expected string field for time.Time")
   338  		}
   339  		t, err := time.Parse(time.RFC3339Nano, *d.av.S)
   340  		return true, t, err
   341  	}
   342  	return false, nil, nil
   343  }