github.com/anacrolix/torrent@v1.61.0/tracker/udp/client.go (about) 1 package udp 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/binary" 7 "fmt" 8 "io" 9 "net" 10 "time" 11 12 "github.com/anacrolix/dht/v2/krpc" 13 "github.com/anacrolix/log" 14 "github.com/protolambda/ctxlock" 15 ) 16 17 // Client interacts with UDP trackers via its Writer and Dispatcher. It has no knowledge of 18 // connection specifics. 19 type Client struct { 20 mu ctxlock.Lock 21 connId ConnectionId 22 connIdIssued time.Time 23 24 shouldReconnectOverride func() bool 25 26 Dispatcher *Dispatcher 27 Writer io.Writer 28 } 29 30 func (cl *Client) Announce( 31 ctx context.Context, req AnnounceRequest, opts Options, 32 // Decides whether the response body is IPv6 or IPv4, see BEP 15. 33 ipv6 func(net.Addr) bool, 34 ) ( 35 respHdr AnnounceResponseHeader, 36 // A slice of krpc.NodeAddr, likely wrapped in an appropriate unmarshalling wrapper. 37 peers AnnounceResponsePeers, 38 err error, 39 ) { 40 respBody, addr, err := cl.request(ctx, ActionAnnounce, append(mustMarshal(req), opts.Encode()...)) 41 if err != nil { 42 return 43 } 44 r := bytes.NewBuffer(respBody) 45 err = Read(r, &respHdr) 46 if err != nil { 47 err = fmt.Errorf("reading response header: %w", err) 48 return 49 } 50 if ipv6(addr) { 51 peers = &krpc.CompactIPv6NodeAddrs{} 52 } else { 53 peers = &krpc.CompactIPv4NodeAddrs{} 54 } 55 err = peers.UnmarshalBinary(r.Bytes()) 56 if err != nil { 57 err = fmt.Errorf("reading response peers: %w", err) 58 } 59 return 60 } 61 62 // There's no way to pass options in a scrape, since we don't when the request body ends. 63 func (cl *Client) Scrape( 64 ctx context.Context, ihs []InfoHash, 65 ) ( 66 out ScrapeResponse, err error, 67 ) { 68 respBody, _, err := cl.request(ctx, ActionScrape, mustMarshal(ScrapeRequest(ihs))) 69 if err != nil { 70 return 71 } 72 r := bytes.NewBuffer(respBody) 73 for r.Len() != 0 { 74 var item ScrapeInfohashResult 75 err = Read(r, &item) 76 if err != nil { 77 return 78 } 79 out = append(out, item) 80 } 81 if len(out) > len(ihs) { 82 err = fmt.Errorf("got %v results but expected %v", len(out), len(ihs)) 83 return 84 } 85 return 86 } 87 88 func (cl *Client) shouldReconnectDefault() bool { 89 return cl.connIdIssued.IsZero() || time.Since(cl.connIdIssued) >= time.Minute 90 } 91 92 func (cl *Client) shouldReconnect() bool { 93 if cl.shouldReconnectOverride != nil { 94 return cl.shouldReconnectOverride() 95 } 96 return cl.shouldReconnectDefault() 97 } 98 99 func (cl *Client) connect(ctx context.Context) (err error) { 100 if !cl.shouldReconnect() { 101 return nil 102 } 103 return cl.doConnectRoundTrip(ctx) 104 } 105 106 // This just does the connect request and updates local state if it succeeds. 107 func (cl *Client) doConnectRoundTrip(ctx context.Context) (err error) { 108 respBody, _, err := cl.request(ctx, ActionConnect, nil) 109 if err != nil { 110 return err 111 } 112 var connResp ConnectionResponse 113 err = binary.Read(bytes.NewReader(respBody), binary.BigEndian, &connResp) 114 if err != nil { 115 return 116 } 117 cl.connId = connResp.ConnectionId 118 cl.connIdIssued = time.Now() 119 //log.Printf("conn id set to %x", cl.connId) 120 return 121 } 122 123 func (cl *Client) connIdForRequest(ctx context.Context, action Action) (id ConnectionId, err error) { 124 if action == ActionConnect { 125 id = ConnectRequestConnectionId 126 return 127 } 128 err = cl.connect(ctx) 129 if err != nil { 130 return 131 } 132 id = cl.connId 133 return 134 } 135 136 func (cl *Client) writeRequest( 137 ctx context.Context, action Action, body []byte, tId TransactionId, buf *bytes.Buffer, 138 ) ( 139 err error, 140 ) { 141 var connId ConnectionId 142 if action == ActionConnect { 143 connId = ConnectRequestConnectionId 144 } else { 145 // We lock here while establishing a connection ID, and then ensuring that the request is 146 // written before allowing the connection ID to change again. This is to ensure the server 147 // doesn't assign us another ID before we've sent this request. Note that this doesn't allow 148 // for us to return if the context is cancelled while we wait to obtain a new ID. 149 err = cl.mu.LockCtx(ctx) 150 if err != nil { 151 return fmt.Errorf("locking connection id: %w", err) 152 } 153 defer cl.mu.Unlock() 154 connId, err = cl.connIdForRequest(ctx, action) 155 if err != nil { 156 return 157 } 158 } 159 buf.Reset() 160 err = Write(buf, RequestHeader{ 161 ConnectionId: connId, 162 Action: action, 163 TransactionId: tId, 164 }) 165 if err != nil { 166 panic(err) 167 } 168 buf.Write(body) 169 _, err = cl.Writer.Write(buf.Bytes()) 170 //log.Printf("sent request with conn id %x", connId) 171 return 172 } 173 174 func (cl *Client) requestWriter( 175 ctx context.Context, 176 action Action, 177 body []byte, 178 tId TransactionId, 179 ) (err error) { 180 var buf bytes.Buffer 181 for n := 0; ; n++ { 182 err = cl.writeRequest(ctx, action, body, tId, &buf) 183 if err != nil { 184 return 185 } 186 select { 187 case <-ctx.Done(): 188 return ctx.Err() 189 case <-time.After(timeout(n)): 190 } 191 } 192 } 193 194 const ConnectionIdMissmatchNul = "Connection ID missmatch.\x00" 195 196 type ErrorResponse struct { 197 Message string 198 } 199 200 func (me ErrorResponse) Error() string { 201 return fmt.Sprintf("error response: %#q", me.Message) 202 } 203 204 func (cl *Client) request( 205 ctx context.Context, 206 action Action, 207 body []byte, 208 ) (respBody []byte, addr net.Addr, err error) { 209 respChan := make(chan DispatchedResponse, 1) 210 t := cl.Dispatcher.NewTransaction(func(dr DispatchedResponse) { 211 respChan <- dr 212 }) 213 defer t.End() 214 ctx, cancel := context.WithCancel(ctx) 215 defer cancel() 216 writeErr := make(chan error, 1) 217 go func() { 218 writeErr <- cl.requestWriter(ctx, action, body, t.Id()) 219 }() 220 select { 221 case dr := <-respChan: 222 if dr.Header.Action == action { 223 respBody = dr.Body 224 addr = dr.Addr 225 } else if dr.Header.Action == ActionError { 226 // udp://tracker.torrent.eu.org:451/announce frequently returns "Connection ID 227 // missmatch.\x00" 228 stringBody := string(dr.Body) 229 err = ErrorResponse{Message: stringBody} 230 if stringBody == ConnectionIdMissmatchNul { 231 err = log.WithLevel(log.Debug, err) 232 } 233 // Force a reconnection. Probably any error is worth doing this for, but the one we're 234 // specifically interested in is ConnectionIdMissmatchNul. 235 cl.connIdIssued = time.Time{} 236 } else { 237 err = fmt.Errorf("unexpected response action %v", dr.Header.Action) 238 } 239 case err = <-writeErr: 240 err = fmt.Errorf("write error: %w", err) 241 case <-ctx.Done(): 242 err = context.Cause(ctx) 243 } 244 return 245 }