github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/go.mongodb.org/mongo-driver/bson/bsoncodec/map_codec.go (about)

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package bsoncodec
     8  
     9  import (
    10  	"encoding"
    11  	"fmt"
    12  	"reflect"
    13  	"strconv"
    14  
    15  	"go.mongodb.org/mongo-driver/bson/bsonoptions"
    16  	"go.mongodb.org/mongo-driver/bson/bsonrw"
    17  	"go.mongodb.org/mongo-driver/bson/bsontype"
    18  )
    19  
    20  var defaultMapCodec = NewMapCodec()
    21  
    22  // MapCodec is the Codec used for map values.
    23  //
    24  // Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the
    25  // MapCodec registered.
    26  type MapCodec struct {
    27  	// DecodeZerosMap causes DecodeValue to delete any existing values from Go maps in the destination
    28  	// value passed to Decode before unmarshaling BSON documents into them.
    29  	//
    30  	// Deprecated: Use bson.Decoder.ZeroMaps instead.
    31  	DecodeZerosMap bool
    32  
    33  	// EncodeNilAsEmpty causes EncodeValue to marshal nil Go maps as empty BSON documents instead of
    34  	// BSON null.
    35  	//
    36  	// Deprecated: Use bson.Encoder.NilMapAsEmpty instead.
    37  	EncodeNilAsEmpty bool
    38  
    39  	// EncodeKeysWithStringer causes the Encoder to convert Go map keys to BSON document field name
    40  	// strings using fmt.Sprintf() instead of the default string conversion logic.
    41  	//
    42  	// Deprecated: Use bson.Encoder.StringifyMapKeysWithFmt instead.
    43  	EncodeKeysWithStringer bool
    44  }
    45  
    46  // KeyMarshaler is the interface implemented by an object that can marshal itself into a string key.
    47  // This applies to types used as map keys and is similar to encoding.TextMarshaler.
    48  type KeyMarshaler interface {
    49  	MarshalKey() (key string, err error)
    50  }
    51  
    52  // KeyUnmarshaler is the interface implemented by an object that can unmarshal a string representation
    53  // of itself. This applies to types used as map keys and is similar to encoding.TextUnmarshaler.
    54  //
    55  // UnmarshalKey must be able to decode the form generated by MarshalKey.
    56  // UnmarshalKey must copy the text if it wishes to retain the text
    57  // after returning.
    58  type KeyUnmarshaler interface {
    59  	UnmarshalKey(key string) error
    60  }
    61  
    62  // NewMapCodec returns a MapCodec with options opts.
    63  //
    64  // Deprecated: Use [go.mongodb.org/mongo-driver/bson.NewRegistry] to get a registry with the
    65  // MapCodec registered.
    66  func NewMapCodec(opts ...*bsonoptions.MapCodecOptions) *MapCodec {
    67  	mapOpt := bsonoptions.MergeMapCodecOptions(opts...)
    68  
    69  	codec := MapCodec{}
    70  	if mapOpt.DecodeZerosMap != nil {
    71  		codec.DecodeZerosMap = *mapOpt.DecodeZerosMap
    72  	}
    73  	if mapOpt.EncodeNilAsEmpty != nil {
    74  		codec.EncodeNilAsEmpty = *mapOpt.EncodeNilAsEmpty
    75  	}
    76  	if mapOpt.EncodeKeysWithStringer != nil {
    77  		codec.EncodeKeysWithStringer = *mapOpt.EncodeKeysWithStringer
    78  	}
    79  	return &codec
    80  }
    81  
    82  // EncodeValue is the ValueEncoder for map[*]* types.
    83  func (mc *MapCodec) EncodeValue(ec EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
    84  	if !val.IsValid() || val.Kind() != reflect.Map {
    85  		return ValueEncoderError{Name: "MapEncodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
    86  	}
    87  
    88  	if val.IsNil() && !mc.EncodeNilAsEmpty && !ec.nilMapAsEmpty {
    89  		// If we have a nil map but we can't WriteNull, that means we're probably trying to encode
    90  		// to a TopLevel document. We can't currently tell if this is what actually happened, but if
    91  		// there's a deeper underlying problem, the error will also be returned from WriteDocument,
    92  		// so just continue. The operations on a map reflection value are valid, so we can call
    93  		// MapKeys within mapEncodeValue without a problem.
    94  		err := vw.WriteNull()
    95  		if err == nil {
    96  			return nil
    97  		}
    98  	}
    99  
   100  	dw, err := vw.WriteDocument()
   101  	if err != nil {
   102  		return err
   103  	}
   104  
   105  	return mc.mapEncodeValue(ec, dw, val, nil)
   106  }
   107  
   108  // mapEncodeValue handles encoding of the values of a map. The collisionFn returns
   109  // true if the provided key exists, this is mainly used for inline maps in the
   110  // struct codec.
   111  func (mc *MapCodec) mapEncodeValue(ec EncodeContext, dw bsonrw.DocumentWriter, val reflect.Value, collisionFn func(string) bool) error {
   112  
   113  	elemType := val.Type().Elem()
   114  	encoder, err := ec.LookupEncoder(elemType)
   115  	if err != nil && elemType.Kind() != reflect.Interface {
   116  		return err
   117  	}
   118  
   119  	keys := val.MapKeys()
   120  	for _, key := range keys {
   121  		keyStr, err := mc.encodeKey(key, ec.stringifyMapKeysWithFmt)
   122  		if err != nil {
   123  			return err
   124  		}
   125  
   126  		if collisionFn != nil && collisionFn(keyStr) {
   127  			return fmt.Errorf("Key %s of inlined map conflicts with a struct field name", key)
   128  		}
   129  
   130  		currEncoder, currVal, lookupErr := defaultValueEncoders.lookupElementEncoder(ec, encoder, val.MapIndex(key))
   131  		if lookupErr != nil && lookupErr != errInvalidValue {
   132  			return lookupErr
   133  		}
   134  
   135  		vw, err := dw.WriteDocumentElement(keyStr)
   136  		if err != nil {
   137  			return err
   138  		}
   139  
   140  		if lookupErr == errInvalidValue {
   141  			err = vw.WriteNull()
   142  			if err != nil {
   143  				return err
   144  			}
   145  			continue
   146  		}
   147  
   148  		err = currEncoder.EncodeValue(ec, vw, currVal)
   149  		if err != nil {
   150  			return err
   151  		}
   152  	}
   153  
   154  	return dw.WriteDocumentEnd()
   155  }
   156  
   157  // DecodeValue is the ValueDecoder for map[string/decimal]* types.
   158  func (mc *MapCodec) DecodeValue(dc DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
   159  	if val.Kind() != reflect.Map || (!val.CanSet() && val.IsNil()) {
   160  		return ValueDecoderError{Name: "MapDecodeValue", Kinds: []reflect.Kind{reflect.Map}, Received: val}
   161  	}
   162  
   163  	switch vrType := vr.Type(); vrType {
   164  	case bsontype.Type(0), bsontype.EmbeddedDocument:
   165  	case bsontype.Null:
   166  		val.Set(reflect.Zero(val.Type()))
   167  		return vr.ReadNull()
   168  	case bsontype.Undefined:
   169  		val.Set(reflect.Zero(val.Type()))
   170  		return vr.ReadUndefined()
   171  	default:
   172  		return fmt.Errorf("cannot decode %v into a %s", vrType, val.Type())
   173  	}
   174  
   175  	dr, err := vr.ReadDocument()
   176  	if err != nil {
   177  		return err
   178  	}
   179  
   180  	if val.IsNil() {
   181  		val.Set(reflect.MakeMap(val.Type()))
   182  	}
   183  
   184  	if val.Len() > 0 && (mc.DecodeZerosMap || dc.zeroMaps) {
   185  		clearMap(val)
   186  	}
   187  
   188  	eType := val.Type().Elem()
   189  	decoder, err := dc.LookupDecoder(eType)
   190  	if err != nil {
   191  		return err
   192  	}
   193  	eTypeDecoder, _ := decoder.(typeDecoder)
   194  
   195  	if eType == tEmpty {
   196  		dc.Ancestor = val.Type()
   197  	}
   198  
   199  	keyType := val.Type().Key()
   200  
   201  	for {
   202  		key, vr, err := dr.ReadElement()
   203  		if err == bsonrw.ErrEOD {
   204  			break
   205  		}
   206  		if err != nil {
   207  			return err
   208  		}
   209  
   210  		k, err := mc.decodeKey(key, keyType)
   211  		if err != nil {
   212  			return err
   213  		}
   214  
   215  		elem, err := decodeTypeOrValueWithInfo(decoder, eTypeDecoder, dc, vr, eType, true)
   216  		if err != nil {
   217  			return newDecodeError(key, err)
   218  		}
   219  
   220  		val.SetMapIndex(k, elem)
   221  	}
   222  	return nil
   223  }
   224  
   225  func clearMap(m reflect.Value) {
   226  	var none reflect.Value
   227  	for _, k := range m.MapKeys() {
   228  		m.SetMapIndex(k, none)
   229  	}
   230  }
   231  
   232  func (mc *MapCodec) encodeKey(val reflect.Value, encodeKeysWithStringer bool) (string, error) {
   233  	if mc.EncodeKeysWithStringer || encodeKeysWithStringer {
   234  		return fmt.Sprint(val), nil
   235  	}
   236  
   237  	// keys of any string type are used directly
   238  	if val.Kind() == reflect.String {
   239  		return val.String(), nil
   240  	}
   241  	// KeyMarshalers are marshaled
   242  	if km, ok := val.Interface().(KeyMarshaler); ok {
   243  		if val.Kind() == reflect.Ptr && val.IsNil() {
   244  			return "", nil
   245  		}
   246  		buf, err := km.MarshalKey()
   247  		if err == nil {
   248  			return buf, nil
   249  		}
   250  		return "", err
   251  	}
   252  	// keys implement encoding.TextMarshaler are marshaled.
   253  	if km, ok := val.Interface().(encoding.TextMarshaler); ok {
   254  		if val.Kind() == reflect.Ptr && val.IsNil() {
   255  			return "", nil
   256  		}
   257  
   258  		buf, err := km.MarshalText()
   259  		if err != nil {
   260  			return "", err
   261  		}
   262  
   263  		return string(buf), nil
   264  	}
   265  
   266  	switch val.Kind() {
   267  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   268  		return strconv.FormatInt(val.Int(), 10), nil
   269  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   270  		return strconv.FormatUint(val.Uint(), 10), nil
   271  	}
   272  	return "", fmt.Errorf("unsupported key type: %v", val.Type())
   273  }
   274  
   275  var keyUnmarshalerType = reflect.TypeOf((*KeyUnmarshaler)(nil)).Elem()
   276  var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
   277  
   278  func (mc *MapCodec) decodeKey(key string, keyType reflect.Type) (reflect.Value, error) {
   279  	keyVal := reflect.ValueOf(key)
   280  	var err error
   281  	switch {
   282  	// First, if EncodeKeysWithStringer is not enabled, try to decode withKeyUnmarshaler
   283  	case !mc.EncodeKeysWithStringer && reflect.PtrTo(keyType).Implements(keyUnmarshalerType):
   284  		keyVal = reflect.New(keyType)
   285  		v := keyVal.Interface().(KeyUnmarshaler)
   286  		err = v.UnmarshalKey(key)
   287  		keyVal = keyVal.Elem()
   288  	// Try to decode encoding.TextUnmarshalers.
   289  	case reflect.PtrTo(keyType).Implements(textUnmarshalerType):
   290  		keyVal = reflect.New(keyType)
   291  		v := keyVal.Interface().(encoding.TextUnmarshaler)
   292  		err = v.UnmarshalText([]byte(key))
   293  		keyVal = keyVal.Elem()
   294  	// Otherwise, go to type specific behavior
   295  	default:
   296  		switch keyType.Kind() {
   297  		case reflect.String:
   298  			keyVal = reflect.ValueOf(key).Convert(keyType)
   299  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   300  			n, parseErr := strconv.ParseInt(key, 10, 64)
   301  			if parseErr != nil || reflect.Zero(keyType).OverflowInt(n) {
   302  				err = fmt.Errorf("failed to unmarshal number key %v", key)
   303  			}
   304  			keyVal = reflect.ValueOf(n).Convert(keyType)
   305  		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
   306  			n, parseErr := strconv.ParseUint(key, 10, 64)
   307  			if parseErr != nil || reflect.Zero(keyType).OverflowUint(n) {
   308  				err = fmt.Errorf("failed to unmarshal number key %v", key)
   309  				break
   310  			}
   311  			keyVal = reflect.ValueOf(n).Convert(keyType)
   312  		case reflect.Float32, reflect.Float64:
   313  			if mc.EncodeKeysWithStringer {
   314  				parsed, err := strconv.ParseFloat(key, 64)
   315  				if err != nil {
   316  					return keyVal, fmt.Errorf("Map key is defined to be a decimal type (%v) but got error %v", keyType.Kind(), err)
   317  				}
   318  				keyVal = reflect.ValueOf(parsed)
   319  				break
   320  			}
   321  			fallthrough
   322  		default:
   323  			return keyVal, fmt.Errorf("unsupported key type: %v", keyType)
   324  		}
   325  	}
   326  	return keyVal, err
   327  }