github.com/kamalshkeir/kencoding@v0.0.2-0.20230409043843-44b609a0475a/proto/custom.go (about)

     1  package proto
     2  
     3  import (
     4  	"io"
     5  	"reflect"
     6  	"unsafe"
     7  )
     8  
     9  func customCodecOf(t reflect.Type) *codec {
    10  	return &codec{
    11  		wire:   varlen,
    12  		size:   customSizeFuncOf(t),
    13  		encode: customEncodeFuncOf(t),
    14  		decode: customDecodeFuncOf(t),
    15  	}
    16  }
    17  
    18  func customSizeFuncOf(t reflect.Type) sizeFunc {
    19  	return func(p unsafe.Pointer, flags flags) int {
    20  		if p != nil {
    21  			if m := reflect.NewAt(t, p).Interface().(customMessage); m != nil {
    22  				size := m.Size()
    23  				if flags.has(toplevel) {
    24  					return size
    25  				}
    26  				return sizeOfVarlen(size)
    27  			}
    28  		}
    29  		return 0
    30  	}
    31  }
    32  
    33  func customEncodeFuncOf(t reflect.Type) encodeFunc {
    34  	return func(b []byte, p unsafe.Pointer, flags flags) (int, error) {
    35  		if p != nil {
    36  			if m := reflect.NewAt(t, p).Interface().(customMessage); m != nil {
    37  				size := m.Size()
    38  
    39  				if flags.has(toplevel) {
    40  					if len(b) < size {
    41  						return 0, io.ErrShortBuffer
    42  					}
    43  					return m.MarshalTo(b)
    44  				}
    45  
    46  				vlen := sizeOfVarlen(size)
    47  				if len(b) < vlen {
    48  					return 0, io.ErrShortBuffer
    49  				}
    50  
    51  				n1, err := encodeVarint(b, uint64(size))
    52  				if err != nil {
    53  					return n1, err
    54  				}
    55  
    56  				n2, err := m.MarshalTo(b[n1:])
    57  				return n1 + n2, err
    58  			}
    59  		}
    60  		return 0, nil
    61  	}
    62  }
    63  
    64  func customDecodeFuncOf(t reflect.Type) decodeFunc {
    65  	return func(b []byte, p unsafe.Pointer, flags flags) (int, error) {
    66  		m := reflect.NewAt(t, p).Interface().(customMessage)
    67  
    68  		if flags.has(toplevel) {
    69  			return len(b), m.Unmarshal(b)
    70  		}
    71  
    72  		v, n, err := decodeVarlen(b)
    73  		if err != nil {
    74  			return n, err
    75  		}
    76  
    77  		return n + len(v), m.Unmarshal(v)
    78  	}
    79  }