github.com/anacrolix/torrent@v1.61.0/tracker/udp/server/server.go (about) 1 package udpTrackerServer 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/rand" 7 "encoding/binary" 8 "fmt" 9 "io" 10 "net" 11 "net/netip" 12 13 "github.com/anacrolix/dht/v2/krpc" 14 "github.com/anacrolix/generics" 15 "github.com/anacrolix/log" 16 "go.opentelemetry.io/otel" 17 "go.opentelemetry.io/otel/attribute" 18 "go.opentelemetry.io/otel/codes" 19 "go.opentelemetry.io/otel/trace" 20 21 trackerServer "github.com/anacrolix/torrent/tracker/server" 22 "github.com/anacrolix/torrent/tracker/udp" 23 ) 24 25 type ConnectionTrackerAddr = string 26 27 type ConnectionTracker interface { 28 Add(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) error 29 Check(ctx context.Context, addr ConnectionTrackerAddr, id udp.ConnectionId) (bool, error) 30 } 31 32 type InfoHash = [20]byte 33 34 type AnnounceTracker = trackerServer.AnnounceTracker 35 36 type Server struct { 37 ConnTracker ConnectionTracker 38 SendResponse func(ctx context.Context, data []byte, addr net.Addr) (int, error) 39 Announce *trackerServer.AnnounceHandler 40 } 41 42 type RequestSourceAddr = net.Addr 43 44 var tracer = otel.Tracer("torrent.tracker.udp") 45 46 func (me *Server) HandleRequest( 47 ctx context.Context, 48 family udp.AddrFamily, 49 source RequestSourceAddr, 50 body []byte, 51 ) (err error) { 52 ctx, span := tracer.Start(ctx, "Server.HandleRequest", 53 trace.WithAttributes(attribute.Int("payload.len", len(body)))) 54 defer span.End() 55 defer func() { 56 if err != nil { 57 span.SetStatus(codes.Error, err.Error()) 58 } 59 }() 60 var h udp.RequestHeader 61 var r bytes.Reader 62 r.Reset(body) 63 err = udp.Read(&r, &h) 64 if err != nil { 65 err = fmt.Errorf("reading request header: %w", err) 66 return err 67 } 68 switch h.Action { 69 case udp.ActionConnect: 70 err = me.handleConnect(ctx, source, h.TransactionId) 71 case udp.ActionAnnounce: 72 err = me.handleAnnounce(ctx, family, source, h.ConnectionId, h.TransactionId, &r) 73 default: 74 err = fmt.Errorf("unimplemented") 75 } 76 if err != nil { 77 err = fmt.Errorf("handling action %v: %w", h.Action, err) 78 } 79 return err 80 } 81 82 func (me *Server) handleAnnounce( 83 ctx context.Context, 84 addrFamily udp.AddrFamily, 85 source RequestSourceAddr, 86 connId udp.ConnectionId, 87 tid udp.TransactionId, 88 r *bytes.Reader, 89 ) error { 90 // Should we set a timeout of 10s or something for the entire response, so that we give up if a 91 // retry is imminent? 92 93 ok, err := me.ConnTracker.Check(ctx, source.String(), connId) 94 if err != nil { 95 err = fmt.Errorf("checking conn id: %w", err) 96 return err 97 } 98 if !ok { 99 return fmt.Errorf("incorrect connection id: %x", connId) 100 } 101 var req udp.AnnounceRequest 102 err = udp.Read(r, &req) 103 if err != nil { 104 return err 105 } 106 // TODO: This should be done asynchronously to responding to the announce. 107 announceAddr, err := netip.ParseAddrPort(source.String()) 108 if err != nil { 109 err = fmt.Errorf("converting source net.Addr to AnnounceAddr: %w", err) 110 return err 111 } 112 opts := trackerServer.GetPeersOpts{MaxCount: generics.Some[uint](50)} 113 if addrFamily == udp.AddrFamilyIpv4 { 114 opts.MaxCount = generics.Some[uint](150) 115 } 116 res := me.Announce.Serve(ctx, req, announceAddr, opts) 117 if res.Err != nil { 118 return res.Err 119 } 120 nodeAddrs := make([]krpc.NodeAddr, 0, len(res.Peers)) 121 for _, p := range res.Peers { 122 var ip net.IP 123 switch addrFamily { 124 default: 125 continue 126 case udp.AddrFamilyIpv4: 127 if !p.Addr().Unmap().Is4() { 128 continue 129 } 130 ipBuf := p.Addr().As4() 131 ip = ipBuf[:] 132 case udp.AddrFamilyIpv6: 133 ipBuf := p.Addr().As16() 134 ip = ipBuf[:] 135 } 136 nodeAddrs = append(nodeAddrs, krpc.NodeAddr{ 137 IP: ip[:], 138 Port: int(p.Port()), 139 }) 140 } 141 var buf bytes.Buffer 142 err = udp.Write(&buf, udp.ResponseHeader{ 143 Action: udp.ActionAnnounce, 144 TransactionId: tid, 145 }) 146 if err != nil { 147 return err 148 } 149 err = udp.Write(&buf, udp.AnnounceResponseHeader{ 150 Interval: res.Interval.UnwrapOr(5 * 60), 151 Seeders: res.Seeders.Value, 152 Leechers: res.Leechers.Value, 153 }) 154 if err != nil { 155 return err 156 } 157 b, err := udp.GetNodeAddrsCompactMarshaler(nodeAddrs, addrFamily).MarshalBinary() 158 if err != nil { 159 err = fmt.Errorf("marshalling compact node addrs: %w", err) 160 return err 161 } 162 buf.Write(b) 163 n, err := me.SendResponse(ctx, buf.Bytes(), source) 164 if err != nil { 165 return err 166 } 167 if n < buf.Len() { 168 err = io.ErrShortWrite 169 } 170 return err 171 } 172 173 func (me *Server) handleConnect(ctx context.Context, source RequestSourceAddr, tid udp.TransactionId) error { 174 connId := randomConnectionId() 175 err := me.ConnTracker.Add(ctx, source.String(), connId) 176 if err != nil { 177 err = fmt.Errorf("recording conn id: %w", err) 178 return err 179 } 180 var buf bytes.Buffer 181 udp.Write(&buf, udp.ResponseHeader{ 182 Action: udp.ActionConnect, 183 TransactionId: tid, 184 }) 185 udp.Write(&buf, udp.ConnectionResponse{connId}) 186 n, err := me.SendResponse(ctx, buf.Bytes(), source) 187 if err != nil { 188 return err 189 } 190 if n < buf.Len() { 191 err = io.ErrShortWrite 192 } 193 return err 194 } 195 196 func randomConnectionId() udp.ConnectionId { 197 var b [8]byte 198 _, err := rand.Read(b[:]) 199 if err != nil { 200 panic(err) 201 } 202 return binary.BigEndian.Uint64(b[:]) 203 } 204 205 func RunSimple(ctx context.Context, s *Server, pc net.PacketConn, family udp.AddrFamily) error { 206 ctx, cancel := context.WithCancel(ctx) 207 defer cancel() 208 var b [1500]byte 209 // Limit concurrent handled requests. 210 sem := make(chan struct{}, 1000) 211 for { 212 n, addr, err := pc.ReadFrom(b[:]) 213 ctx, span := tracer.Start(ctx, "handle udp packet") 214 if err != nil { 215 span.SetStatus(codes.Error, err.Error()) 216 span.End() 217 return err 218 } 219 select { 220 case <-ctx.Done(): 221 span.SetStatus(codes.Error, err.Error()) 222 span.End() 223 return ctx.Err() 224 default: 225 span.SetStatus(codes.Error, "concurrency limit reached") 226 span.End() 227 log.Levelf(log.Debug, "dropping request from %v: concurrency limit reached", addr) 228 continue 229 case sem <- struct{}{}: 230 } 231 b := append([]byte(nil), b[:n]...) 232 go func() { 233 defer span.End() 234 defer func() { <-sem }() 235 err := s.HandleRequest(ctx, family, addr, b) 236 if err != nil { 237 log.Printf("error handling %v byte request from %v: %v", n, addr, err) 238 } 239 }() 240 } 241 }