github.com/MetalBlockchain/metalgo@v1.11.9/codec/hierarchycodec/codec.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package hierarchycodec
     5  
     6  import (
     7  	"fmt"
     8  	"reflect"
     9  	"sync"
    10  
    11  	"github.com/MetalBlockchain/metalgo/codec"
    12  	"github.com/MetalBlockchain/metalgo/codec/reflectcodec"
    13  	"github.com/MetalBlockchain/metalgo/utils/bimap"
    14  	"github.com/MetalBlockchain/metalgo/utils/wrappers"
    15  )
    16  
    17  var (
    18  	_ Codec              = (*hierarchyCodec)(nil)
    19  	_ codec.Codec        = (*hierarchyCodec)(nil)
    20  	_ codec.Registry     = (*hierarchyCodec)(nil)
    21  	_ codec.GeneralCodec = (*hierarchyCodec)(nil)
    22  )
    23  
    24  // Codec marshals and unmarshals
    25  type Codec interface {
    26  	codec.Registry
    27  	codec.Codec
    28  	SkipRegistrations(int)
    29  	NextGroup()
    30  }
    31  
    32  type typeID struct {
    33  	groupID uint16
    34  	typeID  uint16
    35  }
    36  
    37  // Codec handles marshaling and unmarshaling of structs
    38  type hierarchyCodec struct {
    39  	codec.Codec
    40  
    41  	lock            sync.RWMutex
    42  	currentGroupID  uint16
    43  	nextTypeID      uint16
    44  	registeredTypes *bimap.BiMap[typeID, reflect.Type]
    45  }
    46  
    47  // New returns a new, concurrency-safe codec
    48  func New(tagNames []string) Codec {
    49  	hCodec := &hierarchyCodec{
    50  		currentGroupID:  0,
    51  		nextTypeID:      0,
    52  		registeredTypes: bimap.New[typeID, reflect.Type](),
    53  	}
    54  	hCodec.Codec = reflectcodec.New(hCodec, tagNames)
    55  	return hCodec
    56  }
    57  
    58  // NewDefault returns a new codec with reasonable default values
    59  func NewDefault() Codec {
    60  	return New([]string{reflectcodec.DefaultTagName})
    61  }
    62  
    63  // SkipRegistrations some number of type IDs
    64  func (c *hierarchyCodec) SkipRegistrations(num int) {
    65  	c.lock.Lock()
    66  	c.nextTypeID += uint16(num)
    67  	c.lock.Unlock()
    68  }
    69  
    70  // NextGroup moves to the next group registry
    71  func (c *hierarchyCodec) NextGroup() {
    72  	c.lock.Lock()
    73  	c.currentGroupID++
    74  	c.nextTypeID = 0
    75  	c.lock.Unlock()
    76  }
    77  
    78  // RegisterType is used to register types that may be unmarshaled into an interface
    79  // [val] is a value of the type being registered
    80  func (c *hierarchyCodec) RegisterType(val interface{}) error {
    81  	c.lock.Lock()
    82  	defer c.lock.Unlock()
    83  
    84  	valType := reflect.TypeOf(val)
    85  	if c.registeredTypes.HasValue(valType) {
    86  		return fmt.Errorf("%w: %v", codec.ErrDuplicateType, valType)
    87  	}
    88  
    89  	valTypeID := typeID{
    90  		groupID: c.currentGroupID,
    91  		typeID:  c.nextTypeID,
    92  	}
    93  	c.nextTypeID++
    94  
    95  	c.registeredTypes.Put(valTypeID, valType)
    96  	return nil
    97  }
    98  
    99  func (*hierarchyCodec) PrefixSize(reflect.Type) int {
   100  	// see PackPrefix implementation
   101  	return wrappers.ShortLen + wrappers.ShortLen
   102  }
   103  
   104  func (c *hierarchyCodec) PackPrefix(p *wrappers.Packer, valueType reflect.Type) error {
   105  	c.lock.RLock()
   106  	defer c.lock.RUnlock()
   107  
   108  	typeID, ok := c.registeredTypes.GetKey(valueType) // Get the type ID of the value being marshaled
   109  	if !ok {
   110  		return fmt.Errorf("can't marshal unregistered type %q", valueType)
   111  	}
   112  	// Pack type ID so we know what to unmarshal this into
   113  	p.PackShort(typeID.groupID)
   114  	p.PackShort(typeID.typeID)
   115  	return p.Err
   116  }
   117  
   118  func (c *hierarchyCodec) UnpackPrefix(p *wrappers.Packer, valueType reflect.Type) (reflect.Value, error) {
   119  	c.lock.RLock()
   120  	defer c.lock.RUnlock()
   121  
   122  	groupID := p.UnpackShort()     // Get the group ID
   123  	typeIDShort := p.UnpackShort() // Get the type ID
   124  	if p.Err != nil {
   125  		return reflect.Value{}, fmt.Errorf("couldn't unmarshal interface: %w", p.Err)
   126  	}
   127  	t := typeID{
   128  		groupID: groupID,
   129  		typeID:  typeIDShort,
   130  	}
   131  	// Get a type that implements the interface
   132  	implementingType, ok := c.registeredTypes.GetValue(t)
   133  	if !ok {
   134  		return reflect.Value{}, fmt.Errorf("couldn't unmarshal interface: unknown type ID %+v", t)
   135  	}
   136  	// Ensure type actually does implement the interface
   137  	if !implementingType.Implements(valueType) {
   138  		return reflect.Value{}, fmt.Errorf("couldn't unmarshal interface: %s %w %s",
   139  			implementingType,
   140  			codec.ErrDoesNotImplementInterface,
   141  			valueType,
   142  		)
   143  	}
   144  	return reflect.New(implementingType).Elem(), nil // instance of the proper type
   145  }