github.com/decred/dcrlnd@v0.7.6/tlv/stream.go (about)

     1  package tlv
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"io"
     7  	"io/ioutil"
     8  	"math"
     9  )
    10  
    11  // MaxRecordSize is the maximum size of a particular record that will be parsed
    12  // by a stream decoder. This value is currently chosen to the be equal to the
    13  // maximum message size permitted by BOLT 1, as no record should be bigger than
    14  // an entire message.
    15  const MaxRecordSize = 65535 // 65KB
    16  
    17  // ErrStreamNotCanonical signals that a decoded stream does not contain records
    18  // sorting by monotonically-increasing type.
    19  var ErrStreamNotCanonical = errors.New("tlv stream is not canonical")
    20  
    21  // ErrRecordTooLarge signals that a decoded record has a length that is too
    22  // long to parse.
    23  var ErrRecordTooLarge = errors.New("record is too large")
    24  
    25  // Stream defines a TLV stream that can be used for encoding or decoding a set
    26  // of TLV Records.
    27  type Stream struct {
    28  	records []Record
    29  	buf     [8]byte
    30  }
    31  
    32  // NewStream creates a new TLV Stream given an encoding codec, a decoding codec,
    33  // and a set of known records.
    34  func NewStream(records ...Record) (*Stream, error) {
    35  	// Assert that the ordering of the Records is canonical and appear in
    36  	// ascending order of type.
    37  	var (
    38  		min      Type
    39  		overflow bool
    40  	)
    41  	for _, record := range records {
    42  		if overflow || record.typ < min {
    43  			return nil, ErrStreamNotCanonical
    44  		}
    45  		if record.encoder == nil {
    46  			record.encoder = ENOP
    47  		}
    48  		if record.decoder == nil {
    49  			record.decoder = DNOP
    50  		}
    51  		if record.typ == math.MaxUint64 {
    52  			overflow = true
    53  		}
    54  		min = record.typ + 1
    55  	}
    56  
    57  	return &Stream{
    58  		records: records,
    59  	}, nil
    60  }
    61  
    62  // MustNewStream creates a new TLV Stream given an encoding codec, a decoding
    63  // codec, and a set of known records. If an error is encountered in creating the
    64  // stream, this method will panic instead of returning the error.
    65  func MustNewStream(records ...Record) *Stream {
    66  	stream, err := NewStream(records...)
    67  	if err != nil {
    68  		panic(err.Error())
    69  	}
    70  	return stream
    71  }
    72  
    73  // Encode writes a Stream to the passed io.Writer. Each of the Records known to
    74  // the Stream is written in ascending order of their type so as to be canonical.
    75  //
    76  // The stream is constructed by concatenating the individual, serialized Records
    77  // where each record has the following format:
    78  //
    79  //	[varint: type]
    80  //	[varint: length]
    81  //	[length: value]
    82  //
    83  // An error is returned if the io.Writer fails to accept bytes from the
    84  // encoding, and nothing else. The ordering of the Records is asserted upon the
    85  // creation of a Stream, and thus the output will be by definition canonical.
    86  func (s *Stream) Encode(w io.Writer) error {
    87  	// Iterate through all known records, if any, serializing each record's
    88  	// type, length and value.
    89  	for i := range s.records {
    90  		rec := &s.records[i]
    91  
    92  		// Write the record's type as a varint.
    93  		err := WriteVarInt(w, uint64(rec.typ), &s.buf)
    94  		if err != nil {
    95  			return err
    96  		}
    97  
    98  		// Write the record's length as a varint.
    99  		err = WriteVarInt(w, rec.Size(), &s.buf)
   100  		if err != nil {
   101  			return err
   102  		}
   103  
   104  		// Encode the current record's value using the stream's codec.
   105  		err = rec.encoder(w, rec.value, &s.buf)
   106  		if err != nil {
   107  			return err
   108  		}
   109  	}
   110  
   111  	return nil
   112  }
   113  
   114  // Decode deserializes TLV Stream from the passed io.Reader. The Stream will
   115  // inspect each record that is parsed and check to see if it has a corresponding
   116  // Record to facilitate deserialization of that field. If the record is unknown,
   117  // the Stream will discard the record's bytes and proceed to the subsequent
   118  // record.
   119  //
   120  // Each record has the following format:
   121  //
   122  //	[varint: type]
   123  //	[varint: length]
   124  //	[length: value]
   125  //
   126  // A series of (possibly zero) records are concatenated into a stream, this
   127  // example contains two records:
   128  //
   129  //	(t: 0x01, l: 0x04, v: 0xff, 0xff, 0xff, 0xff)
   130  //	(t: 0x02, l: 0x01, v: 0x01)
   131  //
   132  // This method asserts that the byte stream is canonical, namely that each
   133  // record is unique and that all records are sorted in ascending order. An
   134  // ErrNotCanonicalStream error is returned if the encoded TLV stream is not.
   135  //
   136  // We permit an io.EOF error only when reading the type byte which signals that
   137  // the last record was read cleanly and we should stop parsing. All other io.EOF
   138  // or io.ErrUnexpectedEOF errors are returned.
   139  func (s *Stream) Decode(r io.Reader) error {
   140  	_, err := s.decode(r, nil)
   141  	return err
   142  }
   143  
   144  // DecodeWithParsedTypes is identical to Decode, but if successful, returns a
   145  // TypeMap containing the types of all records that were decoded or ignored from
   146  // the stream.
   147  func (s *Stream) DecodeWithParsedTypes(r io.Reader) (TypeMap, error) {
   148  	return s.decode(r, make(TypeMap))
   149  }
   150  
   151  // decode is a helper function that performs the basis of stream decoding. If
   152  // the caller needs the set of parsed types, it must provide an initialized
   153  // parsedTypes, otherwise the returned TypeMap will be nil.
   154  func (s *Stream) decode(r io.Reader, parsedTypes TypeMap) (TypeMap, error) {
   155  	var (
   156  		typ       Type
   157  		min       Type
   158  		recordIdx int
   159  		overflow  bool
   160  	)
   161  
   162  	// Iterate through all possible type identifiers. As types are read from
   163  	// the io.Reader, min will skip forward to the last read type.
   164  	for {
   165  		// Read the next varint type.
   166  		t, err := ReadVarInt(r, &s.buf)
   167  		switch {
   168  
   169  		// We'll silence an EOF when zero bytes remain, meaning the
   170  		// stream was cleanly encoded.
   171  		case err == io.EOF:
   172  			return parsedTypes, nil
   173  
   174  		// Other unexpected errors.
   175  		case err != nil:
   176  			return nil, err
   177  		}
   178  
   179  		typ = Type(t)
   180  
   181  		// Assert that this type is greater than any previously read.
   182  		// If we've already overflowed and we parsed another type, the
   183  		// stream is not canonical. This check prevents us from accepts
   184  		// encodings that have duplicate records or from accepting an
   185  		// unsorted series.
   186  		if overflow || typ < min {
   187  			return nil, ErrStreamNotCanonical
   188  		}
   189  
   190  		// Read the varint length.
   191  		length, err := ReadVarInt(r, &s.buf)
   192  		switch {
   193  
   194  		// We'll convert any EOFs to ErrUnexpectedEOF, since this
   195  		// results in an invalid record.
   196  		case err == io.EOF:
   197  			return nil, io.ErrUnexpectedEOF
   198  
   199  		// Other unexpected errors.
   200  		case err != nil:
   201  			return nil, err
   202  		}
   203  
   204  		// Place a soft limit on the size of a sane record, which
   205  		// prevents malicious encoders from causing us to allocate an
   206  		// unbounded amount of memory when decoding variable-sized
   207  		// fields.
   208  		if length > MaxRecordSize {
   209  			return nil, ErrRecordTooLarge
   210  		}
   211  
   212  		// Search the records known to the stream for this type. We'll
   213  		// begin the search and recordIdx and walk forward until we find
   214  		// it or the next record's type is larger.
   215  		rec, newIdx, ok := s.getRecord(typ, recordIdx)
   216  		switch {
   217  
   218  		// We know of this record type, proceed to decode the value.
   219  		// This method asserts that length bytes are read in the
   220  		// process, and returns an error if the number of bytes is not
   221  		// exactly length.
   222  		case ok:
   223  			err := rec.decoder(r, rec.value, &s.buf, length)
   224  			switch {
   225  
   226  			// We'll convert any EOFs to ErrUnexpectedEOF, since this
   227  			// results in an invalid record.
   228  			case err == io.EOF:
   229  				return nil, io.ErrUnexpectedEOF
   230  
   231  			// Other unexpected errors.
   232  			case err != nil:
   233  				return nil, err
   234  			}
   235  
   236  			// Record the successfully decoded type if the caller
   237  			// provided an initialized TypeMap.
   238  			if parsedTypes != nil {
   239  				parsedTypes[typ] = nil
   240  			}
   241  
   242  		// Otherwise, the record type is unknown and is odd, discard the
   243  		// number of bytes specified by length.
   244  		default:
   245  			// If the caller provided an initialized TypeMap, record
   246  			// the encoded bytes.
   247  			var b *bytes.Buffer
   248  			writer := ioutil.Discard
   249  			if parsedTypes != nil {
   250  				b = bytes.NewBuffer(make([]byte, 0, length))
   251  				writer = b
   252  			}
   253  
   254  			_, err := io.CopyN(writer, r, int64(length))
   255  			switch {
   256  
   257  			// We'll convert any EOFs to ErrUnexpectedEOF, since this
   258  			// results in an invalid record.
   259  			case err == io.EOF:
   260  				return nil, io.ErrUnexpectedEOF
   261  
   262  			// Other unexpected errors.
   263  			case err != nil:
   264  				return nil, err
   265  			}
   266  
   267  			if parsedTypes != nil {
   268  				parsedTypes[typ] = b.Bytes()
   269  			}
   270  		}
   271  
   272  		// Update our record index so that we can begin our next search
   273  		// from where we left off.
   274  		recordIdx = newIdx
   275  
   276  		// If we've parsed the largest possible type, the next loop will
   277  		// overflow back to zero. However, we need to attempt parsing
   278  		// the next type to ensure that the stream is empty.
   279  		if typ == math.MaxUint64 {
   280  			overflow = true
   281  		}
   282  
   283  		// Finally, set our lower bound on the next accepted type.
   284  		min = typ + 1
   285  	}
   286  }
   287  
   288  // getRecord searches for a record matching typ known to the stream. The boolean
   289  // return value indicates whether the record is known to the stream. The integer
   290  // return value carries the index from where getRecord should be invoked on the
   291  // subsequent call. The first call to getRecord should always use an idx of 0.
   292  func (s *Stream) getRecord(typ Type, idx int) (Record, int, bool) {
   293  	for idx < len(s.records) {
   294  		record := s.records[idx]
   295  		switch {
   296  
   297  		// Found target record, return it to the caller. The next index
   298  		// returned points to the immediately following record.
   299  		case record.typ == typ:
   300  			return record, idx + 1, true
   301  
   302  		// This record's type is lower than the target. Advance our
   303  		// index and continue to the next record which will have a
   304  		// strictly higher type.
   305  		case record.typ < typ:
   306  			idx++
   307  			continue
   308  
   309  		// This record's type is larger than the target, hence we have
   310  		// no record matching the current type. Return the current index
   311  		// so that we can start our search from here when processing the
   312  		// next tlv record.
   313  		default:
   314  			return Record{}, idx, false
   315  		}
   316  	}
   317  
   318  	// All known records are exhausted.
   319  	return Record{}, idx, false
   320  }