github.com/anacrolix/torrent@v1.61.0/peer_protocol/decoder.go (about) 1 package peer_protocol 2 3 import ( 4 "bufio" 5 "encoding/binary" 6 "fmt" 7 "io" 8 "sync" 9 10 g "github.com/anacrolix/generics" 11 "github.com/pkg/errors" 12 ) 13 14 type Decoder struct { 15 R *bufio.Reader 16 // This must return *[]byte where the slices can fit data for piece messages. I think we store 17 // *[]byte in the pool to avoid an extra allocation every time we put the slice back into the 18 // pool. The chunk size should not change for the life of the decoder. 19 Pool *sync.Pool 20 MaxLength Integer // TODO: Should this include the length header or not? 21 } 22 23 // This limits reads to the length of a message, returning io.EOF when the end of the message bytes 24 // are reached. If you aren't expecting io.EOF, you should probably wrap it with expectReader. 25 type decodeReader struct { 26 lr io.LimitedReader 27 br *bufio.Reader 28 } 29 30 func (dr *decodeReader) Init(r *bufio.Reader, length int64) { 31 dr.lr.R = r 32 dr.lr.N = length 33 dr.br = r 34 } 35 36 func (dr *decodeReader) ReadByte() (c byte, err error) { 37 if dr.lr.N <= 0 { 38 err = io.EOF 39 return 40 } 41 c, err = dr.br.ReadByte() 42 if err == nil { 43 dr.lr.N-- 44 } 45 return 46 } 47 48 func (dr *decodeReader) Read(p []byte) (n int, err error) { 49 n, err = dr.lr.Read(p) 50 if dr.lr.N != 0 && err == io.EOF { 51 err = io.ErrUnexpectedEOF 52 } 53 return 54 } 55 56 func (dr *decodeReader) UnreadLength() int64 { 57 return dr.lr.N 58 } 59 60 // This expects reads to have enough bytes. io.EOF is mapped to io.ErrUnexpectedEOF. It's probably 61 // not a good idea to pass this to functions that expect to read until the end of something, because 62 // they will probably expect io.EOF. 63 type expectReader struct { 64 dr *decodeReader 65 } 66 67 func (er expectReader) ReadByte() (c byte, err error) { 68 c, err = er.dr.ReadByte() 69 if err == io.EOF { 70 err = io.ErrUnexpectedEOF 71 } 72 return 73 } 74 75 func (er expectReader) Read(p []byte) (n int, err error) { 76 n, err = er.dr.Read(p) 77 if err == io.EOF { 78 err = io.ErrUnexpectedEOF 79 } 80 return 81 } 82 83 func (er expectReader) UnreadLength() int64 { 84 return er.dr.UnreadLength() 85 } 86 87 // io.EOF is returned if the source terminates cleanly on a message boundary. TODO: Raise error 88 // level for protocol errors, log them, or add an error type. 89 func (d *Decoder) Decode(msg *Message) (err error) { 90 var dr decodeReader 91 { 92 var length Integer 93 err = length.Read(d.R) 94 if err != nil { 95 return fmt.Errorf("reading message length: %w", err) 96 } 97 if length > d.MaxLength { 98 return errors.New("message too long") 99 } 100 if length == 0 { 101 msg.Keepalive = true 102 return 103 } 104 dr.Init(d.R, int64(length)) 105 } 106 r := expectReader{&dr} 107 c, err := r.ReadByte() 108 if err != nil { 109 return 110 } 111 msg.Type = MessageType(c) 112 err = readMessageAfterType(msg, &r, d.Pool) 113 if err != nil { 114 err = fmt.Errorf("reading fields for message type %v: %w", msg.Type, err) 115 return 116 } 117 if r.UnreadLength() != 0 { 118 err = fmt.Errorf("%v unused bytes in message type %v", r.UnreadLength(), msg.Type) 119 } 120 return 121 } 122 123 func readMessageAfterType(msg *Message, r *expectReader, piecePool *sync.Pool) (err error) { 124 switch msg.Type { 125 case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone: 126 case Have, AllowedFast, Suggest: 127 err = msg.Index.Read(r) 128 case Request, Cancel, Reject: 129 for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} { 130 err = data.Read(r) 131 if err != nil { 132 break 133 } 134 } 135 case Bitfield: 136 b := make([]byte, r.UnreadLength()) 137 _, err = io.ReadFull(r, b) 138 msg.Bitfield = unmarshalBitfield(b) 139 case Piece: 140 for _, pi := range []*Integer{&msg.Index, &msg.Begin} { 141 err = pi.Read(r) 142 if err != nil { 143 return 144 } 145 } 146 dataLen := r.UnreadLength() 147 if piecePool == nil { 148 msg.Piece = make([]byte, dataLen) 149 } else { 150 msg.Piece = *piecePool.Get().(*[]byte) 151 if int64(cap(msg.Piece)) < dataLen { 152 return errors.New("piece data longer than expected") 153 } 154 msg.Piece = msg.Piece[:dataLen] 155 } 156 _, err = io.ReadFull(r, msg.Piece) 157 case Extended: 158 var b byte 159 b, err = r.ReadByte() 160 if err != nil { 161 break 162 } 163 msg.ExtendedID = ExtensionNumber(b) 164 msg.ExtendedPayload = make([]byte, r.UnreadLength()) 165 _, err = io.ReadFull(r, msg.ExtendedPayload) 166 case Port: 167 err = binary.Read(r, binary.BigEndian, &msg.Port) 168 case HashRequest, HashReject: 169 err = readHashRequest(r, msg) 170 case Hashes: 171 err = readHashRequest(r, msg) 172 numHashes := (r.UnreadLength() + 31) / 32 173 g.MakeSliceWithCap(&msg.Hashes, numHashes) 174 for range numHashes { 175 var oneHash [32]byte 176 _, err = io.ReadFull(r, oneHash[:]) 177 if err != nil { 178 err = fmt.Errorf("error while reading hashes: %w", err) 179 return 180 } 181 msg.Hashes = append(msg.Hashes, oneHash) 182 } 183 default: 184 err = errors.New("unhandled message type") 185 } 186 return 187 } 188 189 func readHashRequest(r io.Reader, msg *Message) (err error) { 190 _, err = io.ReadFull(r, msg.PiecesRoot[:]) 191 if err != nil { 192 return 193 } 194 return readSeq(r, &msg.BaseLayer, &msg.Index, &msg.Length, &msg.ProofLayers) 195 } 196 197 func readSeq(r io.Reader, data ...any) (err error) { 198 for _, d := range data { 199 err = binary.Read(r, binary.BigEndian, d) 200 if err != nil { 201 return 202 } 203 } 204 return 205 } 206 207 func unmarshalBitfield(b []byte) (bf []bool) { 208 for _, c := range b { 209 for i := 7; i >= 0; i-- { 210 bf = append(bf, (c>>uint(i))&1 == 1) 211 } 212 } 213 return 214 }