github.com/decred/dcrlnd@v0.7.6/lnwire/query_short_chan_ids.go (about)

     1  package lnwire
     2  
     3  import (
     4  	"bytes"
     5  	"compress/zlib"
     6  	"fmt"
     7  	"io"
     8  	"sort"
     9  	"sync"
    10  
    11  	"github.com/decred/dcrd/chaincfg/chainhash"
    12  )
    13  
    14  // ShortChanIDEncoding is an enum-like type that represents exactly how a set
    15  // of short channel ID's is encoded on the wire. The set of encodings allows us
    16  // to take advantage of the structure of a list of short channel ID's to
    17  // achieving a high degree of compression.
    18  type ShortChanIDEncoding uint8
    19  
    20  const (
    21  	// EncodingSortedPlain signals that the set of short channel ID's is
    22  	// encoded using the regular encoding, in a sorted order.
    23  	EncodingSortedPlain ShortChanIDEncoding = 0
    24  
    25  	// EncodingSortedZlib signals that the set of short channel ID's is
    26  	// encoded by first sorting the set of channel ID's, as then
    27  	// compressing them using zlib.
    28  	EncodingSortedZlib ShortChanIDEncoding = 1
    29  )
    30  
    31  const (
    32  	// maxZlibBufSize is the max number of bytes that we'll accept from a
    33  	// zlib decoding instance. We do this in order to limit the total
    34  	// amount of memory allocated during a decoding instance.
    35  	maxZlibBufSize = 67413630
    36  )
    37  
    38  // ErrUnsortedSIDs is returned when decoding a QueryShortChannelID request whose
    39  // items were not sorted.
    40  type ErrUnsortedSIDs struct {
    41  	prevSID ShortChannelID
    42  	curSID  ShortChannelID
    43  }
    44  
    45  // Error returns a human-readable description of the error.
    46  func (e ErrUnsortedSIDs) Error() string {
    47  	return fmt.Sprintf("current sid: %v isn't greater than last sid: %v",
    48  		e.curSID, e.prevSID)
    49  }
    50  
    51  // zlibDecodeMtx is a package level mutex that we'll use in order to ensure
    52  // that we'll only attempt a single zlib decoding instance at a time. This
    53  // allows us to also further bound our memory usage.
    54  var zlibDecodeMtx sync.Mutex
    55  
    56  // ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we
    57  // came across an unknown short channel ID encoding, and therefore were unable
    58  // to continue parsing.
    59  func ErrUnknownShortChanIDEncoding(encoding ShortChanIDEncoding) error {
    60  	return fmt.Errorf("unknown short chan id encoding: %v", encoding)
    61  }
    62  
    63  // QueryShortChanIDs is a message that allows the sender to query a set of
    64  // channel announcement and channel update messages that correspond to the set
    65  // of encoded short channel ID's. The encoding of the short channel ID's is
    66  // detailed in the query message ensuring that the receiver knows how to
    67  // properly decode each encode short channel ID which may be encoded using a
    68  // compression format. The receiver should respond with a series of channel
    69  // announcement and channel updates, finally sending a ReplyShortChanIDsEnd
    70  // message.
    71  type QueryShortChanIDs struct {
    72  	// ChainHash denotes the target chain that we're querying for the
    73  	// channel ID's of.
    74  	ChainHash chainhash.Hash
    75  
    76  	// EncodingType is a signal to the receiver of the message that
    77  	// indicates exactly how the set of short channel ID's that follow have
    78  	// been encoded.
    79  	EncodingType ShortChanIDEncoding
    80  
    81  	// ShortChanIDs is a slice of decoded short channel ID's.
    82  	ShortChanIDs []ShortChannelID
    83  
    84  	// ExtraData is the set of data that was appended to this message to
    85  	// fill out the full maximum transport message size. These fields can
    86  	// be used to specify optional data such as custom TLV fields.
    87  	ExtraData ExtraOpaqueData
    88  
    89  	// noSort indicates whether or not to sort the short channel ids before
    90  	// writing them out.
    91  	//
    92  	// NOTE: This should only be used during testing.
    93  	noSort bool
    94  }
    95  
    96  // NewQueryShortChanIDs creates a new QueryShortChanIDs message.
    97  func NewQueryShortChanIDs(h chainhash.Hash, e ShortChanIDEncoding,
    98  	s []ShortChannelID) *QueryShortChanIDs {
    99  
   100  	return &QueryShortChanIDs{
   101  		ChainHash:    h,
   102  		EncodingType: e,
   103  		ShortChanIDs: s,
   104  	}
   105  }
   106  
   107  // A compile time check to ensure QueryShortChanIDs implements the
   108  // lnwire.Message interface.
   109  var _ Message = (*QueryShortChanIDs)(nil)
   110  
   111  // Decode deserializes a serialized QueryShortChanIDs message stored in the
   112  // passed io.Reader observing the specified protocol version.
   113  //
   114  // This is part of the lnwire.Message interface.
   115  func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error {
   116  	err := ReadElements(r, q.ChainHash[:])
   117  	if err != nil {
   118  		return err
   119  	}
   120  
   121  	q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r)
   122  	if err != nil {
   123  		return err
   124  	}
   125  
   126  	return q.ExtraData.Decode(r)
   127  }
   128  
   129  // decodeShortChanIDs decodes a set of short channel ID's that have been
   130  // encoded. The first byte of the body details how the short chan ID's were
   131  // encoded. We'll use this type to govern exactly how we go about encoding the
   132  // set of short channel ID's.
   133  func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, error) {
   134  	// First, we'll attempt to read the number of bytes in the body of the
   135  	// set of encoded short channel ID's.
   136  	var numBytesResp uint16
   137  	err := ReadElements(r, &numBytesResp)
   138  	if err != nil {
   139  		return 0, nil, err
   140  	}
   141  
   142  	if numBytesResp == 0 {
   143  		return 0, nil, nil
   144  	}
   145  
   146  	queryBody := make([]byte, numBytesResp)
   147  	if _, err := io.ReadFull(r, queryBody); err != nil {
   148  		return 0, nil, err
   149  	}
   150  
   151  	// The first byte is the encoding type, so we'll extract that so we can
   152  	// continue our parsing.
   153  	encodingType := ShortChanIDEncoding(queryBody[0])
   154  
   155  	// Before continuing, we'll snip off the first byte of the query body
   156  	// as that was just the encoding type.
   157  	queryBody = queryBody[1:]
   158  
   159  	// Otherwise, depending on the encoding type, we'll decode the encode
   160  	// short channel ID's in a different manner.
   161  	switch encodingType {
   162  
   163  	// In this encoding, we'll simply read a sort array of encoded short
   164  	// channel ID's from the buffer.
   165  	case EncodingSortedPlain:
   166  		// If after extracting the encoding type, the number of
   167  		// remaining bytes is not a whole multiple of the size of an
   168  		// encoded short channel ID (8 bytes), then we'll return a
   169  		// parsing error.
   170  		if len(queryBody)%8 != 0 {
   171  			return 0, nil, fmt.Errorf("whole number of short "+
   172  				"chan ID's cannot be encoded in len=%v",
   173  				len(queryBody))
   174  		}
   175  
   176  		// As each short channel ID is encoded as 8 bytes, we can
   177  		// compute the number of bytes encoded based on the size of the
   178  		// query body.
   179  		numShortChanIDs := len(queryBody) / 8
   180  		if numShortChanIDs == 0 {
   181  			return encodingType, nil, nil
   182  		}
   183  
   184  		// Finally, we'll read out the exact number of short channel
   185  		// ID's to conclude our parsing.
   186  		shortChanIDs := make([]ShortChannelID, numShortChanIDs)
   187  		bodyReader := bytes.NewReader(queryBody)
   188  		var lastChanID ShortChannelID
   189  		for i := 0; i < numShortChanIDs; i++ {
   190  			if err := ReadElements(bodyReader, &shortChanIDs[i]); err != nil {
   191  				return 0, nil, fmt.Errorf("unable to parse "+
   192  					"short chan ID: %v", err)
   193  			}
   194  
   195  			// We'll ensure that this short chan ID is greater than
   196  			// the last one. This is a requirement within the
   197  			// encoding, and if violated can aide us in detecting
   198  			// malicious payloads. This can only be true starting
   199  			// at the second chanID.
   200  			cid := shortChanIDs[i]
   201  			if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
   202  				return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
   203  			}
   204  			lastChanID = cid
   205  		}
   206  
   207  		return encodingType, shortChanIDs, nil
   208  
   209  	// In this encoding, we'll use zlib to decode the compressed payload.
   210  	// However, we'll pay attention to ensure that we don't open our selves
   211  	// up to a memory exhaustion attack.
   212  	case EncodingSortedZlib:
   213  		// We'll obtain an ultimately release the zlib decode mutex.
   214  		// This guards us against allocating too much memory to decode
   215  		// each instance from concurrent peers.
   216  		zlibDecodeMtx.Lock()
   217  		defer zlibDecodeMtx.Unlock()
   218  
   219  		// At this point, if there's no body remaining, then only the encoding
   220  		// type was specified, meaning that there're no further bytes to be
   221  		// parsed.
   222  		if len(queryBody) == 0 {
   223  			return encodingType, nil, nil
   224  		}
   225  
   226  		// Before we start to decode, we'll create a limit reader over
   227  		// the current reader. This will ensure that we can control how
   228  		// much memory we're allocating during the decoding process.
   229  		limitedDecompressor, err := zlib.NewReader(&io.LimitedReader{
   230  			R: bytes.NewReader(queryBody),
   231  			N: maxZlibBufSize,
   232  		})
   233  		if err != nil {
   234  			return 0, nil, fmt.Errorf("unable to create zlib reader: %v", err)
   235  		}
   236  
   237  		var (
   238  			shortChanIDs []ShortChannelID
   239  			lastChanID   ShortChannelID
   240  			i            int
   241  		)
   242  		for {
   243  			// We'll now attempt to read the next short channel ID
   244  			// encoded in the payload.
   245  			var cid ShortChannelID
   246  			err := ReadElements(limitedDecompressor, &cid)
   247  
   248  			switch {
   249  			// If we get an EOF error, then that either means we've
   250  			// read all that's contained in the buffer, or have hit
   251  			// our limit on the number of bytes we'll read. In
   252  			// either case, we'll return what we have so far.
   253  			case err == io.ErrUnexpectedEOF || err == io.EOF:
   254  				return encodingType, shortChanIDs, nil
   255  
   256  			// Otherwise, we hit some other sort of error, possibly
   257  			// an invalid payload, so we'll exit early with the
   258  			// error.
   259  			case err != nil:
   260  				return 0, nil, fmt.Errorf("unable to "+
   261  					"deflate next short chan "+
   262  					"ID: %v", err)
   263  			}
   264  
   265  			// We successfully read the next ID, so we'll collect
   266  			// that in the set of final ID's to return.
   267  			shortChanIDs = append(shortChanIDs, cid)
   268  
   269  			// Finally, we'll ensure that this short chan ID is
   270  			// greater than the last one. This is a requirement
   271  			// within the encoding, and if violated can aide us in
   272  			// detecting malicious payloads. This can only be true
   273  			// starting at the second chanID.
   274  			if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() {
   275  				return 0, nil, ErrUnsortedSIDs{lastChanID, cid}
   276  			}
   277  
   278  			lastChanID = cid
   279  			i++
   280  		}
   281  
   282  	default:
   283  		// If we've been sent an encoding type that we don't know of,
   284  		// then we'll return a parsing error as we can't continue if
   285  		// we're unable to encode them.
   286  		return 0, nil, ErrUnknownShortChanIDEncoding(encodingType)
   287  	}
   288  }
   289  
   290  // Encode serializes the target QueryShortChanIDs into the passed io.Writer
   291  // observing the protocol version specified.
   292  //
   293  // This is part of the lnwire.Message interface.
   294  func (q *QueryShortChanIDs) Encode(w *bytes.Buffer, pver uint32) error {
   295  	// First, we'll write out the chain hash.
   296  	if err := WriteBytes(w, q.ChainHash[:]); err != nil {
   297  		return err
   298  	}
   299  
   300  	// Base on our encoding type, we'll write out the set of short channel
   301  	// ID's.
   302  	err := encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort)
   303  	if err != nil {
   304  		return err
   305  	}
   306  
   307  	return WriteBytes(w, q.ExtraData)
   308  }
   309  
   310  // encodeShortChanIDs encodes the passed short channel ID's into the passed
   311  // io.Writer, respecting the specified encoding type.
   312  func encodeShortChanIDs(w *bytes.Buffer, encodingType ShortChanIDEncoding,
   313  	shortChanIDs []ShortChannelID, noSort bool) error {
   314  
   315  	// For both of the current encoding types, the channel ID's are to be
   316  	// sorted in place, so we'll do that now. The sorting is applied unless
   317  	// we were specifically requested not to for testing purposes.
   318  	if !noSort {
   319  		sort.Slice(shortChanIDs, func(i, j int) bool {
   320  			return shortChanIDs[i].ToUint64() <
   321  				shortChanIDs[j].ToUint64()
   322  		})
   323  	}
   324  
   325  	switch encodingType {
   326  
   327  	// In this encoding, we'll simply write a sorted array of encoded short
   328  	// channel ID's from the buffer.
   329  	case EncodingSortedPlain:
   330  		// First, we'll write out the number of bytes of the query
   331  		// body. We add 1 as the response will have the encoding type
   332  		// prepended to it.
   333  		numBytesBody := uint16(len(shortChanIDs)*8) + 1
   334  		if err := WriteUint16(w, numBytesBody); err != nil {
   335  			return err
   336  		}
   337  
   338  		// We'll then write out the encoding that that follows the
   339  		// actual encoded short channel ID's.
   340  		err := WriteShortChanIDEncoding(w, encodingType)
   341  		if err != nil {
   342  			return err
   343  		}
   344  
   345  		// Now that we know they're sorted, we can write out each short
   346  		// channel ID to the buffer.
   347  		for _, chanID := range shortChanIDs {
   348  			if err := WriteShortChannelID(w, chanID); err != nil {
   349  				return fmt.Errorf("unable to write short chan "+
   350  					"ID: %v", err)
   351  			}
   352  		}
   353  
   354  		return nil
   355  
   356  	// For this encoding we'll first write out a serialized version of all
   357  	// the channel ID's into a buffer, then zlib encode that. The final
   358  	// payload is what we'll write out to the passed io.Writer.
   359  	//
   360  	// TODO(roasbeef): assumes the caller knows the proper chunk size to
   361  	// pass to avoid bin-packing here
   362  	case EncodingSortedZlib:
   363  		// If we don't have anything at all to write, then we'll write
   364  		// an empty payload so we don't include things like the zlib
   365  		// header when the remote party is expecting no actual short
   366  		// channel IDs.
   367  		var compressedPayload []byte
   368  		if len(shortChanIDs) > 0 {
   369  			// We'll make a new write buffer to hold the bytes of
   370  			// shortChanIDs.
   371  			var wb bytes.Buffer
   372  
   373  			// Next, we'll write out all the channel ID's directly
   374  			// into the zlib writer, which will do compressing on
   375  			// the fly.
   376  			for _, chanID := range shortChanIDs {
   377  				err := WriteShortChannelID(&wb, chanID)
   378  				if err != nil {
   379  					return fmt.Errorf(
   380  						"unable to write short chan "+
   381  							"ID: %v", err,
   382  					)
   383  				}
   384  			}
   385  
   386  			// With shortChanIDs written into wb, we'll create a
   387  			// zlib writer and write all the compressed bytes.
   388  			var zlibBuffer bytes.Buffer
   389  			zlibWriter := zlib.NewWriter(&zlibBuffer)
   390  
   391  			if _, err := zlibWriter.Write(wb.Bytes()); err != nil {
   392  				return fmt.Errorf(
   393  					"unable to write compressed short chan"+
   394  						"ID: %w", err)
   395  			}
   396  
   397  			// Now that we've written all the elements, we'll
   398  			// ensure the compressed stream is written to the
   399  			// underlying buffer.
   400  			if err := zlibWriter.Close(); err != nil {
   401  				return fmt.Errorf("unable to finalize "+
   402  					"compression: %v", err)
   403  			}
   404  
   405  			compressedPayload = zlibBuffer.Bytes()
   406  		}
   407  
   408  		// Now that we have all the items compressed, we can compute
   409  		// what the total payload size will be. We add one to account
   410  		// for the byte to encode the type.
   411  		//
   412  		// If we don't have any actual bytes to write, then we'll end
   413  		// up emitting one byte for the length, followed by the
   414  		// encoding type, and nothing more. The spec isn't 100% clear
   415  		// in this area, but we do this as this is what most of the
   416  		// other implementations do.
   417  		numBytesBody := len(compressedPayload) + 1
   418  
   419  		// Finally, we can write out the number of bytes, the
   420  		// compression type, and finally the buffer itself.
   421  		if err := WriteUint16(w, uint16(numBytesBody)); err != nil {
   422  			return err
   423  		}
   424  		err := WriteShortChanIDEncoding(w, encodingType)
   425  		if err != nil {
   426  			return err
   427  		}
   428  
   429  		return WriteBytes(w, compressedPayload)
   430  
   431  	default:
   432  		// If we're trying to encode with an encoding type that we
   433  		// don't know of, then we'll return a parsing error as we can't
   434  		// continue if we're unable to encode them.
   435  		return ErrUnknownShortChanIDEncoding(encodingType)
   436  	}
   437  }
   438  
   439  // MsgType returns the integer uniquely identifying this message type on the
   440  // wire.
   441  //
   442  // This is part of the lnwire.Message interface.
   443  func (q *QueryShortChanIDs) MsgType() MessageType {
   444  	return MsgQueryShortChanIDs
   445  }