github.com/anacrolix/torrent@v1.61.0/tracker/udp/client.go (about)

     1  package udp
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/binary"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"time"
    11  
    12  	"github.com/anacrolix/dht/v2/krpc"
    13  	"github.com/anacrolix/log"
    14  	"github.com/protolambda/ctxlock"
    15  )
    16  
    17  // Client interacts with UDP trackers via its Writer and Dispatcher. It has no knowledge of
    18  // connection specifics.
    19  type Client struct {
    20  	mu           ctxlock.Lock
    21  	connId       ConnectionId
    22  	connIdIssued time.Time
    23  
    24  	shouldReconnectOverride func() bool
    25  
    26  	Dispatcher *Dispatcher
    27  	Writer     io.Writer
    28  }
    29  
    30  func (cl *Client) Announce(
    31  	ctx context.Context, req AnnounceRequest, opts Options,
    32  	// Decides whether the response body is IPv6 or IPv4, see BEP 15.
    33  	ipv6 func(net.Addr) bool,
    34  ) (
    35  	respHdr AnnounceResponseHeader,
    36  	// A slice of krpc.NodeAddr, likely wrapped in an appropriate unmarshalling wrapper.
    37  	peers AnnounceResponsePeers,
    38  	err error,
    39  ) {
    40  	respBody, addr, err := cl.request(ctx, ActionAnnounce, append(mustMarshal(req), opts.Encode()...))
    41  	if err != nil {
    42  		return
    43  	}
    44  	r := bytes.NewBuffer(respBody)
    45  	err = Read(r, &respHdr)
    46  	if err != nil {
    47  		err = fmt.Errorf("reading response header: %w", err)
    48  		return
    49  	}
    50  	if ipv6(addr) {
    51  		peers = &krpc.CompactIPv6NodeAddrs{}
    52  	} else {
    53  		peers = &krpc.CompactIPv4NodeAddrs{}
    54  	}
    55  	err = peers.UnmarshalBinary(r.Bytes())
    56  	if err != nil {
    57  		err = fmt.Errorf("reading response peers: %w", err)
    58  	}
    59  	return
    60  }
    61  
    62  // There's no way to pass options in a scrape, since we don't when the request body ends.
    63  func (cl *Client) Scrape(
    64  	ctx context.Context, ihs []InfoHash,
    65  ) (
    66  	out ScrapeResponse, err error,
    67  ) {
    68  	respBody, _, err := cl.request(ctx, ActionScrape, mustMarshal(ScrapeRequest(ihs)))
    69  	if err != nil {
    70  		return
    71  	}
    72  	r := bytes.NewBuffer(respBody)
    73  	for r.Len() != 0 {
    74  		var item ScrapeInfohashResult
    75  		err = Read(r, &item)
    76  		if err != nil {
    77  			return
    78  		}
    79  		out = append(out, item)
    80  	}
    81  	if len(out) > len(ihs) {
    82  		err = fmt.Errorf("got %v results but expected %v", len(out), len(ihs))
    83  		return
    84  	}
    85  	return
    86  }
    87  
    88  func (cl *Client) shouldReconnectDefault() bool {
    89  	return cl.connIdIssued.IsZero() || time.Since(cl.connIdIssued) >= time.Minute
    90  }
    91  
    92  func (cl *Client) shouldReconnect() bool {
    93  	if cl.shouldReconnectOverride != nil {
    94  		return cl.shouldReconnectOverride()
    95  	}
    96  	return cl.shouldReconnectDefault()
    97  }
    98  
    99  func (cl *Client) connect(ctx context.Context) (err error) {
   100  	if !cl.shouldReconnect() {
   101  		return nil
   102  	}
   103  	return cl.doConnectRoundTrip(ctx)
   104  }
   105  
   106  // This just does the connect request and updates local state if it succeeds.
   107  func (cl *Client) doConnectRoundTrip(ctx context.Context) (err error) {
   108  	respBody, _, err := cl.request(ctx, ActionConnect, nil)
   109  	if err != nil {
   110  		return err
   111  	}
   112  	var connResp ConnectionResponse
   113  	err = binary.Read(bytes.NewReader(respBody), binary.BigEndian, &connResp)
   114  	if err != nil {
   115  		return
   116  	}
   117  	cl.connId = connResp.ConnectionId
   118  	cl.connIdIssued = time.Now()
   119  	//log.Printf("conn id set to %x", cl.connId)
   120  	return
   121  }
   122  
   123  func (cl *Client) connIdForRequest(ctx context.Context, action Action) (id ConnectionId, err error) {
   124  	if action == ActionConnect {
   125  		id = ConnectRequestConnectionId
   126  		return
   127  	}
   128  	err = cl.connect(ctx)
   129  	if err != nil {
   130  		return
   131  	}
   132  	id = cl.connId
   133  	return
   134  }
   135  
   136  func (cl *Client) writeRequest(
   137  	ctx context.Context, action Action, body []byte, tId TransactionId, buf *bytes.Buffer,
   138  ) (
   139  	err error,
   140  ) {
   141  	var connId ConnectionId
   142  	if action == ActionConnect {
   143  		connId = ConnectRequestConnectionId
   144  	} else {
   145  		// We lock here while establishing a connection ID, and then ensuring that the request is
   146  		// written before allowing the connection ID to change again. This is to ensure the server
   147  		// doesn't assign us another ID before we've sent this request. Note that this doesn't allow
   148  		// for us to return if the context is cancelled while we wait to obtain a new ID.
   149  		err = cl.mu.LockCtx(ctx)
   150  		if err != nil {
   151  			return fmt.Errorf("locking connection id: %w", err)
   152  		}
   153  		defer cl.mu.Unlock()
   154  		connId, err = cl.connIdForRequest(ctx, action)
   155  		if err != nil {
   156  			return
   157  		}
   158  	}
   159  	buf.Reset()
   160  	err = Write(buf, RequestHeader{
   161  		ConnectionId:  connId,
   162  		Action:        action,
   163  		TransactionId: tId,
   164  	})
   165  	if err != nil {
   166  		panic(err)
   167  	}
   168  	buf.Write(body)
   169  	_, err = cl.Writer.Write(buf.Bytes())
   170  	//log.Printf("sent request with conn id %x", connId)
   171  	return
   172  }
   173  
   174  func (cl *Client) requestWriter(
   175  	ctx context.Context,
   176  	action Action,
   177  	body []byte,
   178  	tId TransactionId,
   179  ) (err error) {
   180  	var buf bytes.Buffer
   181  	for n := 0; ; n++ {
   182  		err = cl.writeRequest(ctx, action, body, tId, &buf)
   183  		if err != nil {
   184  			return
   185  		}
   186  		select {
   187  		case <-ctx.Done():
   188  			return ctx.Err()
   189  		case <-time.After(timeout(n)):
   190  		}
   191  	}
   192  }
   193  
   194  const ConnectionIdMissmatchNul = "Connection ID missmatch.\x00"
   195  
   196  type ErrorResponse struct {
   197  	Message string
   198  }
   199  
   200  func (me ErrorResponse) Error() string {
   201  	return fmt.Sprintf("error response: %#q", me.Message)
   202  }
   203  
   204  func (cl *Client) request(
   205  	ctx context.Context,
   206  	action Action,
   207  	body []byte,
   208  ) (respBody []byte, addr net.Addr, err error) {
   209  	respChan := make(chan DispatchedResponse, 1)
   210  	t := cl.Dispatcher.NewTransaction(func(dr DispatchedResponse) {
   211  		respChan <- dr
   212  	})
   213  	defer t.End()
   214  	ctx, cancel := context.WithCancel(ctx)
   215  	defer cancel()
   216  	writeErr := make(chan error, 1)
   217  	go func() {
   218  		writeErr <- cl.requestWriter(ctx, action, body, t.Id())
   219  	}()
   220  	select {
   221  	case dr := <-respChan:
   222  		if dr.Header.Action == action {
   223  			respBody = dr.Body
   224  			addr = dr.Addr
   225  		} else if dr.Header.Action == ActionError {
   226  			// udp://tracker.torrent.eu.org:451/announce frequently returns "Connection ID
   227  			// missmatch.\x00"
   228  			stringBody := string(dr.Body)
   229  			err = ErrorResponse{Message: stringBody}
   230  			if stringBody == ConnectionIdMissmatchNul {
   231  				err = log.WithLevel(log.Debug, err)
   232  			}
   233  			// Force a reconnection. Probably any error is worth doing this for, but the one we're
   234  			// specifically interested in is ConnectionIdMissmatchNul.
   235  			cl.connIdIssued = time.Time{}
   236  		} else {
   237  			err = fmt.Errorf("unexpected response action %v", dr.Header.Action)
   238  		}
   239  	case err = <-writeErr:
   240  		err = fmt.Errorf("write error: %w", err)
   241  	case <-ctx.Done():
   242  		err = context.Cause(ctx)
   243  	}
   244  	return
   245  }