github.com/kamalshkeir/kencoding@v0.0.2-0.20230409043843-44b609a0475a/proto/slice.go (about)

     1  package proto
     2  
     3  import (
     4  	"io"
     5  	"reflect"
     6  	"unsafe"
     7  
     8  	. "github.com/kamalshkeir/kencoding/internal/runtime_reflect"
     9  )
    10  
    11  type repeatedField struct {
    12  	codec       *codec
    13  	fieldNumber fieldNumber
    14  	wireType    wireType
    15  	embedded    bool
    16  }
    17  
    18  func sliceCodecOf(t reflect.Type, f structField, seen map[reflect.Type]*codec) *codec {
    19  	s := new(codec)
    20  	seen[t] = s
    21  
    22  	r := &repeatedField{
    23  		codec:       f.codec,
    24  		fieldNumber: f.fieldNumber(),
    25  		wireType:    f.wireType(),
    26  		embedded:    f.embedded(),
    27  	}
    28  
    29  	s.wire = f.codec.wire
    30  	s.size = sliceSizeFuncOf(t, r)
    31  	s.encode = sliceEncodeFuncOf(t, r)
    32  	s.decode = sliceDecodeFuncOf(t, r)
    33  	return s
    34  }
    35  
    36  func sliceSizeFuncOf(t reflect.Type, r *repeatedField) sizeFunc {
    37  	elemSize := alignedSize(t.Elem())
    38  	tagSize := sizeOfTag(r.fieldNumber, r.wireType)
    39  	return func(p unsafe.Pointer, _ flags) int {
    40  		n := 0
    41  
    42  		if v := (*Slice)(p); v != nil {
    43  			for i := 0; i < v.Len(); i++ {
    44  				elem := v.Index(i, elemSize)
    45  				size := r.codec.size(elem, wantzero)
    46  				n += tagSize + size
    47  				if r.embedded {
    48  					n += sizeOfVarint(uint64(size))
    49  				}
    50  			}
    51  		}
    52  
    53  		return n
    54  	}
    55  }
    56  
    57  func sliceEncodeFuncOf(t reflect.Type, r *repeatedField) encodeFunc {
    58  	elemSize := alignedSize(t.Elem())
    59  	tagSize := sizeOfTag(r.fieldNumber, r.wireType)
    60  	tagData := make([]byte, tagSize)
    61  	encodeTag(tagData, r.fieldNumber, r.wireType)
    62  	return func(b []byte, p unsafe.Pointer, _ flags) (int, error) {
    63  		offset := 0
    64  
    65  		if s := (*Slice)(p); s != nil {
    66  			for i := 0; i < s.Len(); i++ {
    67  				elem := s.Index(i, elemSize)
    68  				size := r.codec.size(elem, wantzero)
    69  
    70  				n := copy(b[offset:], tagData)
    71  				offset += n
    72  				if n < len(tagData) {
    73  					return offset, io.ErrShortBuffer
    74  				}
    75  
    76  				if r.embedded {
    77  					n, err := encodeVarint(b[offset:], uint64(size))
    78  					offset += n
    79  					if err != nil {
    80  						return offset, err
    81  					}
    82  				}
    83  
    84  				if (len(b) - offset) < size {
    85  					return len(b), io.ErrShortBuffer
    86  				}
    87  
    88  				n, err := r.codec.encode(b[offset:offset+size], elem, wantzero)
    89  				offset += n
    90  				if err != nil {
    91  					return offset, err
    92  				}
    93  			}
    94  		}
    95  
    96  		return offset, nil
    97  	}
    98  }
    99  
   100  func sliceDecodeFuncOf(t reflect.Type, r *repeatedField) decodeFunc {
   101  	elemType := t.Elem()
   102  	elemSize := alignedSize(elemType)
   103  	return func(b []byte, p unsafe.Pointer, _ flags) (int, error) {
   104  		s := (*Slice)(p)
   105  		i := s.Len()
   106  
   107  		if i == s.Cap() {
   108  			*s = growSlice(elemType, s)
   109  		}
   110  
   111  		n, err := r.codec.decode(b, s.Index(i, elemSize), noflags)
   112  		if err == nil {
   113  			s.SetLen(i + 1)
   114  		}
   115  		return n, err
   116  	}
   117  }
   118  
   119  func alignedSize(t reflect.Type) uintptr {
   120  	a := t.Align()
   121  	s := t.Size()
   122  	return align(uintptr(a), uintptr(s))
   123  }
   124  
   125  func align(align, size uintptr) uintptr {
   126  	if align != 0 && (size%align) != 0 {
   127  		size = ((size / align) + 1) * align
   128  	}
   129  	return size
   130  }
   131  
   132  func growSlice(t reflect.Type, s *Slice) Slice {
   133  	cap := 2 * s.Cap()
   134  	if cap == 0 {
   135  		cap = 10
   136  	}
   137  	p := pointer(t)
   138  	d := MakeSlice(p, s.Len(), cap)
   139  	CopySlice(p, d, *s)
   140  	return d
   141  }