github.com/anacrolix/torrent@v1.61.0/peer_protocol/decoder.go (about)

     1  package peer_protocol
     2  
     3  import (
     4  	"bufio"
     5  	"encoding/binary"
     6  	"fmt"
     7  	"io"
     8  	"sync"
     9  
    10  	g "github.com/anacrolix/generics"
    11  	"github.com/pkg/errors"
    12  )
    13  
    14  type Decoder struct {
    15  	R *bufio.Reader
    16  	// This must return *[]byte where the slices can fit data for piece messages. I think we store
    17  	// *[]byte in the pool to avoid an extra allocation every time we put the slice back into the
    18  	// pool. The chunk size should not change for the life of the decoder.
    19  	Pool      *sync.Pool
    20  	MaxLength Integer // TODO: Should this include the length header or not?
    21  }
    22  
    23  // This limits reads to the length of a message, returning io.EOF when the end of the message bytes
    24  // are reached. If you aren't expecting io.EOF, you should probably wrap it with expectReader.
    25  type decodeReader struct {
    26  	lr io.LimitedReader
    27  	br *bufio.Reader
    28  }
    29  
    30  func (dr *decodeReader) Init(r *bufio.Reader, length int64) {
    31  	dr.lr.R = r
    32  	dr.lr.N = length
    33  	dr.br = r
    34  }
    35  
    36  func (dr *decodeReader) ReadByte() (c byte, err error) {
    37  	if dr.lr.N <= 0 {
    38  		err = io.EOF
    39  		return
    40  	}
    41  	c, err = dr.br.ReadByte()
    42  	if err == nil {
    43  		dr.lr.N--
    44  	}
    45  	return
    46  }
    47  
    48  func (dr *decodeReader) Read(p []byte) (n int, err error) {
    49  	n, err = dr.lr.Read(p)
    50  	if dr.lr.N != 0 && err == io.EOF {
    51  		err = io.ErrUnexpectedEOF
    52  	}
    53  	return
    54  }
    55  
    56  func (dr *decodeReader) UnreadLength() int64 {
    57  	return dr.lr.N
    58  }
    59  
    60  // This expects reads to have enough bytes. io.EOF is mapped to io.ErrUnexpectedEOF. It's probably
    61  // not a good idea to pass this to functions that expect to read until the end of something, because
    62  // they will probably expect io.EOF.
    63  type expectReader struct {
    64  	dr *decodeReader
    65  }
    66  
    67  func (er expectReader) ReadByte() (c byte, err error) {
    68  	c, err = er.dr.ReadByte()
    69  	if err == io.EOF {
    70  		err = io.ErrUnexpectedEOF
    71  	}
    72  	return
    73  }
    74  
    75  func (er expectReader) Read(p []byte) (n int, err error) {
    76  	n, err = er.dr.Read(p)
    77  	if err == io.EOF {
    78  		err = io.ErrUnexpectedEOF
    79  	}
    80  	return
    81  }
    82  
    83  func (er expectReader) UnreadLength() int64 {
    84  	return er.dr.UnreadLength()
    85  }
    86  
    87  // io.EOF is returned if the source terminates cleanly on a message boundary. TODO: Raise error
    88  // level for protocol errors, log them, or add an error type.
    89  func (d *Decoder) Decode(msg *Message) (err error) {
    90  	var dr decodeReader
    91  	{
    92  		var length Integer
    93  		err = length.Read(d.R)
    94  		if err != nil {
    95  			return fmt.Errorf("reading message length: %w", err)
    96  		}
    97  		if length > d.MaxLength {
    98  			return errors.New("message too long")
    99  		}
   100  		if length == 0 {
   101  			msg.Keepalive = true
   102  			return
   103  		}
   104  		dr.Init(d.R, int64(length))
   105  	}
   106  	r := expectReader{&dr}
   107  	c, err := r.ReadByte()
   108  	if err != nil {
   109  		return
   110  	}
   111  	msg.Type = MessageType(c)
   112  	err = readMessageAfterType(msg, &r, d.Pool)
   113  	if err != nil {
   114  		err = fmt.Errorf("reading fields for message type %v: %w", msg.Type, err)
   115  		return
   116  	}
   117  	if r.UnreadLength() != 0 {
   118  		err = fmt.Errorf("%v unused bytes in message type %v", r.UnreadLength(), msg.Type)
   119  	}
   120  	return
   121  }
   122  
   123  func readMessageAfterType(msg *Message, r *expectReader, piecePool *sync.Pool) (err error) {
   124  	switch msg.Type {
   125  	case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
   126  	case Have, AllowedFast, Suggest:
   127  		err = msg.Index.Read(r)
   128  	case Request, Cancel, Reject:
   129  		for _, data := range []*Integer{&msg.Index, &msg.Begin, &msg.Length} {
   130  			err = data.Read(r)
   131  			if err != nil {
   132  				break
   133  			}
   134  		}
   135  	case Bitfield:
   136  		b := make([]byte, r.UnreadLength())
   137  		_, err = io.ReadFull(r, b)
   138  		msg.Bitfield = unmarshalBitfield(b)
   139  	case Piece:
   140  		for _, pi := range []*Integer{&msg.Index, &msg.Begin} {
   141  			err = pi.Read(r)
   142  			if err != nil {
   143  				return
   144  			}
   145  		}
   146  		dataLen := r.UnreadLength()
   147  		if piecePool == nil {
   148  			msg.Piece = make([]byte, dataLen)
   149  		} else {
   150  			msg.Piece = *piecePool.Get().(*[]byte)
   151  			if int64(cap(msg.Piece)) < dataLen {
   152  				return errors.New("piece data longer than expected")
   153  			}
   154  			msg.Piece = msg.Piece[:dataLen]
   155  		}
   156  		_, err = io.ReadFull(r, msg.Piece)
   157  	case Extended:
   158  		var b byte
   159  		b, err = r.ReadByte()
   160  		if err != nil {
   161  			break
   162  		}
   163  		msg.ExtendedID = ExtensionNumber(b)
   164  		msg.ExtendedPayload = make([]byte, r.UnreadLength())
   165  		_, err = io.ReadFull(r, msg.ExtendedPayload)
   166  	case Port:
   167  		err = binary.Read(r, binary.BigEndian, &msg.Port)
   168  	case HashRequest, HashReject:
   169  		err = readHashRequest(r, msg)
   170  	case Hashes:
   171  		err = readHashRequest(r, msg)
   172  		numHashes := (r.UnreadLength() + 31) / 32
   173  		g.MakeSliceWithCap(&msg.Hashes, numHashes)
   174  		for range numHashes {
   175  			var oneHash [32]byte
   176  			_, err = io.ReadFull(r, oneHash[:])
   177  			if err != nil {
   178  				err = fmt.Errorf("error while reading hashes: %w", err)
   179  				return
   180  			}
   181  			msg.Hashes = append(msg.Hashes, oneHash)
   182  		}
   183  	default:
   184  		err = errors.New("unhandled message type")
   185  	}
   186  	return
   187  }
   188  
   189  func readHashRequest(r io.Reader, msg *Message) (err error) {
   190  	_, err = io.ReadFull(r, msg.PiecesRoot[:])
   191  	if err != nil {
   192  		return
   193  	}
   194  	return readSeq(r, &msg.BaseLayer, &msg.Index, &msg.Length, &msg.ProofLayers)
   195  }
   196  
   197  func readSeq(r io.Reader, data ...any) (err error) {
   198  	for _, d := range data {
   199  		err = binary.Read(r, binary.BigEndian, d)
   200  		if err != nil {
   201  			return
   202  		}
   203  	}
   204  	return
   205  }
   206  
   207  func unmarshalBitfield(b []byte) (bf []bool) {
   208  	for _, c := range b {
   209  		for i := 7; i >= 0; i-- {
   210  			bf = append(bf, (c>>uint(i))&1 == 1)
   211  		}
   212  	}
   213  	return
   214  }