github.com/decred/dcrlnd@v0.7.6/lnwire/query_short_chan_ids.go (about) 1 package lnwire 2 3 import ( 4 "bytes" 5 "compress/zlib" 6 "fmt" 7 "io" 8 "sort" 9 "sync" 10 11 "github.com/decred/dcrd/chaincfg/chainhash" 12 ) 13 14 // ShortChanIDEncoding is an enum-like type that represents exactly how a set 15 // of short channel ID's is encoded on the wire. The set of encodings allows us 16 // to take advantage of the structure of a list of short channel ID's to 17 // achieving a high degree of compression. 18 type ShortChanIDEncoding uint8 19 20 const ( 21 // EncodingSortedPlain signals that the set of short channel ID's is 22 // encoded using the regular encoding, in a sorted order. 23 EncodingSortedPlain ShortChanIDEncoding = 0 24 25 // EncodingSortedZlib signals that the set of short channel ID's is 26 // encoded by first sorting the set of channel ID's, as then 27 // compressing them using zlib. 28 EncodingSortedZlib ShortChanIDEncoding = 1 29 ) 30 31 const ( 32 // maxZlibBufSize is the max number of bytes that we'll accept from a 33 // zlib decoding instance. We do this in order to limit the total 34 // amount of memory allocated during a decoding instance. 35 maxZlibBufSize = 67413630 36 ) 37 38 // ErrUnsortedSIDs is returned when decoding a QueryShortChannelID request whose 39 // items were not sorted. 40 type ErrUnsortedSIDs struct { 41 prevSID ShortChannelID 42 curSID ShortChannelID 43 } 44 45 // Error returns a human-readable description of the error. 46 func (e ErrUnsortedSIDs) Error() string { 47 return fmt.Sprintf("current sid: %v isn't greater than last sid: %v", 48 e.curSID, e.prevSID) 49 } 50 51 // zlibDecodeMtx is a package level mutex that we'll use in order to ensure 52 // that we'll only attempt a single zlib decoding instance at a time. This 53 // allows us to also further bound our memory usage. 54 var zlibDecodeMtx sync.Mutex 55 56 // ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we 57 // came across an unknown short channel ID encoding, and therefore were unable 58 // to continue parsing. 59 func ErrUnknownShortChanIDEncoding(encoding ShortChanIDEncoding) error { 60 return fmt.Errorf("unknown short chan id encoding: %v", encoding) 61 } 62 63 // QueryShortChanIDs is a message that allows the sender to query a set of 64 // channel announcement and channel update messages that correspond to the set 65 // of encoded short channel ID's. The encoding of the short channel ID's is 66 // detailed in the query message ensuring that the receiver knows how to 67 // properly decode each encode short channel ID which may be encoded using a 68 // compression format. The receiver should respond with a series of channel 69 // announcement and channel updates, finally sending a ReplyShortChanIDsEnd 70 // message. 71 type QueryShortChanIDs struct { 72 // ChainHash denotes the target chain that we're querying for the 73 // channel ID's of. 74 ChainHash chainhash.Hash 75 76 // EncodingType is a signal to the receiver of the message that 77 // indicates exactly how the set of short channel ID's that follow have 78 // been encoded. 79 EncodingType ShortChanIDEncoding 80 81 // ShortChanIDs is a slice of decoded short channel ID's. 82 ShortChanIDs []ShortChannelID 83 84 // ExtraData is the set of data that was appended to this message to 85 // fill out the full maximum transport message size. These fields can 86 // be used to specify optional data such as custom TLV fields. 87 ExtraData ExtraOpaqueData 88 89 // noSort indicates whether or not to sort the short channel ids before 90 // writing them out. 91 // 92 // NOTE: This should only be used during testing. 93 noSort bool 94 } 95 96 // NewQueryShortChanIDs creates a new QueryShortChanIDs message. 97 func NewQueryShortChanIDs(h chainhash.Hash, e ShortChanIDEncoding, 98 s []ShortChannelID) *QueryShortChanIDs { 99 100 return &QueryShortChanIDs{ 101 ChainHash: h, 102 EncodingType: e, 103 ShortChanIDs: s, 104 } 105 } 106 107 // A compile time check to ensure QueryShortChanIDs implements the 108 // lnwire.Message interface. 109 var _ Message = (*QueryShortChanIDs)(nil) 110 111 // Decode deserializes a serialized QueryShortChanIDs message stored in the 112 // passed io.Reader observing the specified protocol version. 113 // 114 // This is part of the lnwire.Message interface. 115 func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error { 116 err := ReadElements(r, q.ChainHash[:]) 117 if err != nil { 118 return err 119 } 120 121 q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r) 122 if err != nil { 123 return err 124 } 125 126 return q.ExtraData.Decode(r) 127 } 128 129 // decodeShortChanIDs decodes a set of short channel ID's that have been 130 // encoded. The first byte of the body details how the short chan ID's were 131 // encoded. We'll use this type to govern exactly how we go about encoding the 132 // set of short channel ID's. 133 func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, error) { 134 // First, we'll attempt to read the number of bytes in the body of the 135 // set of encoded short channel ID's. 136 var numBytesResp uint16 137 err := ReadElements(r, &numBytesResp) 138 if err != nil { 139 return 0, nil, err 140 } 141 142 if numBytesResp == 0 { 143 return 0, nil, nil 144 } 145 146 queryBody := make([]byte, numBytesResp) 147 if _, err := io.ReadFull(r, queryBody); err != nil { 148 return 0, nil, err 149 } 150 151 // The first byte is the encoding type, so we'll extract that so we can 152 // continue our parsing. 153 encodingType := ShortChanIDEncoding(queryBody[0]) 154 155 // Before continuing, we'll snip off the first byte of the query body 156 // as that was just the encoding type. 157 queryBody = queryBody[1:] 158 159 // Otherwise, depending on the encoding type, we'll decode the encode 160 // short channel ID's in a different manner. 161 switch encodingType { 162 163 // In this encoding, we'll simply read a sort array of encoded short 164 // channel ID's from the buffer. 165 case EncodingSortedPlain: 166 // If after extracting the encoding type, the number of 167 // remaining bytes is not a whole multiple of the size of an 168 // encoded short channel ID (8 bytes), then we'll return a 169 // parsing error. 170 if len(queryBody)%8 != 0 { 171 return 0, nil, fmt.Errorf("whole number of short "+ 172 "chan ID's cannot be encoded in len=%v", 173 len(queryBody)) 174 } 175 176 // As each short channel ID is encoded as 8 bytes, we can 177 // compute the number of bytes encoded based on the size of the 178 // query body. 179 numShortChanIDs := len(queryBody) / 8 180 if numShortChanIDs == 0 { 181 return encodingType, nil, nil 182 } 183 184 // Finally, we'll read out the exact number of short channel 185 // ID's to conclude our parsing. 186 shortChanIDs := make([]ShortChannelID, numShortChanIDs) 187 bodyReader := bytes.NewReader(queryBody) 188 var lastChanID ShortChannelID 189 for i := 0; i < numShortChanIDs; i++ { 190 if err := ReadElements(bodyReader, &shortChanIDs[i]); err != nil { 191 return 0, nil, fmt.Errorf("unable to parse "+ 192 "short chan ID: %v", err) 193 } 194 195 // We'll ensure that this short chan ID is greater than 196 // the last one. This is a requirement within the 197 // encoding, and if violated can aide us in detecting 198 // malicious payloads. This can only be true starting 199 // at the second chanID. 200 cid := shortChanIDs[i] 201 if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() { 202 return 0, nil, ErrUnsortedSIDs{lastChanID, cid} 203 } 204 lastChanID = cid 205 } 206 207 return encodingType, shortChanIDs, nil 208 209 // In this encoding, we'll use zlib to decode the compressed payload. 210 // However, we'll pay attention to ensure that we don't open our selves 211 // up to a memory exhaustion attack. 212 case EncodingSortedZlib: 213 // We'll obtain an ultimately release the zlib decode mutex. 214 // This guards us against allocating too much memory to decode 215 // each instance from concurrent peers. 216 zlibDecodeMtx.Lock() 217 defer zlibDecodeMtx.Unlock() 218 219 // At this point, if there's no body remaining, then only the encoding 220 // type was specified, meaning that there're no further bytes to be 221 // parsed. 222 if len(queryBody) == 0 { 223 return encodingType, nil, nil 224 } 225 226 // Before we start to decode, we'll create a limit reader over 227 // the current reader. This will ensure that we can control how 228 // much memory we're allocating during the decoding process. 229 limitedDecompressor, err := zlib.NewReader(&io.LimitedReader{ 230 R: bytes.NewReader(queryBody), 231 N: maxZlibBufSize, 232 }) 233 if err != nil { 234 return 0, nil, fmt.Errorf("unable to create zlib reader: %v", err) 235 } 236 237 var ( 238 shortChanIDs []ShortChannelID 239 lastChanID ShortChannelID 240 i int 241 ) 242 for { 243 // We'll now attempt to read the next short channel ID 244 // encoded in the payload. 245 var cid ShortChannelID 246 err := ReadElements(limitedDecompressor, &cid) 247 248 switch { 249 // If we get an EOF error, then that either means we've 250 // read all that's contained in the buffer, or have hit 251 // our limit on the number of bytes we'll read. In 252 // either case, we'll return what we have so far. 253 case err == io.ErrUnexpectedEOF || err == io.EOF: 254 return encodingType, shortChanIDs, nil 255 256 // Otherwise, we hit some other sort of error, possibly 257 // an invalid payload, so we'll exit early with the 258 // error. 259 case err != nil: 260 return 0, nil, fmt.Errorf("unable to "+ 261 "deflate next short chan "+ 262 "ID: %v", err) 263 } 264 265 // We successfully read the next ID, so we'll collect 266 // that in the set of final ID's to return. 267 shortChanIDs = append(shortChanIDs, cid) 268 269 // Finally, we'll ensure that this short chan ID is 270 // greater than the last one. This is a requirement 271 // within the encoding, and if violated can aide us in 272 // detecting malicious payloads. This can only be true 273 // starting at the second chanID. 274 if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() { 275 return 0, nil, ErrUnsortedSIDs{lastChanID, cid} 276 } 277 278 lastChanID = cid 279 i++ 280 } 281 282 default: 283 // If we've been sent an encoding type that we don't know of, 284 // then we'll return a parsing error as we can't continue if 285 // we're unable to encode them. 286 return 0, nil, ErrUnknownShortChanIDEncoding(encodingType) 287 } 288 } 289 290 // Encode serializes the target QueryShortChanIDs into the passed io.Writer 291 // observing the protocol version specified. 292 // 293 // This is part of the lnwire.Message interface. 294 func (q *QueryShortChanIDs) Encode(w *bytes.Buffer, pver uint32) error { 295 // First, we'll write out the chain hash. 296 if err := WriteBytes(w, q.ChainHash[:]); err != nil { 297 return err 298 } 299 300 // Base on our encoding type, we'll write out the set of short channel 301 // ID's. 302 err := encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort) 303 if err != nil { 304 return err 305 } 306 307 return WriteBytes(w, q.ExtraData) 308 } 309 310 // encodeShortChanIDs encodes the passed short channel ID's into the passed 311 // io.Writer, respecting the specified encoding type. 312 func encodeShortChanIDs(w *bytes.Buffer, encodingType ShortChanIDEncoding, 313 shortChanIDs []ShortChannelID, noSort bool) error { 314 315 // For both of the current encoding types, the channel ID's are to be 316 // sorted in place, so we'll do that now. The sorting is applied unless 317 // we were specifically requested not to for testing purposes. 318 if !noSort { 319 sort.Slice(shortChanIDs, func(i, j int) bool { 320 return shortChanIDs[i].ToUint64() < 321 shortChanIDs[j].ToUint64() 322 }) 323 } 324 325 switch encodingType { 326 327 // In this encoding, we'll simply write a sorted array of encoded short 328 // channel ID's from the buffer. 329 case EncodingSortedPlain: 330 // First, we'll write out the number of bytes of the query 331 // body. We add 1 as the response will have the encoding type 332 // prepended to it. 333 numBytesBody := uint16(len(shortChanIDs)*8) + 1 334 if err := WriteUint16(w, numBytesBody); err != nil { 335 return err 336 } 337 338 // We'll then write out the encoding that that follows the 339 // actual encoded short channel ID's. 340 err := WriteShortChanIDEncoding(w, encodingType) 341 if err != nil { 342 return err 343 } 344 345 // Now that we know they're sorted, we can write out each short 346 // channel ID to the buffer. 347 for _, chanID := range shortChanIDs { 348 if err := WriteShortChannelID(w, chanID); err != nil { 349 return fmt.Errorf("unable to write short chan "+ 350 "ID: %v", err) 351 } 352 } 353 354 return nil 355 356 // For this encoding we'll first write out a serialized version of all 357 // the channel ID's into a buffer, then zlib encode that. The final 358 // payload is what we'll write out to the passed io.Writer. 359 // 360 // TODO(roasbeef): assumes the caller knows the proper chunk size to 361 // pass to avoid bin-packing here 362 case EncodingSortedZlib: 363 // If we don't have anything at all to write, then we'll write 364 // an empty payload so we don't include things like the zlib 365 // header when the remote party is expecting no actual short 366 // channel IDs. 367 var compressedPayload []byte 368 if len(shortChanIDs) > 0 { 369 // We'll make a new write buffer to hold the bytes of 370 // shortChanIDs. 371 var wb bytes.Buffer 372 373 // Next, we'll write out all the channel ID's directly 374 // into the zlib writer, which will do compressing on 375 // the fly. 376 for _, chanID := range shortChanIDs { 377 err := WriteShortChannelID(&wb, chanID) 378 if err != nil { 379 return fmt.Errorf( 380 "unable to write short chan "+ 381 "ID: %v", err, 382 ) 383 } 384 } 385 386 // With shortChanIDs written into wb, we'll create a 387 // zlib writer and write all the compressed bytes. 388 var zlibBuffer bytes.Buffer 389 zlibWriter := zlib.NewWriter(&zlibBuffer) 390 391 if _, err := zlibWriter.Write(wb.Bytes()); err != nil { 392 return fmt.Errorf( 393 "unable to write compressed short chan"+ 394 "ID: %w", err) 395 } 396 397 // Now that we've written all the elements, we'll 398 // ensure the compressed stream is written to the 399 // underlying buffer. 400 if err := zlibWriter.Close(); err != nil { 401 return fmt.Errorf("unable to finalize "+ 402 "compression: %v", err) 403 } 404 405 compressedPayload = zlibBuffer.Bytes() 406 } 407 408 // Now that we have all the items compressed, we can compute 409 // what the total payload size will be. We add one to account 410 // for the byte to encode the type. 411 // 412 // If we don't have any actual bytes to write, then we'll end 413 // up emitting one byte for the length, followed by the 414 // encoding type, and nothing more. The spec isn't 100% clear 415 // in this area, but we do this as this is what most of the 416 // other implementations do. 417 numBytesBody := len(compressedPayload) + 1 418 419 // Finally, we can write out the number of bytes, the 420 // compression type, and finally the buffer itself. 421 if err := WriteUint16(w, uint16(numBytesBody)); err != nil { 422 return err 423 } 424 err := WriteShortChanIDEncoding(w, encodingType) 425 if err != nil { 426 return err 427 } 428 429 return WriteBytes(w, compressedPayload) 430 431 default: 432 // If we're trying to encode with an encoding type that we 433 // don't know of, then we'll return a parsing error as we can't 434 // continue if we're unable to encode them. 435 return ErrUnknownShortChanIDEncoding(encodingType) 436 } 437 } 438 439 // MsgType returns the integer uniquely identifying this message type on the 440 // wire. 441 // 442 // This is part of the lnwire.Message interface. 443 func (q *QueryShortChanIDs) MsgType() MessageType { 444 return MsgQueryShortChanIDs 445 }