github.com/decred/dcrlnd@v0.7.6/channeldb/migration/lnwire21/query_short_chan_ids.go (about)

     1  package lnwire
     2  
     3  import (
     4  	"bytes"
     5  	"compress/zlib"
     6  	"fmt"
     7  	"io"
     8  	"sort"
     9  	"sync"
    10  
    11  	"github.com/decred/dcrd/chaincfg/chainhash"
    12  )
    13  
    14  // ShortChanIDEncoding is an enum-like type that represents exactly how a set
    15  // of short channel ID's is encoded on the wire. The set of encodings allows us
    16  // to take advantage of the structure of a list of short channel ID's to
    17  // achieving a high degree of compression.
    18  type ShortChanIDEncoding uint8
    19  
    20  const (
    21  	// EncodingSortedPlain signals that the set of short channel ID's is
    22  	// encoded using the regular encoding, in a sorted order.
    23  	EncodingSortedPlain ShortChanIDEncoding = 0
    24  
    25  	// EncodingSortedZlib signals that the set of short channel ID's is
    26  	// encoded by first sorting the set of channel ID's, as then
    27  	// compressing them using zlib.
    28  	EncodingSortedZlib ShortChanIDEncoding = 1
    29  )
    30  
    31  const (
    32  	// maxZlibBufSize is the max number of bytes that we'll accept from a
    33  	// zlib decoding instance. We do this in order to limit the total
    34  	// amount of memory allocated during a decoding instance.
    35  	maxZlibBufSize = 67413630
    36  )
    37  
    38  // ErrUnsortedSIDs is returned when decoding a QueryShortChannelID request whose
    39  // items were not sorted.
    40  type ErrUnsortedSIDs struct {
    41  	prevSID ShortChannelID
    42  	curSID  ShortChannelID
    43  }
    44  
    45  // Error returns a human-readable description of the error.
    46  func (e ErrUnsortedSIDs) Error() string {
    47  	return fmt.Sprintf("current sid: %v isn't greater than last sid: %v",
    48  		e.curSID, e.prevSID)
    49  }
    50  
    51  // zlibDecodeMtx is a package level mutex that we'll use in order to ensure
    52  // that we'll only attempt a single zlib decoding instance at a time. This
    53  // allows us to also further bound our memory usage.
    54  var zlibDecodeMtx sync.Mutex
    55  
    56  // ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we
    57  // came across an unknown short channel ID encoding, and therefore were unable
    58  // to continue parsing.
    59  func ErrUnknownShortChanIDEncoding(encoding ShortChanIDEncoding) error {
    60  	return fmt.Errorf("unknown short chan id encoding: %v", encoding)
    61  }
    62  
    63  // QueryShortChanIDs is a message that allows the sender to query a set of
    64  // channel announcement and channel update messages that correspond to the set
    65  // of encoded short channel ID's. The encoding of the short channel ID's is
    66  // detailed in the query message ensuring that the receiver knows how to
    67  // properly decode each encode short channel ID which may be encoded using a
    68  // compression format. The receiver should respond with a series of channel
    69  // announcement and channel updates, finally sending a ReplyShortChanIDsEnd
    70  // message.
    71  type QueryShortChanIDs struct {
    72  	// ChainHash denotes the target chain that we're querying for the
    73  	// channel ID's of.
    74  	ChainHash chainhash.Hash
    75  
    76  	// EncodingType is a signal to the receiver of the message that
    77  	// indicates exactly how the set of short channel ID's that follow have
    78  	// been encoded.
    79  	EncodingType ShortChanIDEncoding
    80  
    81  	// ShortChanIDs is a slice of decoded short channel ID's.
    82  	ShortChanIDs []ShortChannelID
    83  
    84  	// noSort indicates whether or not to sort the short channel ids before
    85  	// writing them out.
    86  	//
    87  	// NOTE: This should only be used during testing.
    88  	noSort bool
    89  }
    90  
    91  // NewQueryShortChanIDs creates a new QueryShortChanIDs message.
    92  func NewQueryShortChanIDs(h chainhash.Hash, e ShortChanIDEncoding,
    93  	s []ShortChannelID) *QueryShortChanIDs {
    94  
    95  	return &QueryShortChanIDs{
    96  		ChainHash:    h,
    97  		EncodingType: e,
    98  		ShortChanIDs: s,
    99  	}
   100  }
   101  
   102  // A compile time check to ensure QueryShortChanIDs implements the
   103  // lnwire.Message interface.
   104  var _ Message = (*QueryShortChanIDs)(nil)
   105  
   106  // Decode deserializes a serialized QueryShortChanIDs message stored in the
   107  // passed io.Reader observing the specified protocol version.
   108  //
   109  // This is part of the lnwire.Message interface.
   110  func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error {
   111  	err := ReadElements(r, q.ChainHash[:])
   112  	if err != nil {
   113  		return err
   114  	}
   115  
   116  	q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r)
   117  
   118  	return err
   119  }
   120  
   121  // decodeShortChanIDs decodes a set of short channel ID's that have been
   122  // encoded. The first byte of the body details how the short chan ID's were
   123  // encoded. We'll use this type to govern exactly how we go about encoding the
   124  // set of short channel ID's.
   125  func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, error) {
   126  	// First, we'll attempt to read the number of bytes in the body of the
   127  	// set of encoded short channel ID's.
   128  	var numBytesResp uint16
   129  	err := ReadElements(r, &numBytesResp)
   130  	if err != nil {
   131  		return 0, nil, err
   132  	}
   133  
   134  	if numBytesResp == 0 {
   135  		return 0, nil, nil
   136  	}
   137  
   138  	queryBody := make([]byte, numBytesResp)
   139  	if _, err := io.ReadFull(r, queryBody); err != nil {
   140  		return 0, nil, err
   141  	}
   142  
   143  	// The first byte is the encoding type, so we'll extract that so we can
   144  	// continue our parsing.
   145  	encodingType := ShortChanIDEncoding(queryBody[0])
   146  
   147  	// Before continuing, we'll snip off the first byte of the query body
   148  	// as that was just the encoding type.
   149  	queryBody = queryBody[1:]
   150  
   151  	// Otherwise, depending on the encoding type, we'll decode the encode
   152  	// short channel ID's in a different manner.
   153  	switch encodingType {
   154  
   155  	// In this encoding, we'll simply read a sort array of encoded short
   156  	// channel ID's from the buffer.
   157  	case EncodingSortedPlain:
   158  		// If after extracting the encoding type, the number of
   159  		// remaining bytes is not a whole multiple of the size of an
   160  		// encoded short channel ID (8 bytes), then we'll return a
   161  		// parsing error.
   162  		if len(queryBody)%8 != 0 {
   163  			return 0, nil, fmt.Errorf("whole number of short "+
   164  				"chan ID's cannot be encoded in len=%v",
   165  				len(queryBody))
   166  		}
   167  
   168  		// As each short channel ID is encoded as 8 bytes, we can
   169  		// compute the number of bytes encoded based on the size of the
   170  		// query body.
   171  		numShortChanIDs := len(queryBody) / 8
   172  		if numShortChanIDs == 0 {
   173  			return encodingType, nil, nil
   174  		}
   175  
   176  		// Finally, we'll read out the exact number of short channel
   177  		// ID's to conclude our parsing.
   178  		shortChanIDs := make([]ShortChannelID, numShortChanIDs)
   179  		bodyReader := bytes.NewReader(queryBody)
   180  		var lastChanID ShortChannelID
   181  		for i := 0; i < numShortChanIDs; i++ {
   182  			if err := ReadElements(bodyReader, &shortChanIDs[i]); err != nil {
   183  				return 0, nil, fmt.Errorf("unable to parse "+
   184  					"short chan ID: %v", err)
   185  			}
   186  
   187  			// We'll ensure that this short chan ID is greater than
   188  			// the last one. This is a requirement within the
   189  			// encoding, and if violated can aide us in detecting
   190  			// malicious payloads. This can only be true starting
   191  			// at the second chanID.
   192  			cid := shortChanIDs[i]
   193  			if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
   194  				return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
   195  			}
   196  			lastChanID = cid
   197  		}
   198  
   199  		return encodingType, shortChanIDs, nil
   200  
   201  	// In this encoding, we'll use zlib to decode the compressed payload.
   202  	// However, we'll pay attention to ensure that we don't open our selves
   203  	// up to a memory exhaustion attack.
   204  	case EncodingSortedZlib:
   205  		// We'll obtain an ultimately release the zlib decode mutex.
   206  		// This guards us against allocating too much memory to decode
   207  		// each instance from concurrent peers.
   208  		zlibDecodeMtx.Lock()
   209  		defer zlibDecodeMtx.Unlock()
   210  
   211  		// At this point, if there's no body remaining, then only the encoding
   212  		// type was specified, meaning that there're no further bytes to be
   213  		// parsed.
   214  		if len(queryBody) == 0 {
   215  			return encodingType, nil, nil
   216  		}
   217  
   218  		// Before we start to decode, we'll create a limit reader over
   219  		// the current reader. This will ensure that we can control how
   220  		// much memory we're allocating during the decoding process.
   221  		limitedDecompressor, err := zlib.NewReader(&io.LimitedReader{
   222  			R: bytes.NewReader(queryBody),
   223  			N: maxZlibBufSize,
   224  		})
   225  		if err != nil {
   226  			return 0, nil, fmt.Errorf("unable to create zlib reader: %v", err)
   227  		}
   228  
   229  		var (
   230  			shortChanIDs []ShortChannelID
   231  			lastChanID   ShortChannelID
   232  			i            int
   233  		)
   234  		for {
   235  			// We'll now attempt to read the next short channel ID
   236  			// encoded in the payload.
   237  			var cid ShortChannelID
   238  			err := ReadElements(limitedDecompressor, &cid)
   239  
   240  			switch {
   241  			// If we get an EOF error, then that either means we've
   242  			// read all that's contained in the buffer, or have hit
   243  			// our limit on the number of bytes we'll read. In
   244  			// either case, we'll return what we have so far.
   245  			case err == io.ErrUnexpectedEOF || err == io.EOF:
   246  				return encodingType, shortChanIDs, nil
   247  
   248  			// Otherwise, we hit some other sort of error, possibly
   249  			// an invalid payload, so we'll exit early with the
   250  			// error.
   251  			case err != nil:
   252  				return 0, nil, fmt.Errorf("unable to "+
   253  					"deflate next short chan "+
   254  					"ID: %v", err)
   255  			}
   256  
   257  			// We successfully read the next ID, so we'll collect
   258  			// that in the set of final ID's to return.
   259  			shortChanIDs = append(shortChanIDs, cid)
   260  
   261  			// Finally, we'll ensure that this short chan ID is
   262  			// greater than the last one. This is a requirement
   263  			// within the encoding, and if violated can aide us in
   264  			// detecting malicious payloads. This can only be true
   265  			// starting at the second chanID.
   266  			if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
   267  				return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
   268  			}
   269  
   270  			lastChanID = cid
   271  			i++
   272  		}
   273  
   274  	default:
   275  		// If we've been sent an encoding type that we don't know of,
   276  		// then we'll return a parsing error as we can't continue if
   277  		// we're unable to encode them.
   278  		return 0, nil, ErrUnknownShortChanIDEncoding(encodingType)
   279  	}
   280  }
   281  
   282  // Encode serializes the target QueryShortChanIDs into the passed io.Writer
   283  // observing the protocol version specified.
   284  //
   285  // This is part of the lnwire.Message interface.
   286  func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error {
   287  	// First, we'll write out the chain hash.
   288  	err := WriteElements(w, q.ChainHash[:])
   289  	if err != nil {
   290  		return err
   291  	}
   292  
   293  	// Base on our encoding type, we'll write out the set of short channel
   294  	// ID's.
   295  	return encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort)
   296  }
   297  
   298  // encodeShortChanIDs encodes the passed short channel ID's into the passed
   299  // io.Writer, respecting the specified encoding type.
   300  func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding,
   301  	shortChanIDs []ShortChannelID, noSort bool) error {
   302  
   303  	// For both of the current encoding types, the channel ID's are to be
   304  	// sorted in place, so we'll do that now. The sorting is applied unless
   305  	// we were specifically requested not to for testing purposes.
   306  	if !noSort {
   307  		sort.Slice(shortChanIDs, func(i, j int) bool {
   308  			return shortChanIDs[i].ToUint64() <
   309  				shortChanIDs[j].ToUint64()
   310  		})
   311  	}
   312  
   313  	switch encodingType {
   314  
   315  	// In this encoding, we'll simply write a sorted array of encoded short
   316  	// channel ID's from the buffer.
   317  	case EncodingSortedPlain:
   318  		// First, we'll write out the number of bytes of the query
   319  		// body. We add 1 as the response will have the encoding type
   320  		// prepended to it.
   321  		numBytesBody := uint16(len(shortChanIDs)*8) + 1
   322  		if err := WriteElements(w, numBytesBody); err != nil {
   323  			return err
   324  		}
   325  
   326  		// We'll then write out the encoding that that follows the
   327  		// actual encoded short channel ID's.
   328  		if err := WriteElements(w, encodingType); err != nil {
   329  			return err
   330  		}
   331  
   332  		// Now that we know they're sorted, we can write out each short
   333  		// channel ID to the buffer.
   334  		for _, chanID := range shortChanIDs {
   335  			if err := WriteElements(w, chanID); err != nil {
   336  				return fmt.Errorf("unable to write short chan "+
   337  					"ID: %v", err)
   338  			}
   339  		}
   340  
   341  		return nil
   342  
   343  	// For this encoding we'll first write out a serialized version of all
   344  	// the channel ID's into a buffer, then zlib encode that. The final
   345  	// payload is what we'll write out to the passed io.Writer.
   346  	//
   347  	// TODO(roasbeef): assumes the caller knows the proper chunk size to
   348  	// pass to avoid bin-packing here
   349  	case EncodingSortedZlib:
   350  		// We'll make a new buffer, then wrap that with a zlib writer
   351  		// so we can write directly to the buffer and encode in a
   352  		// streaming manner.
   353  		var buf bytes.Buffer
   354  		zlibWriter := zlib.NewWriter(&buf)
   355  
   356  		// If we don't have anything at all to write, then we'll write
   357  		// an empty payload so we don't include things like the zlib
   358  		// header when the remote party is expecting no actual short
   359  		// channel IDs.
   360  		var compressedPayload []byte
   361  		if len(shortChanIDs) > 0 {
   362  			// Next, we'll write out all the channel ID's directly
   363  			// into the zlib writer, which will do compressing on
   364  			// the fly.
   365  			for _, chanID := range shortChanIDs {
   366  				err := WriteElements(zlibWriter, chanID)
   367  				if err != nil {
   368  					return fmt.Errorf("unable to write short chan "+
   369  						"ID: %v", err)
   370  				}
   371  			}
   372  
   373  			// Now that we've written all the elements, we'll
   374  			// ensure the compressed stream is written to the
   375  			// underlying buffer.
   376  			if err := zlibWriter.Close(); err != nil {
   377  				return fmt.Errorf("unable to finalize "+
   378  					"compression: %v", err)
   379  			}
   380  
   381  			compressedPayload = buf.Bytes()
   382  		}
   383  
   384  		// Now that we have all the items compressed, we can compute
   385  		// what the total payload size will be. We add one to account
   386  		// for the byte to encode the type.
   387  		//
   388  		// If we don't have any actual bytes to write, then we'll end
   389  		// up emitting one byte for the length, followed by the
   390  		// encoding type, and nothing more. The spec isn't 100% clear
   391  		// in this area, but we do this as this is what most of the
   392  		// other implementations do.
   393  		numBytesBody := len(compressedPayload) + 1
   394  
   395  		// Finally, we can write out the number of bytes, the
   396  		// compression type, and finally the buffer itself.
   397  		if err := WriteElements(w, uint16(numBytesBody)); err != nil {
   398  			return err
   399  		}
   400  		if err := WriteElements(w, encodingType); err != nil {
   401  			return err
   402  		}
   403  
   404  		_, err := w.Write(compressedPayload)
   405  		return err
   406  
   407  	default:
   408  		// If we're trying to encode with an encoding type that we
   409  		// don't know of, then we'll return a parsing error as we can't
   410  		// continue if we're unable to encode them.
   411  		return ErrUnknownShortChanIDEncoding(encodingType)
   412  	}
   413  }
   414  
   415  // MsgType returns the integer uniquely identifying this message type on the
   416  // wire.
   417  //
   418  // This is part of the lnwire.Message interface.
   419  func (q *QueryShortChanIDs) MsgType() MessageType {
   420  	return MsgQueryShortChanIDs
   421  }
   422  
   423  // MaxPayloadLength returns the maximum allowed payload size for a
   424  // QueryShortChanIDs complete message observing the specified protocol version.
   425  //
   426  // This is part of the lnwire.Message interface.
   427  func (q *QueryShortChanIDs) MaxPayloadLength(uint32) uint32 {
   428  	return MaxMessagePayload
   429  }