github.com/anacrolix/torrent@v1.61.0/peer_protocol/msg.go (about) 1 package peer_protocol 2 3 import ( 4 "bufio" 5 "bytes" 6 "encoding" 7 "encoding/binary" 8 "fmt" 9 "io" 10 ) 11 12 // This is a lazy union representing all the possible fields for messages. Go doesn't have ADTs, and 13 // I didn't choose to use type-assertions. Fields are ordered to minimize struct size and padding. 14 type Message struct { 15 PiecesRoot [32]byte 16 Piece []byte 17 Bitfield []bool 18 ExtendedPayload []byte 19 Hashes [][32]byte 20 Index, Begin, Length Integer 21 BaseLayer Integer 22 ProofLayers Integer 23 Port uint16 24 Type MessageType 25 ExtendedID ExtensionNumber 26 Keepalive bool 27 } 28 29 var _ interface { 30 encoding.BinaryUnmarshaler 31 encoding.BinaryMarshaler 32 } = (*Message)(nil) 33 34 func MakeCancelMessage(piece, offset, length Integer) Message { 35 return Message{ 36 Type: Cancel, 37 Index: piece, 38 Begin: offset, 39 Length: length, 40 } 41 } 42 43 func (msg Message) RequestSpec() (ret RequestSpec) { 44 return RequestSpec{ 45 msg.Index, 46 msg.Begin, 47 func() Integer { 48 if msg.Type == Piece { 49 return Integer(len(msg.Piece)) 50 } else { 51 return msg.Length 52 } 53 }(), 54 } 55 } 56 57 func (msg Message) MustMarshalBinary() []byte { 58 b, err := msg.MarshalBinary() 59 if err != nil { 60 panic(err) 61 } 62 return b 63 } 64 65 type MessageWriter interface { 66 io.ByteWriter 67 io.Writer 68 } 69 70 func (msg *Message) writeHashCommon(buf MessageWriter) (err error) { 71 if _, err = buf.Write(msg.PiecesRoot[:]); err != nil { 72 return 73 } 74 for _, d := range []Integer{msg.BaseLayer, msg.Index, msg.Length, msg.ProofLayers} { 75 if err = binary.Write(buf, binary.BigEndian, d); err != nil { 76 return 77 } 78 } 79 return nil 80 } 81 82 func (msg *Message) writePayloadTo(buf MessageWriter) (err error) { 83 if !msg.Keepalive { 84 err = buf.WriteByte(byte(msg.Type)) 85 if err != nil { 86 return 87 } 88 switch msg.Type { 89 case Choke, Unchoke, Interested, NotInterested, HaveAll, HaveNone: 90 case Have, AllowedFast, Suggest: 91 err = binary.Write(buf, binary.BigEndian, msg.Index) 92 case Request, Cancel, Reject: 93 for _, i := range []Integer{msg.Index, msg.Begin, msg.Length} { 94 err = binary.Write(buf, binary.BigEndian, i) 95 if err != nil { 96 break 97 } 98 } 99 case Bitfield: 100 _, err = buf.Write(marshalBitfield(msg.Bitfield)) 101 case Piece: 102 for _, i := range []Integer{msg.Index, msg.Begin} { 103 err = binary.Write(buf, binary.BigEndian, i) 104 if err != nil { 105 return 106 } 107 } 108 n, err := buf.Write(msg.Piece) 109 if err != nil { 110 break 111 } 112 if n != len(msg.Piece) { 113 panic(n) 114 } 115 case Extended: 116 err = buf.WriteByte(byte(msg.ExtendedID)) 117 if err != nil { 118 return 119 } 120 _, err = buf.Write(msg.ExtendedPayload) 121 case Port: 122 err = binary.Write(buf, binary.BigEndian, msg.Port) 123 case HashRequest, HashReject: 124 err = msg.writeHashCommon(buf) 125 case Hashes: 126 err = msg.writeHashCommon(buf) 127 if err != nil { 128 return 129 } 130 for _, h := range msg.Hashes { 131 if _, err = buf.Write(h[:]); err != nil { 132 return 133 } 134 } 135 default: 136 err = fmt.Errorf("unknown message type: %v", msg.Type) 137 } 138 } 139 return 140 } 141 142 func (msg *Message) WriteTo(w MessageWriter) (err error) { 143 length, err := msg.getPayloadLength() 144 if err != nil { 145 return 146 } 147 err = binary.Write(w, binary.BigEndian, length) 148 if err != nil { 149 return 150 } 151 return msg.writePayloadTo(w) 152 } 153 154 func (msg *Message) getPayloadLength() (length Integer, err error) { 155 var lw lengthWriter 156 err = msg.writePayloadTo(&lw) 157 length = lw.n 158 return 159 } 160 161 func (msg Message) MarshalBinary() (data []byte, err error) { 162 // It might look like you could have a pool of buffers and preallocate the message length 163 // prefix, but because we have to return []byte, it becomes non-trivial to make this fast. You 164 // will need a benchmark. 165 var buf bytes.Buffer 166 err = msg.WriteTo(&buf) 167 data = buf.Bytes() 168 return 169 } 170 171 func marshalBitfield(bf []bool) (b []byte) { 172 b = make([]byte, (len(bf)+7)/8) 173 for i, have := range bf { 174 if !have { 175 continue 176 } 177 c := b[i/8] 178 c |= 1 << uint(7-i%8) 179 b[i/8] = c 180 } 181 return 182 } 183 184 func (me *Message) UnmarshalBinary(b []byte) error { 185 d := Decoder{ 186 R: bufio.NewReader(bytes.NewReader(b)), 187 } 188 err := d.Decode(me) 189 if err != nil { 190 return err 191 } 192 if d.R.Buffered() != 0 { 193 return fmt.Errorf("%d trailing bytes", d.R.Buffered()) 194 } 195 return nil 196 } 197 198 type lengthWriter struct { 199 n Integer 200 } 201 202 func (l *lengthWriter) WriteByte(c byte) error { 203 l.n++ 204 return nil 205 } 206 207 func (l *lengthWriter) Write(p []byte) (n int, err error) { 208 n = len(p) 209 l.n += Integer(n) 210 return 211 }