github.com/decred/dcrlnd@v0.7.6/channeldb/migration/lnwire21/query_short_chan_ids.go (about) 1 package lnwire 2 3 import ( 4 "bytes" 5 "compress/zlib" 6 "fmt" 7 "io" 8 "sort" 9 "sync" 10 11 "github.com/decred/dcrd/chaincfg/chainhash" 12 ) 13 14 // ShortChanIDEncoding is an enum-like type that represents exactly how a set 15 // of short channel ID's is encoded on the wire. The set of encodings allows us 16 // to take advantage of the structure of a list of short channel ID's to 17 // achieving a high degree of compression. 18 type ShortChanIDEncoding uint8 19 20 const ( 21 // EncodingSortedPlain signals that the set of short channel ID's is 22 // encoded using the regular encoding, in a sorted order. 23 EncodingSortedPlain ShortChanIDEncoding = 0 24 25 // EncodingSortedZlib signals that the set of short channel ID's is 26 // encoded by first sorting the set of channel ID's, as then 27 // compressing them using zlib. 28 EncodingSortedZlib ShortChanIDEncoding = 1 29 ) 30 31 const ( 32 // maxZlibBufSize is the max number of bytes that we'll accept from a 33 // zlib decoding instance. We do this in order to limit the total 34 // amount of memory allocated during a decoding instance. 35 maxZlibBufSize = 67413630 36 ) 37 38 // ErrUnsortedSIDs is returned when decoding a QueryShortChannelID request whose 39 // items were not sorted. 40 type ErrUnsortedSIDs struct { 41 prevSID ShortChannelID 42 curSID ShortChannelID 43 } 44 45 // Error returns a human-readable description of the error. 46 func (e ErrUnsortedSIDs) Error() string { 47 return fmt.Sprintf("current sid: %v isn't greater than last sid: %v", 48 e.curSID, e.prevSID) 49 } 50 51 // zlibDecodeMtx is a package level mutex that we'll use in order to ensure 52 // that we'll only attempt a single zlib decoding instance at a time. This 53 // allows us to also further bound our memory usage. 54 var zlibDecodeMtx sync.Mutex 55 56 // ErrUnknownShortChanIDEncoding is a parametrized error that indicates that we 57 // came across an unknown short channel ID encoding, and therefore were unable 58 // to continue parsing. 59 func ErrUnknownShortChanIDEncoding(encoding ShortChanIDEncoding) error { 60 return fmt.Errorf("unknown short chan id encoding: %v", encoding) 61 } 62 63 // QueryShortChanIDs is a message that allows the sender to query a set of 64 // channel announcement and channel update messages that correspond to the set 65 // of encoded short channel ID's. The encoding of the short channel ID's is 66 // detailed in the query message ensuring that the receiver knows how to 67 // properly decode each encode short channel ID which may be encoded using a 68 // compression format. The receiver should respond with a series of channel 69 // announcement and channel updates, finally sending a ReplyShortChanIDsEnd 70 // message. 71 type QueryShortChanIDs struct { 72 // ChainHash denotes the target chain that we're querying for the 73 // channel ID's of. 74 ChainHash chainhash.Hash 75 76 // EncodingType is a signal to the receiver of the message that 77 // indicates exactly how the set of short channel ID's that follow have 78 // been encoded. 79 EncodingType ShortChanIDEncoding 80 81 // ShortChanIDs is a slice of decoded short channel ID's. 82 ShortChanIDs []ShortChannelID 83 84 // noSort indicates whether or not to sort the short channel ids before 85 // writing them out. 86 // 87 // NOTE: This should only be used during testing. 88 noSort bool 89 } 90 91 // NewQueryShortChanIDs creates a new QueryShortChanIDs message. 92 func NewQueryShortChanIDs(h chainhash.Hash, e ShortChanIDEncoding, 93 s []ShortChannelID) *QueryShortChanIDs { 94 95 return &QueryShortChanIDs{ 96 ChainHash: h, 97 EncodingType: e, 98 ShortChanIDs: s, 99 } 100 } 101 102 // A compile time check to ensure QueryShortChanIDs implements the 103 // lnwire.Message interface. 104 var _ Message = (*QueryShortChanIDs)(nil) 105 106 // Decode deserializes a serialized QueryShortChanIDs message stored in the 107 // passed io.Reader observing the specified protocol version. 108 // 109 // This is part of the lnwire.Message interface. 110 func (q *QueryShortChanIDs) Decode(r io.Reader, pver uint32) error { 111 err := ReadElements(r, q.ChainHash[:]) 112 if err != nil { 113 return err 114 } 115 116 q.EncodingType, q.ShortChanIDs, err = decodeShortChanIDs(r) 117 118 return err 119 } 120 121 // decodeShortChanIDs decodes a set of short channel ID's that have been 122 // encoded. The first byte of the body details how the short chan ID's were 123 // encoded. We'll use this type to govern exactly how we go about encoding the 124 // set of short channel ID's. 125 func decodeShortChanIDs(r io.Reader) (ShortChanIDEncoding, []ShortChannelID, error) { 126 // First, we'll attempt to read the number of bytes in the body of the 127 // set of encoded short channel ID's. 128 var numBytesResp uint16 129 err := ReadElements(r, &numBytesResp) 130 if err != nil { 131 return 0, nil, err 132 } 133 134 if numBytesResp == 0 { 135 return 0, nil, nil 136 } 137 138 queryBody := make([]byte, numBytesResp) 139 if _, err := io.ReadFull(r, queryBody); err != nil { 140 return 0, nil, err 141 } 142 143 // The first byte is the encoding type, so we'll extract that so we can 144 // continue our parsing. 145 encodingType := ShortChanIDEncoding(queryBody[0]) 146 147 // Before continuing, we'll snip off the first byte of the query body 148 // as that was just the encoding type. 149 queryBody = queryBody[1:] 150 151 // Otherwise, depending on the encoding type, we'll decode the encode 152 // short channel ID's in a different manner. 153 switch encodingType { 154 155 // In this encoding, we'll simply read a sort array of encoded short 156 // channel ID's from the buffer. 157 case EncodingSortedPlain: 158 // If after extracting the encoding type, the number of 159 // remaining bytes is not a whole multiple of the size of an 160 // encoded short channel ID (8 bytes), then we'll return a 161 // parsing error. 162 if len(queryBody)%8 != 0 { 163 return 0, nil, fmt.Errorf("whole number of short "+ 164 "chan ID's cannot be encoded in len=%v", 165 len(queryBody)) 166 } 167 168 // As each short channel ID is encoded as 8 bytes, we can 169 // compute the number of bytes encoded based on the size of the 170 // query body. 171 numShortChanIDs := len(queryBody) / 8 172 if numShortChanIDs == 0 { 173 return encodingType, nil, nil 174 } 175 176 // Finally, we'll read out the exact number of short channel 177 // ID's to conclude our parsing. 178 shortChanIDs := make([]ShortChannelID, numShortChanIDs) 179 bodyReader := bytes.NewReader(queryBody) 180 var lastChanID ShortChannelID 181 for i := 0; i < numShortChanIDs; i++ { 182 if err := ReadElements(bodyReader, &shortChanIDs[i]); err != nil { 183 return 0, nil, fmt.Errorf("unable to parse "+ 184 "short chan ID: %v", err) 185 } 186 187 // We'll ensure that this short chan ID is greater than 188 // the last one. This is a requirement within the 189 // encoding, and if violated can aide us in detecting 190 // malicious payloads. This can only be true starting 191 // at the second chanID. 192 cid := shortChanIDs[i] 193 if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() { 194 return 0, nil, ErrUnsortedSIDs{lastChanID, cid} 195 } 196 lastChanID = cid 197 } 198 199 return encodingType, shortChanIDs, nil 200 201 // In this encoding, we'll use zlib to decode the compressed payload. 202 // However, we'll pay attention to ensure that we don't open our selves 203 // up to a memory exhaustion attack. 204 case EncodingSortedZlib: 205 // We'll obtain an ultimately release the zlib decode mutex. 206 // This guards us against allocating too much memory to decode 207 // each instance from concurrent peers. 208 zlibDecodeMtx.Lock() 209 defer zlibDecodeMtx.Unlock() 210 211 // At this point, if there's no body remaining, then only the encoding 212 // type was specified, meaning that there're no further bytes to be 213 // parsed. 214 if len(queryBody) == 0 { 215 return encodingType, nil, nil 216 } 217 218 // Before we start to decode, we'll create a limit reader over 219 // the current reader. This will ensure that we can control how 220 // much memory we're allocating during the decoding process. 221 limitedDecompressor, err := zlib.NewReader(&io.LimitedReader{ 222 R: bytes.NewReader(queryBody), 223 N: maxZlibBufSize, 224 }) 225 if err != nil { 226 return 0, nil, fmt.Errorf("unable to create zlib reader: %v", err) 227 } 228 229 var ( 230 shortChanIDs []ShortChannelID 231 lastChanID ShortChannelID 232 i int 233 ) 234 for { 235 // We'll now attempt to read the next short channel ID 236 // encoded in the payload. 237 var cid ShortChannelID 238 err := ReadElements(limitedDecompressor, &cid) 239 240 switch { 241 // If we get an EOF error, then that either means we've 242 // read all that's contained in the buffer, or have hit 243 // our limit on the number of bytes we'll read. In 244 // either case, we'll return what we have so far. 245 case err == io.ErrUnexpectedEOF || err == io.EOF: 246 return encodingType, shortChanIDs, nil 247 248 // Otherwise, we hit some other sort of error, possibly 249 // an invalid payload, so we'll exit early with the 250 // error. 251 case err != nil: 252 return 0, nil, fmt.Errorf("unable to "+ 253 "deflate next short chan "+ 254 "ID: %v", err) 255 } 256 257 // We successfully read the next ID, so we'll collect 258 // that in the set of final ID's to return. 259 shortChanIDs = append(shortChanIDs, cid) 260 261 // Finally, we'll ensure that this short chan ID is 262 // greater than the last one. This is a requirement 263 // within the encoding, and if violated can aide us in 264 // detecting malicious payloads. This can only be true 265 // starting at the second chanID. 266 if i > 0 && cid.ToUint64() <= lastChanID.ToUint64() { 267 return 0, nil, ErrUnsortedSIDs{lastChanID, cid} 268 } 269 270 lastChanID = cid 271 i++ 272 } 273 274 default: 275 // If we've been sent an encoding type that we don't know of, 276 // then we'll return a parsing error as we can't continue if 277 // we're unable to encode them. 278 return 0, nil, ErrUnknownShortChanIDEncoding(encodingType) 279 } 280 } 281 282 // Encode serializes the target QueryShortChanIDs into the passed io.Writer 283 // observing the protocol version specified. 284 // 285 // This is part of the lnwire.Message interface. 286 func (q *QueryShortChanIDs) Encode(w io.Writer, pver uint32) error { 287 // First, we'll write out the chain hash. 288 err := WriteElements(w, q.ChainHash[:]) 289 if err != nil { 290 return err 291 } 292 293 // Base on our encoding type, we'll write out the set of short channel 294 // ID's. 295 return encodeShortChanIDs(w, q.EncodingType, q.ShortChanIDs, q.noSort) 296 } 297 298 // encodeShortChanIDs encodes the passed short channel ID's into the passed 299 // io.Writer, respecting the specified encoding type. 300 func encodeShortChanIDs(w io.Writer, encodingType ShortChanIDEncoding, 301 shortChanIDs []ShortChannelID, noSort bool) error { 302 303 // For both of the current encoding types, the channel ID's are to be 304 // sorted in place, so we'll do that now. The sorting is applied unless 305 // we were specifically requested not to for testing purposes. 306 if !noSort { 307 sort.Slice(shortChanIDs, func(i, j int) bool { 308 return shortChanIDs[i].ToUint64() < 309 shortChanIDs[j].ToUint64() 310 }) 311 } 312 313 switch encodingType { 314 315 // In this encoding, we'll simply write a sorted array of encoded short 316 // channel ID's from the buffer. 317 case EncodingSortedPlain: 318 // First, we'll write out the number of bytes of the query 319 // body. We add 1 as the response will have the encoding type 320 // prepended to it. 321 numBytesBody := uint16(len(shortChanIDs)*8) + 1 322 if err := WriteElements(w, numBytesBody); err != nil { 323 return err 324 } 325 326 // We'll then write out the encoding that that follows the 327 // actual encoded short channel ID's. 328 if err := WriteElements(w, encodingType); err != nil { 329 return err 330 } 331 332 // Now that we know they're sorted, we can write out each short 333 // channel ID to the buffer. 334 for _, chanID := range shortChanIDs { 335 if err := WriteElements(w, chanID); err != nil { 336 return fmt.Errorf("unable to write short chan "+ 337 "ID: %v", err) 338 } 339 } 340 341 return nil 342 343 // For this encoding we'll first write out a serialized version of all 344 // the channel ID's into a buffer, then zlib encode that. The final 345 // payload is what we'll write out to the passed io.Writer. 346 // 347 // TODO(roasbeef): assumes the caller knows the proper chunk size to 348 // pass to avoid bin-packing here 349 case EncodingSortedZlib: 350 // We'll make a new buffer, then wrap that with a zlib writer 351 // so we can write directly to the buffer and encode in a 352 // streaming manner. 353 var buf bytes.Buffer 354 zlibWriter := zlib.NewWriter(&buf) 355 356 // If we don't have anything at all to write, then we'll write 357 // an empty payload so we don't include things like the zlib 358 // header when the remote party is expecting no actual short 359 // channel IDs. 360 var compressedPayload []byte 361 if len(shortChanIDs) > 0 { 362 // Next, we'll write out all the channel ID's directly 363 // into the zlib writer, which will do compressing on 364 // the fly. 365 for _, chanID := range shortChanIDs { 366 err := WriteElements(zlibWriter, chanID) 367 if err != nil { 368 return fmt.Errorf("unable to write short chan "+ 369 "ID: %v", err) 370 } 371 } 372 373 // Now that we've written all the elements, we'll 374 // ensure the compressed stream is written to the 375 // underlying buffer. 376 if err := zlibWriter.Close(); err != nil { 377 return fmt.Errorf("unable to finalize "+ 378 "compression: %v", err) 379 } 380 381 compressedPayload = buf.Bytes() 382 } 383 384 // Now that we have all the items compressed, we can compute 385 // what the total payload size will be. We add one to account 386 // for the byte to encode the type. 387 // 388 // If we don't have any actual bytes to write, then we'll end 389 // up emitting one byte for the length, followed by the 390 // encoding type, and nothing more. The spec isn't 100% clear 391 // in this area, but we do this as this is what most of the 392 // other implementations do. 393 numBytesBody := len(compressedPayload) + 1 394 395 // Finally, we can write out the number of bytes, the 396 // compression type, and finally the buffer itself. 397 if err := WriteElements(w, uint16(numBytesBody)); err != nil { 398 return err 399 } 400 if err := WriteElements(w, encodingType); err != nil { 401 return err 402 } 403 404 _, err := w.Write(compressedPayload) 405 return err 406 407 default: 408 // If we're trying to encode with an encoding type that we 409 // don't know of, then we'll return a parsing error as we can't 410 // continue if we're unable to encode them. 411 return ErrUnknownShortChanIDEncoding(encodingType) 412 } 413 } 414 415 // MsgType returns the integer uniquely identifying this message type on the 416 // wire. 417 // 418 // This is part of the lnwire.Message interface. 419 func (q *QueryShortChanIDs) MsgType() MessageType { 420 return MsgQueryShortChanIDs 421 } 422 423 // MaxPayloadLength returns the maximum allowed payload size for a 424 // QueryShortChanIDs complete message observing the specified protocol version. 425 // 426 // This is part of the lnwire.Message interface. 427 func (q *QueryShortChanIDs) MaxPayloadLength(uint32) uint32 { 428 return MaxMessagePayload 429 }