go.mercari.io/datastore@v1.8.2/load.go (about)

     1  // Copyright 2014 Google LLC
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package datastore
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"reflect"
    21  	"strings"
    22  	"time"
    23  
    24  	"go.mercari.io/datastore/internal/c/fields"
    25  )
    26  
    27  var (
    28  	typeOfByteSlice = reflect.TypeOf([]byte(nil))
    29  	typeOfTime      = reflect.TypeOf(time.Time{})
    30  	typeOfGeoPoint  = reflect.TypeOf(GeoPoint{})
    31  	typeOfKey       = reflect.TypeOf((*Key)(nil)).Elem()
    32  )
    33  
    34  // typeMismatchReason returns a string explaining why the property p could not
    35  // be stored in an entity field of type v.Type().
    36  func typeMismatchReason(p Property, v reflect.Value) string {
    37  	entityType := "empty"
    38  	switch p.Value.(type) {
    39  	case int64:
    40  		entityType = "int"
    41  	case bool:
    42  		entityType = "bool"
    43  	case string:
    44  		entityType = "string"
    45  	case float64:
    46  		entityType = "float"
    47  	case Key:
    48  		entityType = "datastore.Key"
    49  	case *Entity:
    50  		entityType = "*datastore.Entity"
    51  	case GeoPoint:
    52  		entityType = "GeoPoint"
    53  	case time.Time:
    54  		entityType = "time.Time"
    55  	case []byte:
    56  		entityType = "[]byte"
    57  	}
    58  
    59  	return fmt.Sprintf("type mismatch: %s versus %v", entityType, v.Type())
    60  }
    61  
    62  func overflowReason(x interface{}, v reflect.Value) string {
    63  	return fmt.Sprintf("value %v overflows struct field of type %v", x, v.Type())
    64  }
    65  
    66  type propertyLoader struct {
    67  	// m holds the number of times a substruct field like "Foo.Bar.Baz" has
    68  	// been seen so far. The map is constructed lazily.
    69  	m map[string]int
    70  }
    71  
    72  func (l *propertyLoader) load(ctx context.Context, codec fields.List, structValue reflect.Value, p Property, prev map[string]struct{}) string {
    73  	sl, ok := p.Value.([]interface{})
    74  	if !ok {
    75  		return l.loadOneElement(ctx, codec, structValue, p, prev)
    76  	}
    77  	for _, val := range sl {
    78  		p.Value = val
    79  		if errStr := l.loadOneElement(ctx, codec, structValue, p, prev); errStr != "" {
    80  			return errStr
    81  		}
    82  	}
    83  	return ""
    84  }
    85  
    86  // loadOneElement loads the value of Property p into structValue based on the provided
    87  // codec. codec is used to find the field in structValue into which p should be loaded.
    88  // prev is the set of property names already seen for structValue.
    89  func (l *propertyLoader) loadOneElement(ctx context.Context, codec fields.List, structValue reflect.Value, p Property, prev map[string]struct{}) string {
    90  	var sliceOk bool
    91  	var sliceIndex int
    92  	var v reflect.Value
    93  
    94  	name := p.Name
    95  	fieldNames := strings.Split(name, ".")
    96  
    97  	for len(fieldNames) > 0 {
    98  		var field *fields.Field
    99  
   100  		// Start by trying to find a field with name. If none found,
   101  		// cut off the last field (delimited by ".") and find its parent
   102  		// in the codec.
   103  		// eg. for name "A.B.C.D", split off "A.B.C" and try to
   104  		// find a field in the codec with this name.
   105  		// Loop again with "A.B", etc.
   106  		for i := len(fieldNames); i > 0; i-- {
   107  			parent := strings.Join(fieldNames[:i], ".")
   108  			field = codec.Match(parent)
   109  			if field != nil {
   110  				fieldNames = fieldNames[i:]
   111  				break
   112  			}
   113  		}
   114  
   115  		// If we never found a matching field in the codec, return
   116  		// error message.
   117  		if field == nil {
   118  			return "no such struct field"
   119  		}
   120  
   121  		v = initField(structValue, field.Index)
   122  		if !v.IsValid() {
   123  			return "no such struct field"
   124  		}
   125  		if !v.CanSet() {
   126  			return "cannot set struct field"
   127  		}
   128  
   129  		ok, err := ptFieldLoad(ctx, v, p, fieldNames)
   130  		if err != nil {
   131  			return err.Error()
   132  		}
   133  		if ok {
   134  			return ""
   135  		}
   136  
   137  		// If field implements PLS, we delegate loading to the PLS's Load early,
   138  		// and stop iterating through fields.
   139  		ok, err = plsFieldLoad(ctx, v, p, fieldNames)
   140  		if err != nil {
   141  			return err.Error()
   142  		}
   143  		if ok {
   144  			return ""
   145  		}
   146  
   147  		if field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct {
   148  			codec, err = structCache.Fields(field.Type.Elem())
   149  			if err != nil {
   150  				return err.Error()
   151  			}
   152  
   153  			// Init value if its nil
   154  			if v.IsNil() {
   155  				v.Set(reflect.New(field.Type.Elem()))
   156  			}
   157  			structValue = v.Elem()
   158  		}
   159  
   160  		if field.Type.Kind() == reflect.Struct {
   161  			codec, err = structCache.Fields(field.Type)
   162  			if err != nil {
   163  				return err.Error()
   164  			}
   165  			structValue = v
   166  		}
   167  
   168  		// If the element is a slice, we need to accommodate it.
   169  		if v.Kind() == reflect.Slice && v.Type() != typeOfByteSlice {
   170  			if l.m == nil {
   171  				l.m = make(map[string]int)
   172  			}
   173  			sliceIndex = l.m[p.Name]
   174  			l.m[p.Name] = sliceIndex + 1
   175  			for v.Len() <= sliceIndex {
   176  				v.Set(reflect.Append(v, reflect.New(v.Type().Elem()).Elem()))
   177  			}
   178  			structValue = v.Index(sliceIndex)
   179  
   180  			ok, err := ptFieldLoad(ctx, structValue, p, fieldNames)
   181  			if err != nil {
   182  				return err.Error()
   183  			}
   184  			if ok {
   185  				return ""
   186  			}
   187  
   188  			// If structValue implements PLS, we delegate loading to the PLS's
   189  			// Load early, and stop iterating through fields.
   190  			ok, err = plsFieldLoad(ctx, structValue, p, fieldNames)
   191  			if err != nil {
   192  				return err.Error()
   193  			}
   194  			if ok {
   195  				return ""
   196  			}
   197  
   198  			if structValue.Type().Kind() == reflect.Struct {
   199  				codec, err = structCache.Fields(structValue.Type())
   200  				if err != nil {
   201  					return err.Error()
   202  				}
   203  			}
   204  			sliceOk = true
   205  		}
   206  	}
   207  
   208  	var slice reflect.Value
   209  	if v.Kind() == reflect.Slice && v.Type().Elem().Kind() != reflect.Uint8 {
   210  		slice = v
   211  		v = reflect.New(v.Type().Elem()).Elem()
   212  	} else if _, ok := prev[p.Name]; ok && !sliceOk {
   213  		// Zero the field back out that was set previously, turns out
   214  		// it's a slice and we don't know what to do with it
   215  		v.Set(reflect.Zero(v.Type()))
   216  		return "multiple-valued property requires a slice field type"
   217  	}
   218  
   219  	prev[p.Name] = struct{}{}
   220  
   221  	if errReason := setVal(ctx, v, p); errReason != "" {
   222  		// Set the slice back to its zero value.
   223  		if slice.IsValid() {
   224  			slice.Set(reflect.Zero(slice.Type()))
   225  		}
   226  		return errReason
   227  	}
   228  
   229  	if slice.IsValid() {
   230  		slice.Index(sliceIndex).Set(v)
   231  	}
   232  
   233  	return ""
   234  }
   235  
   236  // plsFieldLoad first tries to converts v's value to a PLS, then v's addressed
   237  // value to a PLS. If neither succeeds, plsFieldLoad returns false for first return
   238  // value. Otherwise, the first return value will be true.
   239  // If v is successfully converted to a PLS, plsFieldLoad will then try to Load
   240  // the property p into v (by way of the PLS's Load method).
   241  //
   242  // If the field v has been flattened, the Property's name must be altered
   243  // before calling Load to reflect the field v.
   244  // For example, if our original field name was "A.B.C.D",
   245  // and at this point in iteration we had initialized the field
   246  // corresponding to "A" and have moved into the struct, so that now
   247  // v corresponds to the field named "B", then we want to let the
   248  // PLS handle this field (B)'s subfields ("C", "D"),
   249  // so we send the property to the PLS's Load, renamed to "C.D".
   250  //
   251  // If subfields are present, the field v has been flattened.
   252  func plsFieldLoad(ctx context.Context, v reflect.Value, p Property, subfields []string) (ok bool, err error) {
   253  	vpls, err := plsForLoad(v)
   254  	if err != nil {
   255  		return false, err
   256  	}
   257  
   258  	if vpls == nil {
   259  		return false, nil
   260  	}
   261  
   262  	// If Entity, load properties as well as key.
   263  	if e, ok := p.Value.(*Entity); ok {
   264  		err = loadEntity(ctx, vpls, e)
   265  		return true, err
   266  	}
   267  
   268  	// If flattened, we must alter the property's name to reflect
   269  	// the field v.
   270  	if len(subfields) > 0 {
   271  		p.Name = strings.Join(subfields, ".")
   272  	}
   273  
   274  	return true, vpls.Load(ctx, []Property{p})
   275  }
   276  
   277  // ptFieldLoad try to load value by PropertyTranslator.
   278  // this function is peculiar to mercari/datastore.
   279  func ptFieldLoad(ctx context.Context, v reflect.Value, p Property, subfields []string) (ok bool, err error) {
   280  	vpt, err := ptForLoad(v)
   281  	if err != nil {
   282  		return false, err
   283  	}
   284  
   285  	if vpt == nil {
   286  		return false, nil
   287  	}
   288  
   289  	// If Entity, load properties as well as key.
   290  	if e, ok := p.Value.(*Entity); ok {
   291  		err = loadEntity(ctx, vpt, e)
   292  		return true, err
   293  	}
   294  
   295  	// If flattened, we must alter the property's name to reflect
   296  	// the field v.
   297  	if len(subfields) > 0 {
   298  		p.Name = strings.Join(subfields, ".")
   299  	}
   300  
   301  	dst, err := vpt.FromPropertyValue(ctx, p)
   302  	if err != nil {
   303  		return false, err
   304  	}
   305  
   306  	v.Set(reflect.ValueOf(dst))
   307  
   308  	return true, nil
   309  }
   310  
   311  // setVal sets 'v' to the value of the Property 'p'.
   312  func setVal(ctx context.Context, v reflect.Value, p Property) (s string) {
   313  	pValue := p.Value
   314  	switch v.Kind() {
   315  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   316  		x, ok := pValue.(int64)
   317  		if !ok && pValue != nil {
   318  			return typeMismatchReason(p, v)
   319  		}
   320  		if v.OverflowInt(x) {
   321  			return overflowReason(x, v)
   322  		}
   323  		v.SetInt(x)
   324  	case reflect.Bool:
   325  		x, ok := pValue.(bool)
   326  		if !ok && pValue != nil {
   327  			return typeMismatchReason(p, v)
   328  		}
   329  		v.SetBool(x)
   330  	case reflect.String:
   331  		x, ok := pValue.(string)
   332  		if !ok && pValue != nil {
   333  			return typeMismatchReason(p, v)
   334  		}
   335  		v.SetString(x)
   336  	case reflect.Float32, reflect.Float64:
   337  		x, ok := pValue.(float64)
   338  		if !ok && pValue != nil {
   339  			return typeMismatchReason(p, v)
   340  		}
   341  		if v.OverflowFloat(x) {
   342  			return overflowReason(x, v)
   343  		}
   344  		v.SetFloat(x)
   345  	case reflect.Ptr:
   346  		// v must be a pointer to either a Key, an Entity, or one of the supported basic types.
   347  		if v.Type() != typeOfKey && v.Type().Elem().Kind() != reflect.Struct && !isValidPointerType(v.Type().Elem()) {
   348  			return typeMismatchReason(p, v)
   349  		}
   350  
   351  		if pValue == nil {
   352  			// If v is populated already, set it to nil.
   353  			if !v.IsNil() {
   354  				v.Set(reflect.New(v.Type()).Elem())
   355  			}
   356  			return ""
   357  		}
   358  
   359  		//if x, ok := pValue.(Key); ok {
   360  		//	if _, ok := v.Interface().(Key); !ok {
   361  		//		return typeMismatchReason(p, v)
   362  		//	}
   363  		//	v.Set(reflect.ValueOf(x))
   364  		//	return ""
   365  		//}
   366  		if v.IsNil() {
   367  			v.Set(reflect.New(v.Type().Elem()))
   368  		}
   369  		switch x := pValue.(type) {
   370  		case *Entity:
   371  			err := loadEntity(ctx, v.Interface(), x)
   372  			if err != nil {
   373  				return err.Error()
   374  			}
   375  		case int64:
   376  			if v.Elem().OverflowInt(x) {
   377  				return overflowReason(x, v.Elem())
   378  			}
   379  			v.Elem().SetInt(x)
   380  		case float64:
   381  			if v.Elem().OverflowFloat(x) {
   382  				return overflowReason(x, v.Elem())
   383  			}
   384  			v.Elem().SetFloat(x)
   385  		case bool:
   386  			v.Elem().SetBool(x)
   387  		case string:
   388  			v.Elem().SetString(x)
   389  		case GeoPoint, time.Time:
   390  			v.Elem().Set(reflect.ValueOf(x))
   391  		default:
   392  			return typeMismatchReason(p, v)
   393  		}
   394  	case reflect.Interface:
   395  		switch v.Type() {
   396  		case typeOfKey:
   397  			x, ok := pValue.(Key)
   398  			if !ok && pValue != nil {
   399  				return typeMismatchReason(p, v)
   400  			}
   401  			if x == nil {
   402  				if !v.IsNil() {
   403  					v.Set(reflect.ValueOf(nil))
   404  				}
   405  			} else {
   406  				v.Set(reflect.ValueOf(x))
   407  			}
   408  		default:
   409  			return typeMismatchReason(p, v)
   410  		}
   411  	case reflect.Struct:
   412  		switch v.Type() {
   413  		case typeOfTime:
   414  			x, ok := pValue.(time.Time)
   415  			if !ok && pValue != nil {
   416  				return typeMismatchReason(p, v)
   417  			}
   418  			v.Set(reflect.ValueOf(x))
   419  		case typeOfGeoPoint:
   420  			x, ok := pValue.(GeoPoint)
   421  			if !ok && pValue != nil {
   422  				return typeMismatchReason(p, v)
   423  			}
   424  			v.Set(reflect.ValueOf(x))
   425  		default:
   426  			ent, ok := pValue.(*Entity)
   427  			if !ok {
   428  				return typeMismatchReason(p, v)
   429  			}
   430  			err := loadEntity(ctx, v.Addr().Interface(), ent)
   431  			if err != nil {
   432  				return err.Error()
   433  			}
   434  		}
   435  	case reflect.Slice:
   436  		x, ok := pValue.([]byte)
   437  		if !ok && pValue != nil {
   438  			return typeMismatchReason(p, v)
   439  		}
   440  		if v.Type().Elem().Kind() != reflect.Uint8 {
   441  			return typeMismatchReason(p, v)
   442  		}
   443  		v.SetBytes(x)
   444  	default:
   445  		return typeMismatchReason(p, v)
   446  	}
   447  	return ""
   448  }
   449  
   450  // initField is similar to reflect's Value.FieldByIndex, in that it
   451  // returns the nested struct field corresponding to index, but it
   452  // initialises any nil pointers encountered when traversing the structure.
   453  func initField(val reflect.Value, index []int) reflect.Value {
   454  	for _, i := range index[:len(index)-1] {
   455  		val = val.Field(i)
   456  		if val.Kind() == reflect.Ptr {
   457  			if val.IsNil() {
   458  				val.Set(reflect.New(val.Type().Elem()))
   459  			}
   460  			val = val.Elem()
   461  		}
   462  	}
   463  	return val.Field(index[len(index)-1])
   464  }
   465  
   466  func loadEntity(ctx context.Context, dst interface{}, ent *Entity) error {
   467  	if pls, ok := dst.(PropertyLoadSaver); ok {
   468  		err := pls.Load(ctx, ent.Properties)
   469  		if err != nil {
   470  			return err
   471  		}
   472  		if e, ok := dst.(KeyLoader); ok {
   473  			err = e.LoadKey(ctx, ent.Key)
   474  		}
   475  		return err
   476  	}
   477  	return loadEntityToStruct(ctx, dst, ent)
   478  }
   479  
   480  func loadEntityToStruct(ctx context.Context, dst interface{}, ent *Entity) error {
   481  	pls, err := newStructPLS(dst)
   482  	if err != nil {
   483  		return err
   484  	}
   485  
   486  	// Try and load key.
   487  	keyField := pls.codec.Match(keyFieldName)
   488  	if keyField != nil && ent.Key != nil {
   489  		pls.v.FieldByIndex(keyField.Index).Set(reflect.ValueOf(ent.Key))
   490  	}
   491  
   492  	// Load properties.
   493  	return pls.Load(ctx, ent.Properties)
   494  }
   495  
   496  func (s structPLS) Load(ctx context.Context, props []Property) error {
   497  	var fieldName, errReason string
   498  	var l propertyLoader
   499  
   500  	prev := make(map[string]struct{})
   501  	for _, p := range props {
   502  		if errStr := l.load(ctx, s.codec, s.v, p, prev); errStr != "" {
   503  			// We don't return early, as we try to load as many properties as possible.
   504  			// It is valid to load an entity into a struct that cannot fully represent it.
   505  			// That case returns an error, but the caller is free to ignore it.
   506  			fieldName, errReason = p.Name, errStr
   507  		}
   508  	}
   509  	if errReason != "" {
   510  		if !SuppressErrFieldMismatch {
   511  			return &ErrFieldMismatch{
   512  				StructType: s.v.Type(),
   513  				FieldName:  fieldName,
   514  				Reason:     errReason,
   515  			}
   516  		}
   517  	}
   518  	return nil
   519  }