github.com/aacfactory/avro@v1.2.12/internal/base/protocol.go (about)

     1  package base
     2  
     3  import (
     4  	"crypto/md5"
     5  	"encoding/hex"
     6  	"errors"
     7  	"fmt"
     8  	"os"
     9  
    10  	jsoniter "github.com/json-iterator/go"
    11  	"github.com/mitchellh/mapstructure"
    12  )
    13  
    14  var (
    15  	protocolReserved = []string{"doc", "types", "messages", "protocol", "namespace"}
    16  	messageReserved  = []string{"doc", "response", "request", "errors", "one-way"}
    17  )
    18  
    19  type protocolConfig struct {
    20  	doc   string
    21  	props map[string]any
    22  }
    23  
    24  // ProtocolOption is a function that sets a protocol option.
    25  type ProtocolOption func(*protocolConfig)
    26  
    27  // WithProtoDoc sets the doc on a protocol.
    28  func WithProtoDoc(doc string) ProtocolOption {
    29  	return func(opts *protocolConfig) {
    30  		opts.doc = doc
    31  	}
    32  }
    33  
    34  // WithProtoProps sets the properties on a protocol.
    35  func WithProtoProps(props map[string]any) ProtocolOption {
    36  	return func(opts *protocolConfig) {
    37  		opts.props = props
    38  	}
    39  }
    40  
    41  // Protocol is an Avro protocol.
    42  type Protocol struct {
    43  	name
    44  	properties
    45  
    46  	types    []NamedSchema
    47  	messages map[string]*Message
    48  
    49  	doc string
    50  
    51  	hash string
    52  }
    53  
    54  // NewProtocol creates a protocol instance.
    55  func NewProtocol(
    56  	name, namepsace string,
    57  	types []NamedSchema,
    58  	messages map[string]*Message,
    59  	opts ...ProtocolOption,
    60  ) (*Protocol, error) {
    61  	var cfg protocolConfig
    62  	for _, opt := range opts {
    63  		opt(&cfg)
    64  	}
    65  
    66  	n, err := newName(name, namepsace, nil)
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  
    71  	p := &Protocol{
    72  		name:       n,
    73  		properties: newProperties(cfg.props, protocolReserved),
    74  		types:      types,
    75  		messages:   messages,
    76  		doc:        cfg.doc,
    77  	}
    78  
    79  	b := md5.Sum([]byte(p.String()))
    80  	p.hash = hex.EncodeToString(b[:])
    81  
    82  	return p, nil
    83  }
    84  
    85  // Message returns a message with the given name or nil.
    86  func (p *Protocol) Message(name string) *Message {
    87  	return p.messages[name]
    88  }
    89  
    90  // Doc returns the protocol doc.
    91  func (p *Protocol) Doc() string {
    92  	return p.doc
    93  }
    94  
    95  // Hash returns the MD5 hash of the protocol.
    96  func (p *Protocol) Hash() string {
    97  	return p.hash
    98  }
    99  
   100  // Types returns the types of the protocol.
   101  func (p *Protocol) Types() []NamedSchema {
   102  	return p.types
   103  }
   104  
   105  // String returns the canonical form of the protocol.
   106  func (p *Protocol) String() string {
   107  	types := ""
   108  	for _, f := range p.types {
   109  		types += f.String() + ","
   110  	}
   111  	if len(types) > 0 {
   112  		types = types[:len(types)-1]
   113  	}
   114  
   115  	messages := ""
   116  	for k, m := range p.messages {
   117  		messages += `"` + k + `":` + m.String() + ","
   118  	}
   119  	if len(messages) > 0 {
   120  		messages = messages[:len(messages)-1]
   121  	}
   122  
   123  	return `{"protocol":"` + p.Name() +
   124  		`","namespace":"` + p.Namespace() +
   125  		`","types":[` + types + `],"messages":{` + messages + `}}`
   126  }
   127  
   128  // Message is an Avro protocol message.
   129  type Message struct {
   130  	properties
   131  
   132  	req    *RecordSchema
   133  	resp   Schema
   134  	errs   *UnionSchema
   135  	oneWay bool
   136  
   137  	doc string
   138  }
   139  
   140  // NewMessage creates a protocol message instance.
   141  func NewMessage(req *RecordSchema, resp Schema, errors *UnionSchema, oneWay bool, opts ...ProtocolOption) *Message {
   142  	var cfg protocolConfig
   143  	for _, opt := range opts {
   144  		opt(&cfg)
   145  	}
   146  
   147  	return &Message{
   148  		properties: newProperties(cfg.props, messageReserved),
   149  		req:        req,
   150  		resp:       resp,
   151  		errs:       errors,
   152  		oneWay:     oneWay,
   153  		doc:        cfg.doc,
   154  	}
   155  }
   156  
   157  // Request returns the message request schema.
   158  func (m *Message) Request() *RecordSchema {
   159  	return m.req
   160  }
   161  
   162  // Response returns the message response schema.
   163  func (m *Message) Response() Schema {
   164  	return m.resp
   165  }
   166  
   167  // Errors returns the message errors union schema.
   168  func (m *Message) Errors() *UnionSchema {
   169  	return m.errs
   170  }
   171  
   172  // OneWay determines of the message is a one way message.
   173  func (m *Message) OneWay() bool {
   174  	return m.oneWay
   175  }
   176  
   177  // Doc returns the message doc.
   178  func (m *Message) Doc() string {
   179  	return m.doc
   180  }
   181  
   182  // String returns the canonical form of the message.
   183  func (m *Message) String() string {
   184  	fields := ""
   185  	for _, f := range m.req.fields {
   186  		fields += f.String() + ","
   187  	}
   188  	if len(fields) > 0 {
   189  		fields = fields[:len(fields)-1]
   190  	}
   191  
   192  	str := `{"request":[` + fields + `]`
   193  	if m.resp != nil {
   194  		str += `,"response":` + m.resp.String()
   195  	}
   196  	if m.errs != nil && len(m.errs.Types()) > 1 {
   197  		errs, _ := NewUnionSchema(m.errs.Types()[1:])
   198  		str += `,"errors":` + errs.String()
   199  	}
   200  	str += "}"
   201  	return str
   202  }
   203  
   204  // ParseProtocolFile parses an Avro protocol from a file.
   205  func ParseProtocolFile(path string) (*Protocol, error) {
   206  	s, err := os.ReadFile(path)
   207  	if err != nil {
   208  		return nil, err
   209  	}
   210  
   211  	return ParseProtocol(string(s))
   212  }
   213  
   214  // MustParseProtocol parses an Avro protocol, panicing if there is an error.
   215  func MustParseProtocol(protocol string) *Protocol {
   216  	parsed, err := ParseProtocol(protocol)
   217  	if err != nil {
   218  		panic(err)
   219  	}
   220  
   221  	return parsed
   222  }
   223  
   224  // ParseProtocol parses an Avro protocol.
   225  func ParseProtocol(protocol string) (*Protocol, error) {
   226  	cache := &SchemaCache{}
   227  
   228  	var m map[string]any
   229  	if err := jsoniter.Unmarshal([]byte(protocol), &m); err != nil {
   230  		return nil, err
   231  	}
   232  
   233  	seen := seenCache{}
   234  	return parseProtocol(m, seen, cache)
   235  }
   236  
   237  type protocol struct {
   238  	Protocol  string                    `mapstructure:"protocol"`
   239  	Namespace string                    `mapstructure:"namespace"`
   240  	Doc       string                    `mapstructure:"doc"`
   241  	Types     []any                     `mapstructure:"types"`
   242  	Messages  map[string]map[string]any `mapstructure:"messages"`
   243  	Props     map[string]any            `mapstructure:",remain"`
   244  }
   245  
   246  func parseProtocol(m map[string]any, seen seenCache, cache *SchemaCache) (*Protocol, error) {
   247  	var (
   248  		p    protocol
   249  		meta mapstructure.Metadata
   250  	)
   251  	if err := decodeMap(m, &p, &meta); err != nil {
   252  		return nil, fmt.Errorf("avro: error decoding protocol: %w", err)
   253  	}
   254  
   255  	if err := checkParsedName(p.Protocol, p.Namespace, hasKey(meta.Keys, "namespace")); err != nil {
   256  		return nil, err
   257  	}
   258  
   259  	var (
   260  		types []NamedSchema
   261  		err   error
   262  	)
   263  	if len(p.Types) > 0 {
   264  		types, err = parseProtocolTypes(p.Namespace, p.Types, seen, cache)
   265  		if err != nil {
   266  			return nil, err
   267  		}
   268  	}
   269  
   270  	messages := map[string]*Message{}
   271  	if len(p.Messages) > 0 {
   272  		for k, msg := range p.Messages {
   273  			message, err := parseMessage(p.Namespace, msg, seen, cache)
   274  			if err != nil {
   275  				return nil, err
   276  			}
   277  
   278  			messages[k] = message
   279  		}
   280  	}
   281  
   282  	return NewProtocol(p.Protocol, p.Namespace, types, messages, WithProtoDoc(p.Doc), WithProtoProps(p.Props))
   283  }
   284  
   285  func parseProtocolTypes(namespace string, types []any, seen seenCache, cache *SchemaCache) ([]NamedSchema, error) {
   286  	ts := make([]NamedSchema, len(types))
   287  	for i, typ := range types {
   288  		schema, err := parseType(namespace, typ, seen, cache)
   289  		if err != nil {
   290  			return nil, err
   291  		}
   292  
   293  		namedSchema, ok := schema.(NamedSchema)
   294  		if !ok {
   295  			return nil, errors.New("avro: protocol types must be named schemas")
   296  		}
   297  
   298  		ts[i] = namedSchema
   299  	}
   300  
   301  	return ts, nil
   302  }
   303  
   304  type message struct {
   305  	Doc      string           `mapstructure:"doc"`
   306  	Request  []map[string]any `mapstructure:"request"`
   307  	Response any              `mapstructure:"response"`
   308  	Errors   []any            `mapstructure:"errors"`
   309  	OneWay   bool             `mapstructure:"one-way"`
   310  	Props    map[string]any   `mapstructure:",remain"`
   311  }
   312  
   313  func parseMessage(namespace string, m map[string]any, seen seenCache, cache *SchemaCache) (*Message, error) {
   314  	var (
   315  		msg  message
   316  		meta mapstructure.Metadata
   317  	)
   318  	if err := decodeMap(m, &msg, &meta); err != nil {
   319  		return nil, fmt.Errorf("avro: error decoding message: %w", err)
   320  	}
   321  
   322  	fields := make([]*Field, len(msg.Request))
   323  	for i, f := range msg.Request {
   324  		field, err := parseField(namespace, f, seen, cache)
   325  		if err != nil {
   326  			return nil, err
   327  		}
   328  		fields[i] = field
   329  	}
   330  	request := &RecordSchema{
   331  		name:       name{},
   332  		properties: properties{},
   333  		fields:     fields,
   334  	}
   335  
   336  	var response Schema
   337  	if msg.Response != nil {
   338  		schema, err := parseType(namespace, msg.Response, seen, cache)
   339  		if err != nil {
   340  			return nil, err
   341  		}
   342  
   343  		if schema.Type() != Null {
   344  			response = schema
   345  		}
   346  	}
   347  
   348  	types := []Schema{NewPrimitiveSchema(String, nil)}
   349  	if len(msg.Errors) > 0 {
   350  		for _, e := range msg.Errors {
   351  			schema, err := parseType(namespace, e, seen, cache)
   352  			if err != nil {
   353  				return nil, err
   354  			}
   355  
   356  			if rec, ok := schema.(*RecordSchema); ok && !rec.IsError() {
   357  				return nil, errors.New("avro: errors record schema must be of type error")
   358  			}
   359  
   360  			types = append(types, schema)
   361  		}
   362  	}
   363  	errs, err := NewUnionSchema(types)
   364  	if err != nil {
   365  		return nil, err
   366  	}
   367  
   368  	oneWay := msg.OneWay
   369  	if hasKey(meta.Keys, "one-way") && oneWay && (len(errs.Types()) > 1 || response != nil) {
   370  		return nil, errors.New("avro: one-way messages cannot not have a response or errors")
   371  	}
   372  	if !oneWay && len(errs.Types()) <= 1 && response == nil {
   373  		oneWay = true
   374  	}
   375  
   376  	return NewMessage(request, response, errs, oneWay, WithProtoDoc(msg.Doc), WithProtoProps(msg.Props)), nil
   377  }