github.com/hamba/avro@v1.8.0/protocol.go (about)

     1  package avro
     2  
     3  import (
     4  	"crypto/md5"
     5  	"encoding/hex"
     6  	"errors"
     7  	"io/ioutil"
     8  
     9  	jsoniter "github.com/json-iterator/go"
    10  )
    11  
    12  var (
    13  	protocolReserved = []string{"doc", "types", "messages", "protocol", "namespace"}
    14  	messageReserved  = []string{"doc", "response", "request", "errors", "one-way"}
    15  )
    16  
    17  // Protocol is an Avro protocol.
    18  type Protocol struct {
    19  	name
    20  	properties
    21  
    22  	types    []NamedSchema
    23  	messages map[string]*Message
    24  
    25  	hash string
    26  }
    27  
    28  // NewProtocol creates a protocol instance.
    29  func NewProtocol(name, space string, types []NamedSchema, messages map[string]*Message) (*Protocol, error) {
    30  	n, err := newName(name, space)
    31  	if err != nil {
    32  		return nil, err
    33  	}
    34  
    35  	p := &Protocol{
    36  		name:       n,
    37  		properties: properties{reserved: protocolReserved},
    38  		types:      types,
    39  		messages:   messages,
    40  	}
    41  
    42  	b := md5.Sum([]byte(p.String()))
    43  	p.hash = hex.EncodeToString(b[:])
    44  
    45  	return p, nil
    46  }
    47  
    48  // Message returns a message with the given name or nil.
    49  func (p *Protocol) Message(name string) *Message {
    50  	return p.messages[name]
    51  }
    52  
    53  // Hash returns the MD5 hash of the protocol.
    54  func (p *Protocol) Hash() string {
    55  	return p.hash
    56  }
    57  
    58  // String returns the canonical form of the protocol.
    59  func (p *Protocol) String() string {
    60  	types := ""
    61  	for _, f := range p.types {
    62  		types += f.String() + ","
    63  	}
    64  	if len(types) > 0 {
    65  		types = types[:len(types)-1]
    66  	}
    67  
    68  	messages := ""
    69  	for k, m := range p.messages {
    70  		messages += `"` + k + `":` + m.String() + ","
    71  	}
    72  	if len(messages) > 0 {
    73  		messages = messages[:len(messages)-1]
    74  	}
    75  
    76  	return `{"protocol":"` + p.Name() +
    77  		`","namespace":"` + p.Namespace() +
    78  		`","types":[` + types + `],"messages":{` + messages + `}}`
    79  }
    80  
    81  // Message is an Avro protocol message.
    82  type Message struct {
    83  	properties
    84  
    85  	req    *RecordSchema
    86  	resp   Schema
    87  	errs   *UnionSchema
    88  	oneWay bool
    89  }
    90  
    91  // NewMessage creates a protocol message instance.
    92  func NewMessage(req *RecordSchema, resp Schema, errors *UnionSchema, oneWay bool) *Message {
    93  	return &Message{
    94  		properties: properties{reserved: messageReserved},
    95  		req:        req,
    96  		resp:       resp,
    97  		errs:       errors,
    98  		oneWay:     oneWay,
    99  	}
   100  }
   101  
   102  // Request returns the message request schema.
   103  func (m *Message) Request() *RecordSchema {
   104  	return m.req
   105  }
   106  
   107  // Response returns the message response schema.
   108  func (m *Message) Response() Schema {
   109  	return m.resp
   110  }
   111  
   112  // Errors returns the message errors union schema.
   113  func (m *Message) Errors() *UnionSchema {
   114  	return m.errs
   115  }
   116  
   117  // OneWay determines of the message is a one way message.
   118  func (m *Message) OneWay() bool {
   119  	return m.oneWay
   120  }
   121  
   122  // String returns the canonical form of the message.
   123  func (m *Message) String() string {
   124  	fields := ""
   125  	for _, f := range m.req.fields {
   126  		fields += f.String() + ","
   127  	}
   128  	if len(fields) > 0 {
   129  		fields = fields[:len(fields)-1]
   130  	}
   131  
   132  	str := `{"request":[` + fields + `]`
   133  
   134  	if m.resp != nil {
   135  		str += `,"response":` + m.resp.String()
   136  	}
   137  
   138  	if m.errs != nil && len(m.errs.Types()) > 1 {
   139  		errs, _ := NewUnionSchema(m.errs.Types()[1:])
   140  		str += `,"errors":` + errs.String()
   141  	}
   142  
   143  	str += "}"
   144  	return str
   145  }
   146  
   147  // ParseProtocolFile parses an Avro protocol from a file.
   148  func ParseProtocolFile(path string) (*Protocol, error) {
   149  	s, err := ioutil.ReadFile(path)
   150  	if err != nil {
   151  		return nil, err
   152  	}
   153  
   154  	return ParseProtocol(string(s))
   155  }
   156  
   157  // MustParseProtocol parses an Avro protocol, panicing if there is an error.
   158  func MustParseProtocol(protocol string) *Protocol {
   159  	parsed, err := ParseProtocol(protocol)
   160  	if err != nil {
   161  		panic(err)
   162  	}
   163  
   164  	return parsed
   165  }
   166  
   167  // ParseProtocol parses an Avro protocol.
   168  func ParseProtocol(protocol string) (*Protocol, error) {
   169  	cache := &SchemaCache{}
   170  
   171  	var m map[string]interface{}
   172  	if err := jsoniter.Unmarshal([]byte(protocol), &m); err != nil {
   173  		return nil, err
   174  	}
   175  
   176  	name, err := resolveProtocolName(m)
   177  	if err != nil {
   178  		return nil, err
   179  	}
   180  
   181  	var types []NamedSchema
   182  	if ts, ok := m["types"].([]interface{}); ok {
   183  		types, err = parseProtocolTypes(name.space, ts, cache)
   184  		if err != nil {
   185  			return nil, err
   186  		}
   187  	}
   188  
   189  	messages := map[string]*Message{}
   190  	if msgs, ok := m["messages"].(map[string]interface{}); ok {
   191  		for k, msg := range msgs {
   192  			m, ok := msg.(map[string]interface{})
   193  			if !ok {
   194  				return nil, errors.New("avro: message must be an object")
   195  			}
   196  
   197  			message, err := parseMessage(name.space, m, cache)
   198  			if err != nil {
   199  				return nil, err
   200  			}
   201  
   202  			messages[k] = message
   203  		}
   204  	}
   205  
   206  	proto, _ := NewProtocol(name.name, name.space, types, messages)
   207  
   208  	for k, v := range m {
   209  		proto.AddProp(k, v)
   210  	}
   211  
   212  	return proto, nil
   213  }
   214  
   215  func parseProtocolTypes(namespace string, types []interface{}, cache *SchemaCache) ([]NamedSchema, error) {
   216  	ts := make([]NamedSchema, len(types))
   217  	for i, typ := range types {
   218  		schema, err := parseType(namespace, typ, cache)
   219  		if err != nil {
   220  			return nil, err
   221  		}
   222  
   223  		namedSchema, ok := schema.(NamedSchema)
   224  		if !ok {
   225  			return nil, errors.New("avro: protocol types must be named schemas")
   226  		}
   227  
   228  		ts[i] = namedSchema
   229  	}
   230  
   231  	return ts, nil
   232  }
   233  
   234  func parseMessage(namespace string, m map[string]interface{}, cache *SchemaCache) (*Message, error) {
   235  	req, ok := m["request"].([]interface{})
   236  	if !ok {
   237  		return nil, errors.New("avro: request must have an array of fields")
   238  	}
   239  
   240  	fields := make([]*Field, len(req))
   241  	for i, f := range req {
   242  		field, err := parseField(namespace, f, cache)
   243  		if err != nil {
   244  			return nil, err
   245  		}
   246  
   247  		fields[i] = field
   248  	}
   249  	request := &RecordSchema{
   250  		name:       name{},
   251  		properties: properties{reserved: schemaReserved},
   252  		fields:     fields,
   253  	}
   254  
   255  	var response Schema
   256  	if res, ok := m["response"]; ok {
   257  		schema, err := parseType(namespace, res, cache)
   258  		if err != nil {
   259  			return nil, err
   260  		}
   261  
   262  		if schema.Type() != Null {
   263  			response = schema
   264  		}
   265  	}
   266  
   267  	types := []Schema{NewPrimitiveSchema(String, nil)}
   268  	if errs, ok := m["errors"].([]interface{}); ok {
   269  		for _, e := range errs {
   270  			schema, err := parseType(namespace, e, cache)
   271  			if err != nil {
   272  				return nil, err
   273  			}
   274  
   275  			if rec, ok := schema.(*RecordSchema); ok && !rec.IsError() {
   276  				return nil, errors.New("avro: errors record schema must be of type error")
   277  			}
   278  
   279  			types = append(types, schema)
   280  		}
   281  	}
   282  	errs, err := NewUnionSchema(types)
   283  	if err != nil {
   284  		return nil, err
   285  	}
   286  
   287  	oneWay := false
   288  	if o, ok := m["one-way"].(bool); ok {
   289  		oneWay = o
   290  		if oneWay && (len(errs.Types()) > 1 || response != nil) {
   291  			return nil, errors.New("avro: one-way messages cannot not have a response or errors")
   292  		}
   293  	}
   294  
   295  	if !oneWay && len(errs.Types()) <= 1 && response == nil {
   296  		oneWay = true
   297  	}
   298  
   299  	msg := NewMessage(request, response, errs, oneWay)
   300  
   301  	for k, v := range m {
   302  		msg.AddProp(k, v)
   303  	}
   304  
   305  	return msg, nil
   306  }
   307  
   308  func resolveProtocolName(m map[string]interface{}) (name, error) {
   309  	proto, ok := m["protocol"].(string)
   310  	if !ok {
   311  		return name{}, errors.New("avro: protocol key required")
   312  	}
   313  
   314  	space := ""
   315  	if namespace, ok := m["namespace"].(string); ok {
   316  		if namespace == "" {
   317  			return name{}, errors.New("avro: namespace key must be non-empty or omitted")
   318  		}
   319  
   320  		space = namespace
   321  	}
   322  
   323  	return newName(proto, space)
   324  }