github.com/segmentio/encoding@v0.4.0/proto/message.go (about)

     1  package proto
     2  
     3  import (
     4  	"encoding/binary"
     5  	"fmt"
     6  	"io"
     7  	"math"
     8  	"reflect"
     9  	"unsafe"
    10  )
    11  
    12  // Message is an interface implemented by types that supported being encoded to
    13  // and decoded from protobuf.
    14  type Message interface {
    15  	// Size is the size of the protobuf representation (in bytes).
    16  	Size() int
    17  
    18  	// Marshal writes the message to the byte slice passed as argument.
    19  	Marshal([]byte) error
    20  
    21  	// Unmarshal reads the message from the byte slice passed as argument.
    22  	Unmarshal([]byte) error
    23  }
    24  
    25  // RawMessage represents a raw protobuf-encoded message.
    26  type RawMessage []byte
    27  
    28  // Size satisfies the Message interface.
    29  func (m RawMessage) Size() int { return len(m) }
    30  
    31  // Marshal satisfies the Message interface.
    32  func (m RawMessage) Marshal(b []byte) error {
    33  	copy(b, m)
    34  	return nil
    35  }
    36  
    37  // Unmarshal satisfies the Message interface.
    38  func (m *RawMessage) Unmarshal(b []byte) error {
    39  	*m = make([]byte, len(b))
    40  	copy(*m, b)
    41  	return nil
    42  }
    43  
    44  // Rewrite satisfies the Rewriter interface.
    45  func (m RawMessage) Rewrite(out, _ []byte) ([]byte, error) {
    46  	return append(out, m...), nil
    47  }
    48  
    49  // FieldNumber represents a protobuf field number.
    50  type FieldNumber uint
    51  
    52  func (f FieldNumber) Bool(v bool) RawMessage {
    53  	var x uint64
    54  	if v {
    55  		x = 1
    56  	}
    57  	return AppendVarint(nil, f, x)
    58  }
    59  
    60  func (f FieldNumber) Int(v int) RawMessage {
    61  	return f.Int64(int64(v))
    62  }
    63  
    64  func (f FieldNumber) Int32(v int32) RawMessage {
    65  	return f.Int64(int64(v))
    66  }
    67  
    68  func (f FieldNumber) Int64(v int64) RawMessage {
    69  	return AppendVarint(nil, f, uint64(v))
    70  }
    71  
    72  func (f FieldNumber) Uint(v uint) RawMessage {
    73  	return f.Uint64(uint64(v))
    74  }
    75  
    76  func (f FieldNumber) Uint32(v uint32) RawMessage {
    77  	return f.Uint64(uint64(v))
    78  }
    79  
    80  func (f FieldNumber) Uint64(v uint64) RawMessage {
    81  	return AppendVarint(nil, f, v)
    82  }
    83  
    84  func (f FieldNumber) Fixed32(v uint32) RawMessage {
    85  	return AppendFixed32(nil, f, v)
    86  }
    87  
    88  func (f FieldNumber) Fixed64(v uint64) RawMessage {
    89  	return AppendFixed64(nil, f, v)
    90  }
    91  
    92  func (f FieldNumber) Float32(v float32) RawMessage {
    93  	return f.Fixed32(math.Float32bits(v))
    94  }
    95  
    96  func (f FieldNumber) Float64(v float64) RawMessage {
    97  	return f.Fixed64(math.Float64bits(v))
    98  }
    99  
   100  func (f FieldNumber) String(v string) RawMessage {
   101  	return f.Bytes([]byte(v))
   102  }
   103  
   104  func (f FieldNumber) Bytes(v []byte) RawMessage {
   105  	return AppendVarlen(nil, f, v)
   106  }
   107  
   108  // Value constructs a RawMessage for field number f from v.
   109  func (f FieldNumber) Value(v interface{}) RawMessage {
   110  	switch x := v.(type) {
   111  	case bool:
   112  		return f.Bool(x)
   113  	case int:
   114  		return f.Int(x)
   115  	case int32:
   116  		return f.Int32(x)
   117  	case int64:
   118  		return f.Int64(x)
   119  	case uint:
   120  		return f.Uint(x)
   121  	case uint32:
   122  		return f.Uint32(x)
   123  	case uint64:
   124  		return f.Uint64(x)
   125  	case float32:
   126  		return f.Float32(x)
   127  	case float64:
   128  		return f.Float64(x)
   129  	case string:
   130  		return f.String(x)
   131  	case []byte:
   132  		return f.Bytes(x)
   133  	default:
   134  		panic("cannot rewrite value of unsupported type")
   135  	}
   136  }
   137  
   138  // The WireType enumeration represents the different protobuf wire types.
   139  type WireType uint
   140  
   141  const (
   142  	Varint  WireType = 0
   143  	Fixed64 WireType = 1
   144  	Varlen  WireType = 2
   145  	Fixed32 WireType = 5
   146  	// Wire types 3 and 4 were used for StartGroup and EndGroup, but are
   147  	// deprecated so we don't expose them here.
   148  	//
   149  	// https://developers.google.com/protocol-buffers/docs/encoding#structure
   150  )
   151  
   152  func (wt WireType) String() string {
   153  	return wireType(wt).String()
   154  }
   155  
   156  func Append(m RawMessage, f FieldNumber, t WireType, v []byte) RawMessage {
   157  	b := [20]byte{}
   158  	n, _ := encodeVarint(b[:], EncodeTag(f, t))
   159  	if t == Varlen {
   160  		n1, _ := encodeVarint(b[n:], uint64(len(v)))
   161  		n += n1
   162  	}
   163  	m = append(m, b[:n]...)
   164  	m = append(m, v...)
   165  	return m
   166  }
   167  
   168  func AppendVarint(m RawMessage, f FieldNumber, v uint64) RawMessage {
   169  	b := [10]byte{}
   170  	n, _ := encodeVarint(b[:], v)
   171  	return Append(m, f, Varint, b[:n])
   172  }
   173  
   174  func AppendVarlen(m RawMessage, f FieldNumber, v []byte) RawMessage {
   175  	return Append(m, f, Varlen, v)
   176  }
   177  
   178  func AppendFixed32(m RawMessage, f FieldNumber, v uint32) RawMessage {
   179  	b := [4]byte{}
   180  	binary.LittleEndian.PutUint32(b[:], v)
   181  	return Append(m, f, Fixed32, b[:])
   182  }
   183  
   184  func AppendFixed64(m RawMessage, f FieldNumber, v uint64) RawMessage {
   185  	b := [8]byte{}
   186  	binary.LittleEndian.PutUint64(b[:], v)
   187  	return Append(m, f, Fixed64, b[:])
   188  }
   189  
   190  func Parse(m []byte) (FieldNumber, WireType, RawValue, RawMessage, error) {
   191  	tag, n, err := decodeVarint(m)
   192  	if err != nil {
   193  		return 0, 0, nil, m, fmt.Errorf("decoding protobuf field number: %w", err)
   194  	}
   195  	m = m[n:]
   196  	f, t := DecodeTag(tag)
   197  
   198  	switch t {
   199  	case Varint:
   200  		_, n, err := decodeVarint(m)
   201  		if err != nil {
   202  			return f, t, nil, m, fmt.Errorf("decoding varint field %d: %w", f, err)
   203  		}
   204  		if len(m) < n {
   205  			return f, t, nil, m, fmt.Errorf("decoding varint field %d: %w", f, io.ErrUnexpectedEOF)
   206  		}
   207  		return f, t, RawValue(m[:n]), m[n:], nil
   208  
   209  	case Varlen:
   210  		l, n, err := decodeVarint(m) // length
   211  		if err != nil {
   212  			return f, t, nil, m, fmt.Errorf("decoding varlen field %d: %w", f, err)
   213  		}
   214  		if uint64(len(m)-n) < l {
   215  			return f, t, nil, m, fmt.Errorf("decoding varlen field %d: %w", f, io.ErrUnexpectedEOF)
   216  		}
   217  		return f, t, RawValue(m[n : n+int(l)]), m[n+int(l):], nil
   218  
   219  	case Fixed32:
   220  		if len(m) < 4 {
   221  			return f, t, nil, m, fmt.Errorf("decoding fixed32 field %d: %w", f, io.ErrUnexpectedEOF)
   222  		}
   223  		return f, t, RawValue(m[:4]), m[4:], nil
   224  
   225  	case Fixed64:
   226  		if len(m) < 8 {
   227  			return f, t, nil, m, fmt.Errorf("decoding fixed64 field %d: %w", f, io.ErrUnexpectedEOF)
   228  		}
   229  		return f, t, RawValue(m[:8]), m[8:], nil
   230  
   231  	default:
   232  		return f, t, nil, m, fmt.Errorf("invalid wire type: %d", t)
   233  	}
   234  }
   235  
   236  // Scan calls fn for each protobuf field in the message b.
   237  //
   238  // The iteration stops when all fields have been scanned, fn returns false, or
   239  // an error is seen.
   240  func Scan(b []byte, fn func(FieldNumber, WireType, RawValue) (bool, error)) error {
   241  	for len(b) != 0 {
   242  		f, t, v, m, err := Parse(b)
   243  		if err != nil {
   244  			return err
   245  		}
   246  		if ok, err := fn(f, t, v); !ok {
   247  			return err
   248  		}
   249  		b = m
   250  	}
   251  	return nil
   252  }
   253  
   254  // RawValue represents a single protobuf value.
   255  //
   256  // RawValue instances are returned by Parse and share the backing array of the
   257  // RawMessage that they were decoded from.
   258  type RawValue []byte
   259  
   260  // Varint decodes v as a varint.
   261  //
   262  // The content of v will always be a valid varint if v was returned by a call to
   263  // Parse and the associated wire type was Varint. In other cases, the behavior
   264  // of Varint is undefined.
   265  func (v RawValue) Varint() uint64 {
   266  	u, _, _ := decodeVarint(v)
   267  	return u
   268  }
   269  
   270  // Fixed32 decodes v as a fixed32.
   271  //
   272  // The content of v will always be a valid fixed32 if v was returned by a call
   273  // to Parse and the associated wire type was Fixed32. In other cases, the
   274  // behavior of Fixed32 is undefined.
   275  func (v RawValue) Fixed32() uint32 {
   276  	return binary.LittleEndian.Uint32(v)
   277  }
   278  
   279  // Fixed64 decodes v as a fixed64.
   280  //
   281  // The content of v will always be a valid fixed64 if v was returned by a call
   282  // to Parse and the associated wire type was Fixed64. In other cases, the
   283  // behavior of Fixed64 is undefined.
   284  func (v RawValue) Fixed64() uint64 {
   285  	return binary.LittleEndian.Uint64(v)
   286  }
   287  
   288  var (
   289  	_ Message  = &RawMessage{}
   290  	_ Rewriter = RawMessage{}
   291  )
   292  
   293  func messageCodecOf(t reflect.Type) *codec {
   294  	return &codec{
   295  		wire:   varlen,
   296  		size:   messageSizeFuncOf(t),
   297  		encode: messageEncodeFuncOf(t),
   298  		decode: messageDecodeFuncOf(t),
   299  	}
   300  }
   301  
   302  func messageSizeFuncOf(t reflect.Type) sizeFunc {
   303  	return func(p unsafe.Pointer, flags flags) int {
   304  		if p != nil {
   305  			if m := reflect.NewAt(t, p).Interface().(Message); m != nil {
   306  				size := m.Size()
   307  				if flags.has(toplevel) {
   308  					return size
   309  				}
   310  				return sizeOfVarlen(size)
   311  			}
   312  		}
   313  		return 0
   314  	}
   315  }
   316  
   317  func messageEncodeFuncOf(t reflect.Type) encodeFunc {
   318  	return func(b []byte, p unsafe.Pointer, flags flags) (int, error) {
   319  		if p != nil {
   320  			if m := reflect.NewAt(t, p).Interface().(Message); m != nil {
   321  				size := m.Size()
   322  
   323  				if flags.has(toplevel) {
   324  					if len(b) < size {
   325  						return 0, io.ErrShortBuffer
   326  					}
   327  					return len(b), m.Marshal(b)
   328  				}
   329  
   330  				vlen := sizeOfVarlen(size)
   331  				if len(b) < vlen {
   332  					return 0, io.ErrShortBuffer
   333  				}
   334  
   335  				n, err := encodeVarint(b, uint64(size))
   336  				if err != nil {
   337  					return n, err
   338  				}
   339  
   340  				return vlen, m.Marshal(b[n:])
   341  			}
   342  		}
   343  		return 0, nil
   344  	}
   345  }
   346  
   347  func messageDecodeFuncOf(t reflect.Type) decodeFunc {
   348  	return func(b []byte, p unsafe.Pointer, flags flags) (int, error) {
   349  		m := reflect.NewAt(t, p).Interface().(Message)
   350  
   351  		if flags.has(toplevel) {
   352  			return len(b), m.Unmarshal(b)
   353  		}
   354  
   355  		v, n, err := decodeVarlen(b)
   356  		if err != nil {
   357  			return n, err
   358  		}
   359  
   360  		return n + len(v), m.Unmarshal(v)
   361  	}
   362  }