github.com/hamba/avro/v2@v2.22.1-0.20240518180522-aff3955acf7d/registry/decoder.go (about)

     1  package registry
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  
     9  	"github.com/hamba/avro/v2"
    10  )
    11  
    12  // DecoderFunc is a function used to customize the Decoder.
    13  type DecoderFunc func(*Decoder)
    14  
    15  // WithAPI sets the avro configuration on the decoder.
    16  func WithAPI(api avro.API) DecoderFunc {
    17  	return func(d *Decoder) {
    18  		d.api = api
    19  	}
    20  }
    21  
    22  // Decoder decodes confluent wire formatted avro payloads.
    23  type Decoder struct {
    24  	client *Client
    25  	api    avro.API
    26  }
    27  
    28  // NewDecoder returns a decoder that will get schemas from client.
    29  func NewDecoder(client *Client, opts ...DecoderFunc) *Decoder {
    30  	d := &Decoder{
    31  		client: client,
    32  		api:    avro.DefaultConfig,
    33  	}
    34  	for _, opt := range opts {
    35  		opt(d)
    36  	}
    37  	return d
    38  }
    39  
    40  // Decode decodes data into v.
    41  // The data must be formatted using the Confluent wire format, otherwise
    42  // and error will be returned.
    43  // See:
    44  // https://docs.confluent.io/3.2.0/schema-registry/docs/serializer-formatter.html#wire-format.
    45  func (d *Decoder) Decode(ctx context.Context, data []byte, v any) error {
    46  	if len(data) < 6 {
    47  		return errors.New("data too short")
    48  	}
    49  
    50  	id, err := extractSchemaID(data)
    51  	if err != nil {
    52  		return fmt.Errorf("extracting schema id: %w", err)
    53  	}
    54  
    55  	schema, err := d.client.GetSchema(ctx, id)
    56  	if err != nil {
    57  		return fmt.Errorf("getting schema: %w", err)
    58  	}
    59  
    60  	return d.api.Unmarshal(schema, data[5:], v)
    61  }
    62  
    63  func extractSchemaID(data []byte) (int, error) {
    64  	if len(data) < 5 {
    65  		return 0, errors.New("data too short")
    66  	}
    67  	if data[0] != 0 {
    68  		return 0, fmt.Errorf("invalid magic byte: %x", data[0])
    69  	}
    70  	return int(binary.BigEndian.Uint32(data[1:5])), nil
    71  }