github.com/anacrolix/torrent@v1.61.0/wstracker.go (about)

     1  package torrent
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"net"
     7  	netHttp "net/http"
     8  	"net/url"
     9  	"sync"
    10  
    11  	"github.com/anacrolix/log"
    12  	"github.com/gorilla/websocket"
    13  	"github.com/pion/webrtc/v4"
    14  
    15  	"github.com/anacrolix/torrent/tracker"
    16  	httpTracker "github.com/anacrolix/torrent/tracker/http"
    17  	"github.com/anacrolix/torrent/webtorrent"
    18  )
    19  
    20  type websocketTrackerStatus struct {
    21  	url url.URL
    22  	tc  *webtorrent.TrackerClient
    23  }
    24  
    25  func (me websocketTrackerStatus) statusLine() string {
    26  	return fmt.Sprintf("%+v", me.tc.Stats())
    27  }
    28  
    29  func (me websocketTrackerStatus) URL() *url.URL {
    30  	return &me.url
    31  }
    32  
    33  func (me websocketTrackerStatus) Stop() {
    34  }
    35  
    36  type refCountedWebtorrentTrackerClient struct {
    37  	webtorrent.TrackerClient
    38  	refCount int
    39  }
    40  
    41  type websocketTrackers struct {
    42  	PeerId                     [20]byte
    43  	Logger                     log.Logger
    44  	GetAnnounceRequest         func(event tracker.AnnounceEvent, infoHash [20]byte) (tracker.AnnounceRequest, error)
    45  	OnConn                     func(webtorrent.DataChannelConn, webtorrent.DataChannelContext)
    46  	mu                         sync.Mutex
    47  	clients                    map[string]*refCountedWebtorrentTrackerClient
    48  	Proxy                      httpTracker.ProxyFunc
    49  	DialContext                func(ctx context.Context, network, addr string) (net.Conn, error)
    50  	WebsocketTrackerHttpHeader func() netHttp.Header
    51  	ICEServers                 []webrtc.ICEServer
    52  	callbacks                  *Callbacks
    53  }
    54  
    55  func (me *websocketTrackers) Get(url string, infoHash [20]byte) (*webtorrent.TrackerClient, func()) {
    56  	me.mu.Lock()
    57  	defer me.mu.Unlock()
    58  	value, ok := me.clients[url]
    59  	if !ok {
    60  		dialer := &websocket.Dialer{
    61  			Proxy:            me.Proxy,
    62  			NetDialContext:   me.DialContext,
    63  			HandshakeTimeout: websocket.DefaultDialer.HandshakeTimeout,
    64  		}
    65  		value = &refCountedWebtorrentTrackerClient{
    66  			TrackerClient: webtorrent.TrackerClient{
    67  				Dialer:             dialer,
    68  				Url:                url,
    69  				GetAnnounceRequest: me.GetAnnounceRequest,
    70  				PeerId:             me.PeerId,
    71  				OnConn:             me.OnConn,
    72  				Logger: me.Logger.WithText(
    73  					func(m log.Msg) string {
    74  						return fmt.Sprintf("tracker client for %q: %v", url, m)
    75  					},
    76  				),
    77  				WebsocketTrackerHttpHeader: me.WebsocketTrackerHttpHeader,
    78  				ICEServers:                 me.ICEServers,
    79  				OnConnected: func(err error) {
    80  					for _, cb := range me.callbacks.StatusUpdated {
    81  						cb(StatusUpdatedEvent{
    82  							Event: TrackerConnected,
    83  							Url:   url,
    84  							Error: err,
    85  						})
    86  					}
    87  				},
    88  				OnDisconnected: func(err error) {
    89  					for _, cb := range me.callbacks.StatusUpdated {
    90  						cb(StatusUpdatedEvent{
    91  							Event: TrackerDisconnected,
    92  							Url:   url,
    93  							Error: err,
    94  						})
    95  					}
    96  				},
    97  				OnAnnounceSuccessful: func(ih string) {
    98  					for _, cb := range me.callbacks.StatusUpdated {
    99  						cb(StatusUpdatedEvent{
   100  							Event:    TrackerAnnounceSuccessful,
   101  							Url:      url,
   102  							InfoHash: ih,
   103  						})
   104  					}
   105  				},
   106  				OnAnnounceError: func(ih string, err error) {
   107  					for _, cb := range me.callbacks.StatusUpdated {
   108  						cb(StatusUpdatedEvent{
   109  							Event:    TrackerAnnounceError,
   110  							Url:      url,
   111  							Error:    err,
   112  							InfoHash: ih,
   113  						})
   114  					}
   115  				},
   116  			},
   117  		}
   118  		value.TrackerClient.Start(func(err error) {
   119  			if err != nil {
   120  				me.Logger.Printf("error running tracker client for %q: %v", url, err)
   121  			}
   122  		})
   123  		if me.clients == nil {
   124  			me.clients = make(map[string]*refCountedWebtorrentTrackerClient)
   125  		}
   126  		me.clients[url] = value
   127  	}
   128  	value.refCount++
   129  	return &value.TrackerClient, func() {
   130  		me.mu.Lock()
   131  		defer me.mu.Unlock()
   132  		value.TrackerClient.CloseOffersForInfohash(infoHash)
   133  		value.refCount--
   134  		if value.refCount == 0 {
   135  			value.TrackerClient.Close()
   136  			delete(me.clients, url)
   137  		}
   138  	}
   139  }