github.com/philpearl/plenc@v0.0.15/plenccodec/struct.go (about)

     1  package plenccodec
     2  
     3  import (
     4  	"fmt"
     5  	"reflect"
     6  	"strconv"
     7  	"strings"
     8  	"unicode"
     9  	"unicode/utf8"
    10  	"unsafe"
    11  
    12  	"github.com/philpearl/plenc/plenccore"
    13  )
    14  
    15  type wrappedCodecRegistry struct {
    16  	CodecRegistry
    17  	typ   reflect.Type
    18  	tag   string
    19  	codec Codec
    20  }
    21  
    22  func (w wrappedCodecRegistry) Load(typ reflect.Type, tag string) Codec {
    23  	if typ == w.typ && tag == w.tag {
    24  		return w.codec
    25  	}
    26  	return w.CodecRegistry.Load(typ, tag)
    27  }
    28  
    29  func BuildStructCodec(p CodecBuilder, registry CodecRegistry, typ reflect.Type, tag string) (Codec, error) {
    30  	if typ.Kind() != reflect.Struct {
    31  		return nil, fmt.Errorf("type must be a struct to build a struct codec")
    32  	}
    33  
    34  	c := StructCodec{
    35  		rtype:  typ,
    36  		fields: make([]description, typ.NumField()),
    37  	}
    38  
    39  	registry = wrappedCodecRegistry{CodecRegistry: registry, typ: typ, tag: tag, codec: &c}
    40  
    41  	var maxIndex int
    42  	var count int
    43  	for i := range c.fields {
    44  		sf := typ.Field(i)
    45  
    46  		r, _ := utf8.DecodeRuneInString(sf.Name)
    47  		if unicode.IsLower(r) {
    48  			continue
    49  		}
    50  
    51  		tag := sf.Tag.Get("plenc")
    52  		if tag == "" {
    53  			return nil, fmt.Errorf("no plenc tag on field %d %s of %s", i, sf.Name, typ.Name())
    54  		}
    55  		if tag == "-" {
    56  			continue
    57  		}
    58  		var postfix string
    59  		if comma := strings.IndexByte(tag, ','); comma != -1 {
    60  			postfix = tag[comma+1:]
    61  			tag = tag[:comma]
    62  		}
    63  
    64  		index, err := strconv.Atoi(tag)
    65  		if err != nil {
    66  			return nil, fmt.Errorf("could not parse plenc tag on field %d %s of %s. %w", i, sf.Name, typ.Name(), err)
    67  		}
    68  
    69  		field := &c.fields[count]
    70  		count++
    71  		field.offset = sf.Offset
    72  		field.index = index
    73  		if field.index > maxIndex {
    74  			maxIndex = field.index
    75  		}
    76  
    77  		field.name = sf.Name
    78  		if jsonName, _, _ := strings.Cut(sf.Tag.Get("json"), ","); jsonName != "" {
    79  			field.name = jsonName
    80  		}
    81  
    82  		var wantIntern bool
    83  		if postfix == "intern" {
    84  			postfix = ""
    85  			wantIntern = true
    86  		}
    87  
    88  		fc, err := p.CodecForTypeRegistry(registry, sf.Type, postfix)
    89  		if err != nil {
    90  			return nil, fmt.Errorf("failed to find codec for field %d (%s, %q) of %s. %w", i, sf.Name, postfix, typ.Name(), err)
    91  		}
    92  
    93  		if wantIntern {
    94  			if in, ok := fc.(Interner); ok {
    95  				// Note we get an independent interner for each field
    96  				fc = in.WithInterning()
    97  			}
    98  		}
    99  
   100  		field.codec = fc
   101  		field.tag = plenccore.AppendTag(nil, fc.WireType(), field.index)
   102  		if sf.Type.Kind() == reflect.Map {
   103  			field.deref = true
   104  		}
   105  	}
   106  	c.fields = c.fields[:count]
   107  
   108  	c.fieldsByIndex = make([]shortDesc, maxIndex+1)
   109  	for _, f := range c.fields {
   110  		if c.fieldsByIndex[f.index].codec != nil {
   111  			return nil, fmt.Errorf("failed building codec for %s. Multiple fields have index %d", typ.Name(), f.index)
   112  		}
   113  		c.fieldsByIndex[f.index] = shortDesc{
   114  			codec:  f.codec,
   115  			offset: f.offset,
   116  		}
   117  	}
   118  
   119  	return &c, nil
   120  }
   121  
   122  type description struct {
   123  	offset uintptr
   124  	codec  Codec
   125  	index  int
   126  	tag    []byte
   127  	deref  bool
   128  	name   string
   129  }
   130  
   131  type shortDesc struct {
   132  	codec  Codec
   133  	offset uintptr
   134  }
   135  
   136  type StructCodec struct {
   137  	rtype         reflect.Type
   138  	fields        []description
   139  	fieldsByIndex []shortDesc
   140  }
   141  
   142  func (c *StructCodec) Omit(ptr unsafe.Pointer) bool {
   143  	return false
   144  }
   145  
   146  func (c *StructCodec) size(ptr unsafe.Pointer) (size int) {
   147  	for _, field := range c.fields {
   148  		// For most fields we have a pointer to the type, and this is the same
   149  		// when we call these functions for types within structs or when we
   150  		// pass an interface to Marshal. But maps are kind of pointers and
   151  		// kind of not. When passed to Marshal via interfaces we get passed
   152  		// the underlying map pointer. But when the map is in a struct, we
   153  		// have a pointer to the underlying map pointer
   154  		fptr := unsafe.Pointer(uintptr(ptr) + field.offset)
   155  		if field.deref {
   156  			fptr = *(*unsafe.Pointer)(fptr)
   157  		}
   158  		if !field.codec.Omit(fptr) {
   159  			size += field.codec.Size(fptr, field.tag)
   160  		}
   161  	}
   162  	return size
   163  }
   164  
   165  func (c *StructCodec) append(data []byte, ptr unsafe.Pointer) []byte {
   166  	for _, field := range c.fields {
   167  		fptr := unsafe.Pointer(uintptr(ptr) + field.offset)
   168  		if field.deref {
   169  			fptr = *(*unsafe.Pointer)(fptr)
   170  		}
   171  		if field.codec.Omit(fptr) {
   172  			continue
   173  		}
   174  		data = field.codec.Append(data, fptr, field.tag)
   175  	}
   176  
   177  	return data
   178  }
   179  
   180  func (c *StructCodec) Read(data []byte, ptr unsafe.Pointer, wt plenccore.WireType) (n int, err error) {
   181  	l := len(data)
   182  
   183  	var offset int
   184  	for offset < l {
   185  		wt, index, n := plenccore.ReadTag(data[offset:])
   186  		offset += n
   187  
   188  		if index >= len(c.fieldsByIndex) || c.fieldsByIndex[index].codec == nil {
   189  			// Field corresponding to index does not exist
   190  			n, err := plenccore.Skip(data[offset:], wt)
   191  			if err != nil {
   192  				return 0, fmt.Errorf("failed to skip field %d in %s. %w", index, c.rtype.Name(), err)
   193  			}
   194  			offset += n
   195  			continue
   196  		}
   197  
   198  		fl := l
   199  		if wt == plenccore.WTLength {
   200  			// For WTLength types we read out the length and ensure the data we
   201  			// read the field from is the right length
   202  			v, n := plenccore.ReadVarUint(data[offset:])
   203  			if n <= 0 {
   204  				return 0, fmt.Errorf("varuint overflow reading field %d of %s", index, c.rtype.Name())
   205  			}
   206  			offset += n
   207  			fl = int(v) + offset
   208  			if fl > l {
   209  				return 0, fmt.Errorf("length %d of field %d of %s exceeds data length", fl, index, c.rtype.Name())
   210  			}
   211  		}
   212  
   213  		d := c.fieldsByIndex[index]
   214  		n, err := d.codec.Read(data[offset:fl], unsafe.Pointer(uintptr(ptr)+d.offset), wt)
   215  		if err != nil {
   216  			return 0, fmt.Errorf("failed reading field %d of %s. %w", index, c.rtype.Name(), err)
   217  		}
   218  		offset += n
   219  	}
   220  
   221  	return offset, nil
   222  }
   223  
   224  func (c *StructCodec) New() unsafe.Pointer {
   225  	return unsafe.Pointer(reflect.New(c.rtype).Pointer())
   226  }
   227  
   228  func (c *StructCodec) WireType() plenccore.WireType {
   229  	return plenccore.WTLength
   230  }
   231  
   232  func (c *StructCodec) Descriptor() Descriptor {
   233  	var d Descriptor
   234  	d.Type = FieldTypeStruct
   235  	d.TypeName = c.rtype.Name()
   236  	d.Elements = make([]Descriptor, len(c.fields))
   237  	for i, f := range c.fields {
   238  		d.Elements[i] = f.codec.Descriptor()
   239  		d.Elements[i].Index = f.index
   240  		d.Elements[i].Name = f.name
   241  	}
   242  	return d
   243  }
   244  
   245  func (c *StructCodec) Size(ptr unsafe.Pointer, tag []byte) int {
   246  	l := c.size(ptr)
   247  	if len(tag) != 0 {
   248  		l += len(tag) + plenccore.SizeVarUint(uint64(l))
   249  	}
   250  	return l
   251  }
   252  
   253  func (c *StructCodec) Append(data []byte, ptr unsafe.Pointer, tag []byte) []byte {
   254  	if len(tag) != 0 {
   255  		data = append(data, tag...)
   256  		data = plenccore.AppendVarUint(data, uint64(c.size(ptr)))
   257  	}
   258  	return c.append(data, ptr)
   259  }