github.com/philpearl/plenc@v0.0.15/plenccodec/map.go (about)

     1  package plenccodec
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"sync"
     7  	"unsafe"
     8  
     9  	"github.com/philpearl/plenc/plenccore"
    10  )
    11  
    12  // MapCodec is a codec for maps. We treat it as a slice of structs with the key
    13  // and value as the fields in the structs.
    14  type MapCodec struct {
    15  	keyCodec   Codec
    16  	valueCodec Codec
    17  	rtype      reflect.Type
    18  	keyTag     []byte
    19  	valueTag   []byte
    20  	kPool      sync.Pool
    21  	kZero      unsafe.Pointer
    22  	vZero      unsafe.Pointer
    23  }
    24  
    25  func BuildMapCodec(p CodecBuilder, registry CodecRegistry, typ reflect.Type, tag string) (Codec, error) {
    26  	if typ.Kind() != reflect.Map {
    27  		return nil, fmt.Errorf("type must be a map to build a map codec")
    28  	}
    29  
    30  	keyCodec, err := p.CodecForTypeRegistry(registry, typ.Key(), "")
    31  	if err != nil {
    32  		return nil, fmt.Errorf("failed to find codec for map key %s. %w", typ.Key().Name(), err)
    33  	}
    34  	valueCodec, err := p.CodecForTypeRegistry(registry, typ.Elem(), "")
    35  	if err != nil {
    36  		return nil, fmt.Errorf("failed to find codec for map value %s. %w", typ.Elem().Name(), err)
    37  	}
    38  
    39  	c := MapCodec{
    40  		keyCodec:   keyCodec,
    41  		valueCodec: valueCodec,
    42  		rtype:      typ,
    43  		keyTag:     plenccore.AppendTag(nil, keyCodec.WireType(), 1),
    44  		valueTag:   plenccore.AppendTag(nil, valueCodec.WireType(), 2),
    45  	}
    46  
    47  	c.kPool.New = c.newKey
    48  	if l := int(typ.Key().Size()); l <= len(zero) {
    49  		c.kZero = unsafe.Pointer(&zero)
    50  	} else {
    51  		z := make([]byte, l)
    52  		c.kZero = unsafe.Pointer(&z[0])
    53  	}
    54  
    55  	if l := int(typ.Elem().Size()); l <= len(zero) {
    56  		c.vZero = unsafe.Pointer(&zero)
    57  	} else {
    58  		z := make([]byte, l)
    59  		c.vZero = unsafe.Pointer(&z[0])
    60  	}
    61  
    62  	if tag == "proto" {
    63  		return ProtoMapCodec{&c}, nil
    64  	}
    65  
    66  	return &c, nil
    67  }
    68  
    69  func (c *MapCodec) newKey() interface{} {
    70  	return c.keyCodec.New()
    71  }
    72  
    73  // When we're writing ptr is a map pointer. When reading it is a pointer to a
    74  // map pointer
    75  
    76  func (c *MapCodec) Omit(ptr unsafe.Pointer) bool {
    77  	return ptr == nil
    78  }
    79  
    80  func (c *MapCodec) size(ptr unsafe.Pointer) (size int) {
    81  	size = plenccore.SizeVarUint(uint64(maplen(ptr)))
    82  
    83  	var iterM mapiter
    84  	iter := (unsafe.Pointer)(&iterM)
    85  	mapiterinit(unpackEFace(c.rtype).data, ptr, iter)
    86  	for {
    87  		k := mapiterkey(iter)
    88  		if k == nil {
    89  			break
    90  		}
    91  		v := mapiterelem(iter)
    92  
    93  		s := c.sizeForEntry(k, v)
    94  		size += plenccore.SizeVarUint(uint64(s)) + s
    95  
    96  		mapiternext(iter)
    97  	}
    98  	return size
    99  }
   100  
   101  func (c *MapCodec) sizeForEntry(k, v unsafe.Pointer) int {
   102  	s := c.sizeFor(c.keyCodec, k, c.keyTag)
   103  	return s + c.sizeFor(c.valueCodec, v, c.valueTag)
   104  }
   105  
   106  func (*MapCodec) sizeFor(underlying Codec, ptr unsafe.Pointer, tag []byte) int {
   107  	if underlying.Omit(ptr) {
   108  		return 0
   109  	}
   110  	return underlying.Size(ptr, tag)
   111  }
   112  
   113  func (c *MapCodec) append(data []byte, ptr unsafe.Pointer) []byte {
   114  	add := func(underlying Codec, ptr unsafe.Pointer, tag []byte) {
   115  		if !underlying.Omit(ptr) {
   116  			data = underlying.Append(data, ptr, tag)
   117  		}
   118  	}
   119  
   120  	// First add the count of entries
   121  	data = plenccore.AppendVarUint(data, uint64(maplen(ptr)))
   122  
   123  	var iterM mapiter
   124  	iter := (unsafe.Pointer)(&iterM)
   125  	mapiterinit(unpackEFace(c.rtype).data, ptr, iter)
   126  	for {
   127  		k := mapiterkey(iter)
   128  		if k == nil {
   129  			break
   130  		}
   131  		v := mapiterelem(iter)
   132  
   133  		// Add the length of each entry, then the key and value
   134  		data = plenccore.AppendVarUint(data, uint64(c.sizeForEntry(k, v)))
   135  		add(c.keyCodec, k, c.keyTag)
   136  		add(c.valueCodec, v, c.valueTag)
   137  
   138  		mapiternext(iter)
   139  	}
   140  
   141  	return data
   142  }
   143  
   144  var zero [1024]byte
   145  
   146  func (c *MapCodec) Read(data []byte, ptr unsafe.Pointer, wt plenccore.WireType) (n int, err error) {
   147  	if len(data) == 0 {
   148  		return 0, nil
   149  	}
   150  
   151  	// We start with a count of entries
   152  	count, n := plenccore.ReadVarUint(data)
   153  	if n <= 0 {
   154  		return 0, fmt.Errorf("failed to read map size")
   155  	}
   156  
   157  	// ptr is a pointer to a map pointer
   158  	if *(*unsafe.Pointer)(ptr) == nil {
   159  		*(*unsafe.Pointer)(ptr) = unsafe.Pointer(reflect.MakeMapWithSize(c.rtype, int(count)).Pointer())
   160  	}
   161  	mp := *(*unsafe.Pointer)(ptr)
   162  
   163  	// We need some space to hold keys and values as we read them out. We can
   164  	// re-use the space on each iteration as the data is copied into the map
   165  	// We also save some memory & time if we cache them in some pools
   166  	k := c.kPool.Get().(unsafe.Pointer)
   167  	defer c.kPool.Put(k)
   168  	offset := int(n)
   169  	for count > 0 {
   170  		// Each entry starts with a length
   171  		entryLength, n := plenccore.ReadVarUint(data[offset:])
   172  		if n <= 0 {
   173  			return 0, fmt.Errorf("failed to read map entry length")
   174  		}
   175  		offset += n
   176  		n, err := c.readMapEntry(mp, k, data[offset:offset+int(entryLength)])
   177  		if err != nil {
   178  			return 0, err
   179  		}
   180  		offset += n
   181  		count--
   182  	}
   183  
   184  	return offset, nil
   185  }
   186  
   187  // readMapEntry reads out a single map entry. mp is the map pointer. k is an
   188  // area to read key values into. data is the raw data for this map entry
   189  func (c *MapCodec) readMapEntry(mp, k unsafe.Pointer, data []byte) (int, error) {
   190  	offset, fieldEnd, index, wt, err := c.readTagAndLength(data, 0)
   191  	if err != nil {
   192  		return 0, err
   193  	}
   194  
   195  	if index == 1 {
   196  		// Key is present - read it
   197  		n, err := c.keyCodec.Read(data[offset:fieldEnd], k, wt)
   198  		if err != nil {
   199  			return 0, fmt.Errorf("failed reading key field of %s. %w", c.rtype.Name(), err)
   200  		}
   201  		offset += n
   202  	} else {
   203  		k = c.kZero
   204  	}
   205  
   206  	// Assign/find a place in the map for this key. Val is a pointer to where
   207  	// the value should be. We're going to unmarshal into this directly
   208  	val := mapassign(unpackEFace(c.rtype).data, mp, k)
   209  
   210  	if offset < len(data) {
   211  		if index == 1 {
   212  			offset, fieldEnd, _, wt, err = c.readTagAndLength(data, offset)
   213  			if err != nil {
   214  				return 0, err
   215  			}
   216  		}
   217  
   218  		n, err := c.valueCodec.Read(data[offset:fieldEnd], val, wt)
   219  		if err != nil {
   220  			return 0, fmt.Errorf("failed reading value field of %s. %w", c.rtype.Name(), err)
   221  		}
   222  		offset += n
   223  	} else {
   224  		// No value - use the nil value.
   225  		typedmemmove(unpackEFace(c.rtype.Elem()).data, val, c.vZero)
   226  	}
   227  
   228  	return offset, nil
   229  }
   230  
   231  func (c *MapCodec) readTagAndLength(data []byte, offset int) (offset2, fieldEnd, index int, wt plenccore.WireType, err error) {
   232  	wt, index, n := plenccore.ReadTag(data[offset:])
   233  	offset += n
   234  	fieldEnd = len(data)
   235  	if wt == plenccore.WTLength {
   236  		// For WTLength types we read out the length and ensure the data we
   237  		// read the field from is the right length
   238  		fieldLen, n := plenccore.ReadVarUint(data[offset:])
   239  		if n <= 0 {
   240  			return 0, 0, 0, wt, fmt.Errorf("varuint overflow reading %d of %s", index, c.rtype.Name())
   241  		}
   242  		offset += n
   243  		fieldEnd = int(fieldLen) + offset
   244  		if fieldEnd > len(data) {
   245  			return 0, 0, 0, wt, fmt.Errorf("length %d of field %d of %s exceeds data length %d", fieldLen, index, c.rtype.Name(), len(data)-offset)
   246  		}
   247  	}
   248  
   249  	return offset, fieldEnd, index, wt, nil
   250  }
   251  
   252  func (c *MapCodec) New() unsafe.Pointer {
   253  	return unsafe.Pointer(reflect.MakeMap(c.rtype).Pointer())
   254  }
   255  
   256  func (c *MapCodec) WireType() plenccore.WireType {
   257  	return plenccore.WTSlice
   258  }
   259  
   260  func (c *MapCodec) Descriptor() Descriptor {
   261  	// We treat this as a slice of structs? Perhaps need to define a map descriptor!
   262  	kDesc := c.keyCodec.Descriptor()
   263  	vDesc := c.valueCodec.Descriptor()
   264  
   265  	kDesc.Index = 1
   266  	kDesc.Name = "key"
   267  	vDesc.Index = 2
   268  	vDesc.Name = "value"
   269  
   270  	kTypeName, vTypeName := kDesc.TypeName, vDesc.TypeName
   271  	if kTypeName == "" {
   272  		kTypeName = kDesc.Type.String()
   273  	}
   274  	if vTypeName == "" {
   275  		vTypeName = vDesc.Type.String()
   276  	}
   277  
   278  	return Descriptor{
   279  		Type:        FieldTypeSlice,
   280  		LogicalType: LogicalTypeMap,
   281  		Elements: []Descriptor{
   282  			{
   283  				Type:        FieldTypeStruct,
   284  				LogicalType: LogicalTypeMapEntry,
   285  				TypeName:    fmt.Sprintf("map_%s_%s", kTypeName, vTypeName),
   286  				Elements: []Descriptor{
   287  					kDesc,
   288  					vDesc,
   289  				},
   290  			},
   291  		},
   292  	}
   293  }
   294  
   295  func (c *MapCodec) Size(ptr unsafe.Pointer, tag []byte) int {
   296  	return c.size(ptr) + len(tag)
   297  }
   298  
   299  func (c *MapCodec) Append(data []byte, ptr unsafe.Pointer, tag []byte) []byte {
   300  	data = append(data, tag...)
   301  	return c.append(data, ptr)
   302  }
   303  
   304  type ProtoMapCodec struct {
   305  	*MapCodec
   306  }
   307  
   308  func (c ProtoMapCodec) Size(ptr unsafe.Pointer, tag []byte) (size int) {
   309  	// Treat as an array of structs. Each entry carries its own tag
   310  	var iterM mapiter
   311  	iter := (unsafe.Pointer)(&iterM)
   312  	mapiterinit(unpackEFace(c.rtype).data, ptr, iter)
   313  	for {
   314  		k := mapiterkey(iter)
   315  		if k == nil {
   316  			break
   317  		}
   318  		v := mapiterelem(iter)
   319  
   320  		s := c.sizeForEntry(k, v)
   321  		size += len(tag) + plenccore.SizeVarUint(uint64(s)) + s
   322  
   323  		mapiternext(iter)
   324  	}
   325  	return size
   326  }
   327  
   328  func (c ProtoMapCodec) Append(data []byte, ptr unsafe.Pointer, tag []byte) []byte {
   329  	// Each entry is appended separately as if a struct of key & value
   330  
   331  	add := func(underlying Codec, ptr unsafe.Pointer, tag []byte) {
   332  		if !underlying.Omit(ptr) {
   333  			data = underlying.Append(data, ptr, tag)
   334  		}
   335  	}
   336  
   337  	var iterM mapiter
   338  	iter := (unsafe.Pointer)(&iterM)
   339  	mapiterinit(unpackEFace(c.rtype).data, ptr, iter)
   340  	for {
   341  		k := mapiterkey(iter)
   342  		if k == nil {
   343  			break
   344  		}
   345  		v := mapiterelem(iter)
   346  
   347  		data = append(data, tag...)
   348  		data = plenccore.AppendVarUint(data, uint64(c.sizeForEntry(k, v)))
   349  		add(c.keyCodec, k, c.keyTag)
   350  		add(c.valueCodec, v, c.valueTag)
   351  
   352  		mapiternext(iter)
   353  	}
   354  
   355  	return data
   356  }
   357  
   358  func (c ProtoMapCodec) Read(data []byte, ptr unsafe.Pointer, wt plenccore.WireType) (n int, err error) {
   359  	if len(data) == 0 {
   360  		return 0, nil
   361  	}
   362  
   363  	// ptr is a pointer to a map pointer
   364  	if *(*unsafe.Pointer)(ptr) == nil {
   365  		*(*unsafe.Pointer)(ptr) = unsafe.Pointer(reflect.MakeMap(c.rtype).Pointer())
   366  	}
   367  	mp := *(*unsafe.Pointer)(ptr)
   368  
   369  	// We need some space to hold keys and values as we read them out. We can
   370  	// re-use the space on each iteration as the data is copied into the map
   371  	// We also save some memory & time if we cache them in some pools
   372  	k := c.kPool.Get().(unsafe.Pointer)
   373  	defer c.kPool.Put(k)
   374  	return c.readMapEntry(mp, k, data)
   375  }
   376  
   377  func (c ProtoMapCodec) WireType() plenccore.WireType {
   378  	return plenccore.WTLength
   379  }