github.com/anacrolix/torrent@v1.61.0/peer_protocol/msg.go (about)

     1  package peer_protocol
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"encoding"
     7  	"encoding/binary"
     8  	"fmt"
     9  	"io"
    10  )
    11  
    12  // This is a lazy union representing all the possible fields for messages. Go doesn't have ADTs, and
    13  // I didn't choose to use type-assertions. Fields are ordered to minimize struct size and padding.
    14  type Message struct {
    15  	PiecesRoot           [32]byte
    16  	Piece                []byte
    17  	Bitfield             []bool
    18  	ExtendedPayload      []byte
    19  	Hashes               [][32]byte
    20  	Index, Begin, Length Integer
    21  	BaseLayer            Integer
    22  	ProofLayers          Integer
    23  	Port                 uint16
    24  	Type                 MessageType
    25  	ExtendedID           ExtensionNumber
    26  	Keepalive            bool
    27  }
    28  
    29  var _ interface {
    30  	encoding.BinaryUnmarshaler
    31  	encoding.BinaryMarshaler
    32  } = (*Message)(nil)
    33  
    34  func MakeCancelMessage(piece, offset, length Integer) Message {
    35  	return Message{
    36  		Type:   Cancel,
    37  		Index:  piece,
    38  		Begin:  offset,
    39  		Length: length,
    40  	}
    41  }
    42  
    43  func (msg Message) RequestSpec() (ret RequestSpec) {
    44  	return RequestSpec{
    45  		msg.Index,
    46  		msg.Begin,
    47  		func() Integer {
    48  			if msg.Type == Piece {
    49  				return Integer(len(msg.Piece))
    50  			} else {
    51  				return msg.Length
    52  			}
    53  		}(),
    54  	}
    55  }
    56  
    57  func (msg Message) MustMarshalBinary() []byte {
    58  	b, err := msg.MarshalBinary()
    59  	if err != nil {
    60  		panic(err)
    61  	}
    62  	return b
    63  }
    64  
    65  type MessageWriter interface {
    66  	io.ByteWriter
    67  	io.Writer
    68  }
    69  
    70  func (msg *Message) writeHashCommon(buf MessageWriter) (err error) {
    71  	if _, err = buf.Write(msg.PiecesRoot[:]); err != nil {
    72  		return
    73  	}
    74  	for _, d := range []Integer{msg.BaseLayer, msg.Index, msg.Length, msg.ProofLayers} {
    75  		if err = binary.Write(buf, binary.BigEndian, d); err != nil {
    76  			return
    77  		}
    78  	}
    79  	return nil
    80  }
    81  
    82  func (msg *Message) writePayloadTo(buf MessageWriter) (err error) {
    83  	if !msg.Keepalive {
    84  		err = buf.WriteByte(byte(msg.Type))
    85  		if err != nil {
    86  			return
    87  		}
    88  		switch msg.Type {
    89  		case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone:
    90  		case Have, AllowedFast, Suggest:
    91  			err = binary.Write(buf, binary.BigEndian, msg.Index)
    92  		case Request, Cancel, Reject:
    93  			for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} {
    94  				err = binary.Write(buf, binary.BigEndian, i)
    95  				if err != nil {
    96  					break
    97  				}
    98  			}
    99  		case Bitfield:
   100  			_, err = buf.Write(marshalBitfield(msg.Bitfield))
   101  		case Piece:
   102  			for _, i := range []Integer{msg.Index, msg.Begin} {
   103  				err = binary.Write(buf, binary.BigEndian, i)
   104  				if err != nil {
   105  					return
   106  				}
   107  			}
   108  			n, err := buf.Write(msg.Piece)
   109  			if err != nil {
   110  				break
   111  			}
   112  			if n != len(msg.Piece) {
   113  				panic(n)
   114  			}
   115  		case Extended:
   116  			err = buf.WriteByte(byte(msg.ExtendedID))
   117  			if err != nil {
   118  				return
   119  			}
   120  			_, err = buf.Write(msg.ExtendedPayload)
   121  		case Port:
   122  			err = binary.Write(buf, binary.BigEndian, msg.Port)
   123  		case HashRequest, HashReject:
   124  			err = msg.writeHashCommon(buf)
   125  		case Hashes:
   126  			err = msg.writeHashCommon(buf)
   127  			if err != nil {
   128  				return
   129  			}
   130  			for _, h := range msg.Hashes {
   131  				if _, err = buf.Write(h[:]); err != nil {
   132  					return
   133  				}
   134  			}
   135  		default:
   136  			err = fmt.Errorf("unknown message type: %v", msg.Type)
   137  		}
   138  	}
   139  	return
   140  }
   141  
   142  func (msg *Message) WriteTo(w MessageWriter) (err error) {
   143  	length, err := msg.getPayloadLength()
   144  	if err != nil {
   145  		return
   146  	}
   147  	err = binary.Write(w, binary.BigEndian, length)
   148  	if err != nil {
   149  		return
   150  	}
   151  	return msg.writePayloadTo(w)
   152  }
   153  
   154  func (msg *Message) getPayloadLength() (length Integer, err error) {
   155  	var lw lengthWriter
   156  	err = msg.writePayloadTo(&lw)
   157  	length = lw.n
   158  	return
   159  }
   160  
   161  func (msg Message) MarshalBinary() (data []byte, err error) {
   162  	// It might look like you could have a pool of buffers and preallocate the message length
   163  	// prefix, but because we have to return []byte, it becomes non-trivial to make this fast. You
   164  	// will need a benchmark.
   165  	var buf bytes.Buffer
   166  	err = msg.WriteTo(&buf)
   167  	data = buf.Bytes()
   168  	return
   169  }
   170  
   171  func marshalBitfield(bf []bool) (b []byte) {
   172  	b = make([]byte, (len(bf)+7)/8)
   173  	for i, have := range bf {
   174  		if !have {
   175  			continue
   176  		}
   177  		c := b[i/8]
   178  		c |= 1 << uint(7-i%8)
   179  		b[i/8] = c
   180  	}
   181  	return
   182  }
   183  
   184  func (me *Message) UnmarshalBinary(b []byte) error {
   185  	d := Decoder{
   186  		R: bufio.NewReader(bytes.NewReader(b)),
   187  	}
   188  	err := d.Decode(me)
   189  	if err != nil {
   190  		return err
   191  	}
   192  	if d.R.Buffered() != 0 {
   193  		return fmt.Errorf("%d trailing bytes", d.R.Buffered())
   194  	}
   195  	return nil
   196  }
   197  
   198  type lengthWriter struct {
   199  	n Integer
   200  }
   201  
   202  func (l *lengthWriter) WriteByte(c byte) error {
   203  	l.n++
   204  	return nil
   205  }
   206  
   207  func (l *lengthWriter) Write(p []byte) (n int, err error) {
   208  	n = len(p)
   209  	l.n += Integer(n)
   210  	return
   211  }