github.com/anacrolix/torrent@v1.61.0/tracker/udp/server/server.go (about)

     1  package udpTrackerServer
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/rand"
     7  	"encoding/binary"
     8  	"fmt"
     9  	"io"
    10  	"net"
    11  	"net/netip"
    12  
    13  	"github.com/anacrolix/dht/v2/krpc"
    14  	"github.com/anacrolix/generics"
    15  	"github.com/anacrolix/log"
    16  	"go.opentelemetry.io/otel"
    17  	"go.opentelemetry.io/otel/attribute"
    18  	"go.opentelemetry.io/otel/codes"
    19  	"go.opentelemetry.io/otel/trace"
    20  
    21  	trackerServer "github.com/anacrolix/torrent/tracker/server"
    22  	"github.com/anacrolix/torrent/tracker/udp"
    23  )
    24  
    25  type ConnectionTrackerAddr = string
    26  
    27  type ConnectionTracker interface {
    28  	Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error
    29  	Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error)
    30  }
    31  
    32  type InfoHash = [20]byte
    33  
    34  type AnnounceTracker = trackerServer.AnnounceTracker
    35  
    36  type Server struct {
    37  	ConnTracker  ConnectionTracker
    38  	SendResponse func(ctx context.Context, data []byte, addr net.Addr) (int, error)
    39  	Announce     *trackerServer.AnnounceHandler
    40  }
    41  
    42  type RequestSourceAddr = net.Addr
    43  
    44  var tracer = otel.Tracer("torrent.tracker.udp")
    45  
    46  func (me *Server) HandleRequest(
    47  	ctx context.Context,
    48  	family udp.AddrFamily,
    49  	source RequestSourceAddr,
    50  	body []byte,
    51  ) (err error) {
    52  	ctx, span := tracer.Start(ctx, "Server.HandleRequest",
    53  		trace.WithAttributes(attribute.Int("payload.len", len(body))))
    54  	defer span.End()
    55  	defer func() {
    56  		if err != nil {
    57  			span.SetStatus(codes.Error, err.Error())
    58  		}
    59  	}()
    60  	var h udp.RequestHeader
    61  	var r bytes.Reader
    62  	r.Reset(body)
    63  	err = udp.Read(&r, &h)
    64  	if err != nil {
    65  		err = fmt.Errorf("reading request header: %w", err)
    66  		return err
    67  	}
    68  	switch h.Action {
    69  	case udp.ActionConnect:
    70  		err = me.handleConnect(ctx, source, h.TransactionId)
    71  	case udp.ActionAnnounce:
    72  		err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r)
    73  	default:
    74  		err = fmt.Errorf("unimplemented")
    75  	}
    76  	if err != nil {
    77  		err = fmt.Errorf("handling action %v: %w", h.Action, err)
    78  	}
    79  	return err
    80  }
    81  
    82  func (me *Server) handleAnnounce(
    83  	ctx context.Context,
    84  	addrFamily udp.AddrFamily,
    85  	source RequestSourceAddr,
    86  	connId udp.ConnectionId,
    87  	tid udp.TransactionId,
    88  	r *bytes.Reader,
    89  ) error {
    90  	// Should we set a timeout of 10s or something for the entire response, so that we give up if a
    91  	// retry is imminent?
    92  
    93  	ok, err := me.ConnTracker.Check(ctx, source.String(), connId)
    94  	if err != nil {
    95  		err = fmt.Errorf("checking conn id: %w", err)
    96  		return err
    97  	}
    98  	if !ok {
    99  		return fmt.Errorf("incorrect connection id: %x", connId)
   100  	}
   101  	var req udp.AnnounceRequest
   102  	err = udp.Read(r, &req)
   103  	if err != nil {
   104  		return err
   105  	}
   106  	// TODO: This should be done asynchronously to responding to the announce.
   107  	announceAddr, err := netip.ParseAddrPort(source.String())
   108  	if err != nil {
   109  		err = fmt.Errorf("converting source net.Addr to AnnounceAddr: %w", err)
   110  		return err
   111  	}
   112  	opts := trackerServer.GetPeersOpts{MaxCount: generics.Some[uint](50)}
   113  	if addrFamily == udp.AddrFamilyIpv4 {
   114  		opts.MaxCount = generics.Some[uint](150)
   115  	}
   116  	res := me.Announce.Serve(ctx, req, announceAddr, opts)
   117  	if res.Err != nil {
   118  		return res.Err
   119  	}
   120  	nodeAddrs := make([]krpc.NodeAddr, 0, len(res.Peers))
   121  	for _, p := range res.Peers {
   122  		var ip net.IP
   123  		switch addrFamily {
   124  		default:
   125  			continue
   126  		case udp.AddrFamilyIpv4:
   127  			if !p.Addr().Unmap().Is4() {
   128  				continue
   129  			}
   130  			ipBuf := p.Addr().As4()
   131  			ip = ipBuf[:]
   132  		case udp.AddrFamilyIpv6:
   133  			ipBuf := p.Addr().As16()
   134  			ip = ipBuf[:]
   135  		}
   136  		nodeAddrs = append(nodeAddrs, krpc.NodeAddr{
   137  			IP:   ip[:],
   138  			Port: int(p.Port()),
   139  		})
   140  	}
   141  	var buf bytes.Buffer
   142  	err = udp.Write(&buf, udp.ResponseHeader{
   143  		Action:        udp.ActionAnnounce,
   144  		TransactionId: tid,
   145  	})
   146  	if err != nil {
   147  		return err
   148  	}
   149  	err = udp.Write(&buf, udp.AnnounceResponseHeader{
   150  		Interval: res.Interval.UnwrapOr(5 * 60),
   151  		Seeders:  res.Seeders.Value,
   152  		Leechers: res.Leechers.Value,
   153  	})
   154  	if err != nil {
   155  		return err
   156  	}
   157  	b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary()
   158  	if err != nil {
   159  		err = fmt.Errorf("marshalling compact node addrs: %w", err)
   160  		return err
   161  	}
   162  	buf.Write(b)
   163  	n, err := me.SendResponse(ctx, buf.Bytes(), source)
   164  	if err != nil {
   165  		return err
   166  	}
   167  	if n < buf.Len() {
   168  		err = io.ErrShortWrite
   169  	}
   170  	return err
   171  }
   172  
   173  func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error {
   174  	connId := randomConnectionId()
   175  	err := me.ConnTracker.Add(ctx, source.String(), connId)
   176  	if err != nil {
   177  		err = fmt.Errorf("recording conn id: %w", err)
   178  		return err
   179  	}
   180  	var buf bytes.Buffer
   181  	udp.Write(&buf, udp.ResponseHeader{
   182  		Action:        udp.ActionConnect,
   183  		TransactionId: tid,
   184  	})
   185  	udp.Write(&buf, udp.ConnectionResponse{connId})
   186  	n, err := me.SendResponse(ctx, buf.Bytes(), source)
   187  	if err != nil {
   188  		return err
   189  	}
   190  	if n < buf.Len() {
   191  		err = io.ErrShortWrite
   192  	}
   193  	return err
   194  }
   195  
   196  func randomConnectionId() udp.ConnectionId {
   197  	var b [8]byte
   198  	_, err := rand.Read(b[:])
   199  	if err != nil {
   200  		panic(err)
   201  	}
   202  	return binary.BigEndian.Uint64(b[:])
   203  }
   204  
   205  func RunSimple(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error {
   206  	ctx, cancel := context.WithCancel(ctx)
   207  	defer cancel()
   208  	var b [1500]byte
   209  	// Limit concurrent handled requests.
   210  	sem := make(chan struct{}, 1000)
   211  	for {
   212  		n, addr, err := pc.ReadFrom(b[:])
   213  		ctx, span := tracer.Start(ctx, "handle udp packet")
   214  		if err != nil {
   215  			span.SetStatus(codes.Error, err.Error())
   216  			span.End()
   217  			return err
   218  		}
   219  		select {
   220  		case <-ctx.Done():
   221  			span.SetStatus(codes.Error, err.Error())
   222  			span.End()
   223  			return ctx.Err()
   224  		default:
   225  			span.SetStatus(codes.Error, "concurrency limit reached")
   226  			span.End()
   227  			log.Levelf(log.Debug, "dropping request from %v: concurrency limit reached", addr)
   228  			continue
   229  		case sem <- struct{}{}:
   230  		}
   231  		b := append([]byte(nil), b[:n]...)
   232  		go func() {
   233  			defer span.End()
   234  			defer func() { <-sem }()
   235  			err := s.HandleRequest(ctx, family, addr, b)
   236  			if err != nil {
   237  				log.Printf("error handling %v byte request from %v: %v", n, addr, err)
   238  			}
   239  		}()
   240  	}
   241  }