github.com/segmentio/encoding@v0.4.0/proto/map.go (about)

     1  package proto
     2  
     3  import (
     4  	"io"
     5  	"reflect"
     6  	"sync"
     7  	"unsafe"
     8  
     9  	. "github.com/segmentio/encoding/internal/runtime_reflect"
    10  )
    11  
    12  const (
    13  	zeroSize = 1 // sizeOfVarint(0)
    14  )
    15  
    16  type mapField struct {
    17  	number   uint16
    18  	keyFlags uint8
    19  	valFlags uint8
    20  	keyCodec *codec
    21  	valCodec *codec
    22  }
    23  
    24  func mapCodecOf(t reflect.Type, f *mapField, seen map[reflect.Type]*codec) *codec {
    25  	m := new(codec)
    26  	seen[t] = m
    27  
    28  	m.wire = varlen
    29  	m.size = mapSizeFuncOf(t, f)
    30  	m.encode = mapEncodeFuncOf(t, f)
    31  	m.decode = mapDecodeFuncOf(t, f, seen)
    32  	return m
    33  }
    34  
    35  func mapSizeFuncOf(t reflect.Type, f *mapField) sizeFunc {
    36  	mapTagSize := sizeOfTag(fieldNumber(f.number), varlen)
    37  	keyTagSize := sizeOfTag(1, wireType(f.keyCodec.wire))
    38  	valTagSize := sizeOfTag(2, wireType(f.valCodec.wire))
    39  	return func(p unsafe.Pointer, flags flags) int {
    40  		if p == nil {
    41  			return 0
    42  		}
    43  
    44  		if !flags.has(inline) {
    45  			p = *(*unsafe.Pointer)(p)
    46  		}
    47  
    48  		n := 0
    49  		m := MapIter{}
    50  		defer m.Done()
    51  
    52  		for m.Init(pointer(t), p); m.HasNext(); m.Next() {
    53  			keySize := f.keyCodec.size(m.Key(), wantzero)
    54  			valSize := f.valCodec.size(m.Value(), wantzero)
    55  
    56  			if keySize > 0 {
    57  				n += keyTagSize + keySize
    58  				if (f.keyFlags & embedded) != 0 {
    59  					n += sizeOfVarint(uint64(keySize))
    60  				}
    61  			}
    62  
    63  			if valSize > 0 {
    64  				n += valTagSize + valSize
    65  				if (f.valFlags & embedded) != 0 {
    66  					n += sizeOfVarint(uint64(valSize))
    67  				}
    68  			}
    69  
    70  			n += mapTagSize + sizeOfVarint(uint64(keySize+valSize))
    71  		}
    72  
    73  		if n == 0 {
    74  			n = mapTagSize + zeroSize
    75  		}
    76  
    77  		return n
    78  	}
    79  }
    80  
    81  func mapEncodeFuncOf(t reflect.Type, f *mapField) encodeFunc {
    82  	keyTag := [1]byte{}
    83  	valTag := [1]byte{}
    84  	encodeTag(keyTag[:], 1, f.keyCodec.wire)
    85  	encodeTag(valTag[:], 2, f.valCodec.wire)
    86  
    87  	number := fieldNumber(f.number)
    88  	mapTag := make([]byte, sizeOfTag(number, varlen)+zeroSize)
    89  	encodeTag(mapTag, number, varlen)
    90  
    91  	zero := mapTag
    92  	mapTag = mapTag[:len(mapTag)-1]
    93  
    94  	return func(b []byte, p unsafe.Pointer, flags flags) (int, error) {
    95  		if p == nil {
    96  			return 0, nil
    97  		}
    98  
    99  		if !flags.has(inline) {
   100  			p = *(*unsafe.Pointer)(p)
   101  		}
   102  
   103  		offset := 0
   104  		m := MapIter{}
   105  		defer m.Done()
   106  
   107  		for m.Init(pointer(t), p); m.HasNext(); m.Next() {
   108  			key := m.Key()
   109  			val := m.Value()
   110  
   111  			keySize := f.keyCodec.size(key, wantzero)
   112  			valSize := f.valCodec.size(val, wantzero)
   113  			elemSize := keySize + valSize
   114  
   115  			if keySize > 0 {
   116  				elemSize += len(keyTag)
   117  				if (f.keyFlags & embedded) != 0 {
   118  					elemSize += sizeOfVarint(uint64(keySize))
   119  				}
   120  			}
   121  
   122  			if valSize > 0 {
   123  				elemSize += len(valTag)
   124  				if (f.valFlags & embedded) != 0 {
   125  					elemSize += sizeOfVarint(uint64(valSize))
   126  				}
   127  			}
   128  
   129  			n := copy(b[offset:], mapTag)
   130  			offset += n
   131  			if n < len(mapTag) {
   132  				return offset, io.ErrShortBuffer
   133  			}
   134  			n, err := encodeVarint(b[offset:], uint64(elemSize))
   135  			offset += n
   136  			if err != nil {
   137  				return offset, err
   138  			}
   139  
   140  			if keySize > 0 {
   141  				n := copy(b[offset:], keyTag[:])
   142  				offset += n
   143  				if n < len(keyTag) {
   144  					return offset, io.ErrShortBuffer
   145  				}
   146  
   147  				if (f.keyFlags & embedded) != 0 {
   148  					n, err := encodeVarint(b[offset:], uint64(keySize))
   149  					offset += n
   150  					if err != nil {
   151  						return offset, err
   152  					}
   153  				}
   154  
   155  				if (len(b) - offset) < keySize {
   156  					return len(b), io.ErrShortBuffer
   157  				}
   158  
   159  				n, err := f.keyCodec.encode(b[offset:offset+keySize], key, wantzero)
   160  				offset += n
   161  				if err != nil {
   162  					return offset, err
   163  				}
   164  			}
   165  
   166  			if valSize > 0 {
   167  				n := copy(b[offset:], valTag[:])
   168  				offset += n
   169  				if n < len(valTag) {
   170  					return n, io.ErrShortBuffer
   171  				}
   172  
   173  				if (f.valFlags & embedded) != 0 {
   174  					n, err := encodeVarint(b[offset:], uint64(valSize))
   175  					offset += n
   176  					if err != nil {
   177  						return offset, err
   178  					}
   179  				}
   180  
   181  				if (len(b) - offset) < valSize {
   182  					return len(b), io.ErrShortBuffer
   183  				}
   184  
   185  				n, err := f.valCodec.encode(b[offset:offset+valSize], val, wantzero)
   186  				offset += n
   187  				if err != nil {
   188  					return offset, err
   189  				}
   190  			}
   191  		}
   192  
   193  		if offset == 0 {
   194  			if offset = copy(b, zero); offset < len(zero) {
   195  				return offset, io.ErrShortBuffer
   196  			}
   197  		}
   198  
   199  		return offset, nil
   200  	}
   201  }
   202  
   203  func mapDecodeFuncOf(t reflect.Type, f *mapField, seen map[reflect.Type]*codec) decodeFunc {
   204  	structType := reflect.StructOf([]reflect.StructField{
   205  		{Name: "Key", Type: t.Key()},
   206  		{Name: "Elem", Type: t.Elem()},
   207  	})
   208  
   209  	structCodec := codecOf(structType, seen)
   210  	structPool := new(sync.Pool)
   211  	structZero := pointer(reflect.Zero(structType).Interface())
   212  
   213  	valueType := t.Elem()
   214  	valueOffset := structType.Field(1).Offset
   215  
   216  	mtype := pointer(t)
   217  	stype := pointer(structType)
   218  	vtype := pointer(valueType)
   219  
   220  	return func(b []byte, p unsafe.Pointer, _ flags) (int, error) {
   221  		m := (*unsafe.Pointer)(p)
   222  		if *m == nil {
   223  			*m = MakeMap(mtype, 10)
   224  		}
   225  		if len(b) == 0 {
   226  			return 0, nil
   227  		}
   228  
   229  		s := pointer(structPool.Get())
   230  		if s == nil {
   231  			s = unsafe.Pointer(reflect.New(structType).Pointer())
   232  		}
   233  
   234  		n, err := structCodec.decode(b, s, noflags)
   235  		if err == nil {
   236  			v := MapAssign(mtype, *m, s)
   237  			Assign(vtype, v, unsafe.Pointer(uintptr(s)+valueOffset))
   238  		}
   239  
   240  		Assign(stype, s, structZero)
   241  		structPool.Put(s)
   242  		return n, err
   243  	}
   244  }