github.com/anacrolix/torrent@v1.61.0/webtorrent/transport.go (about)

     1  package webtorrent
     2  
     3  import (
     4  	"context"
     5  	"expvar"
     6  	"fmt"
     7  	"os"
     8  	"strconv"
     9  	"sync"
    10  	"time"
    11  
    12  	g "github.com/anacrolix/generics"
    13  	"github.com/anacrolix/log"
    14  	"github.com/anacrolix/missinggo/v2/panicif"
    15  	"github.com/anacrolix/missinggo/v2/pproffd"
    16  	"github.com/pion/datachannel"
    17  	"github.com/pion/webrtc/v4"
    18  	"go.opentelemetry.io/otel"
    19  	"go.opentelemetry.io/otel/attribute"
    20  	"go.opentelemetry.io/otel/codes"
    21  	"go.opentelemetry.io/otel/trace"
    22  )
    23  
    24  const (
    25  	dataChannelLabel = "webrtc-datachannel"
    26  )
    27  
    28  var (
    29  	metrics = expvar.NewMap("webtorrent")
    30  	api     = func() *webrtc.API {
    31  		// Enable the detach API (since it's non-standard but more idiomatic).
    32  		s.DetachDataChannels()
    33  		return webrtc.NewAPI(webrtc.WithSettingEngine(s))
    34  	}()
    35  	newPeerConnectionMu sync.Mutex
    36  )
    37  
    38  type wrappedPeerConnection struct {
    39  	*webrtc.PeerConnection
    40  	closeMu sync.Mutex
    41  	pproffd.CloseWrapper
    42  	span trace.Span
    43  	ctx  context.Context
    44  
    45  	onCloseHandler func()
    46  }
    47  
    48  func (me *wrappedPeerConnection) Close() error {
    49  	me.closeMu.Lock()
    50  	defer me.closeMu.Unlock()
    51  
    52  	me.onClose()
    53  
    54  	err := me.CloseWrapper.Close()
    55  	me.span.End()
    56  	return err
    57  }
    58  
    59  func (me *wrappedPeerConnection) OnClose(f func()) {
    60  	me.closeMu.Lock()
    61  	defer me.closeMu.Unlock()
    62  	me.onCloseHandler = f
    63  }
    64  
    65  func (me *wrappedPeerConnection) onClose() {
    66  	handler := me.onCloseHandler
    67  
    68  	if handler != nil {
    69  		handler()
    70  	}
    71  }
    72  
    73  func newPeerConnection(logger log.Logger, iceServers []webrtc.ICEServer) (*wrappedPeerConnection, error) {
    74  	newPeerConnectionMu.Lock()
    75  	defer newPeerConnectionMu.Unlock()
    76  	ctx, span := otel.Tracer(tracerName).Start(context.Background(), "PeerConnection")
    77  
    78  	pcConfig := webrtc.Configuration{ICEServers: iceServers}
    79  
    80  	pc, err := api.NewPeerConnection(pcConfig)
    81  	if err != nil {
    82  		span.SetStatus(codes.Error, err.Error())
    83  		span.RecordError(err)
    84  		span.End()
    85  		return nil, err
    86  	}
    87  	wpc := &wrappedPeerConnection{
    88  		PeerConnection: pc,
    89  		CloseWrapper:   pproffd.NewCloseWrapper(pc),
    90  		ctx:            ctx,
    91  		span:           span,
    92  	}
    93  	// If the state change handler intends to call Close, it should call it on the wrapper.
    94  	wpc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) {
    95  		logger.Levelf(log.Debug, "webrtc PeerConnection state changed to %v", state)
    96  		span.AddEvent("connection state changed", trace.WithAttributes(attribute.String("state", state.String())))
    97  	})
    98  	return wpc, nil
    99  }
   100  
   101  func setAndGatherLocalDescription(peerConnection *wrappedPeerConnection, sdp webrtc.SessionDescription) (_ webrtc.SessionDescription, err error) {
   102  	gatherComplete := webrtc.GatheringCompletePromise(peerConnection.PeerConnection)
   103  	peerConnection.span.AddEvent("setting local description")
   104  	err = peerConnection.SetLocalDescription(sdp)
   105  	if err != nil {
   106  		err = fmt.Errorf("setting local description: %w", err)
   107  		return
   108  	}
   109  	<-gatherComplete
   110  	peerConnection.span.AddEvent("gathering complete")
   111  	return *peerConnection.LocalDescription(), nil
   112  }
   113  
   114  // newOffer creates a transport and returns a WebRTC offer to be announced. See
   115  // https://github.com/pion/webrtc/blob/master/examples/data-channels/jsfiddle/main.go for what this is modelled on.
   116  func (tc *TrackerClient) newOffer(
   117  	logger log.Logger,
   118  	offerId string,
   119  	infoHash [20]byte,
   120  ) (
   121  	peerConnection *wrappedPeerConnection,
   122  	dataChannel *webrtc.DataChannel,
   123  	offer webrtc.SessionDescription,
   124  	err error,
   125  ) {
   126  	peerConnection, err = newPeerConnection(logger, tc.ICEServers)
   127  	if err != nil {
   128  		return
   129  	}
   130  
   131  	peerConnection.span.SetAttributes(attribute.String(webrtcConnTypeKey, "offer"))
   132  
   133  	dataChannel, err = peerConnection.CreateDataChannel(dataChannelLabel, nil)
   134  	if err != nil {
   135  		err = fmt.Errorf("creating data channel: %w", err)
   136  		peerConnection.Close()
   137  	}
   138  	initDataChannel(dataChannel, peerConnection, func(dc DataChannelConn, dcCtx context.Context, dcSpan trace.Span) {
   139  		metrics.Add("outbound offers answered with datachannel", 1)
   140  		tc.mu.Lock()
   141  		tc.stats.ConvertedOutboundConns++
   142  		tc.mu.Unlock()
   143  		tc.OnConn(dc, DataChannelContext{
   144  			OfferId:        offerId,
   145  			LocalOffered:   true,
   146  			InfoHash:       infoHash,
   147  			peerConnection: peerConnection,
   148  			Context:        dcCtx,
   149  			Span:           dcSpan,
   150  		})
   151  	})
   152  
   153  	offer, err = peerConnection.CreateOffer(nil)
   154  	if err != nil {
   155  		dataChannel.Close()
   156  		peerConnection.Close()
   157  		return
   158  	}
   159  
   160  	offer, err = setAndGatherLocalDescription(peerConnection, offer)
   161  	if err != nil {
   162  		dataChannel.Close()
   163  		peerConnection.Close()
   164  	}
   165  	return
   166  }
   167  
   168  type onDetachedDataChannelFunc func(detached DataChannelConn, ctx context.Context, span trace.Span)
   169  
   170  func (tc *TrackerClient) initAnsweringPeerConnection(
   171  	peerConn *wrappedPeerConnection,
   172  	offerContext offerContext,
   173  ) (answer webrtc.SessionDescription, err error) {
   174  	peerConn.span.SetAttributes(attribute.String(webrtcConnTypeKey, "answer"))
   175  
   176  	timer := time.AfterFunc(30*time.Second, func() {
   177  		peerConn.span.SetStatus(codes.Error, "answer timeout")
   178  		metrics.Add("answering peer connections timed out", 1)
   179  		peerConn.Close()
   180  	})
   181  	peerConn.OnDataChannel(func(d *webrtc.DataChannel) {
   182  		initDataChannel(d, peerConn, func(detached DataChannelConn, ctx context.Context, span trace.Span) {
   183  			timer.Stop()
   184  			metrics.Add("answering peer connection conversions", 1)
   185  			tc.mu.Lock()
   186  			tc.stats.ConvertedInboundConns++
   187  			tc.mu.Unlock()
   188  			tc.OnConn(detached, DataChannelContext{
   189  				OfferId:        offerContext.Id,
   190  				LocalOffered:   false,
   191  				InfoHash:       offerContext.InfoHash,
   192  				peerConnection: peerConn,
   193  				Context:        ctx,
   194  				Span:           span,
   195  			})
   196  		})
   197  	})
   198  
   199  	err = peerConn.SetRemoteDescription(offerContext.SessDesc)
   200  	if err != nil {
   201  		return
   202  	}
   203  	answer, err = peerConn.CreateAnswer(nil)
   204  	if err != nil {
   205  		return
   206  	}
   207  
   208  	answer, err = setAndGatherLocalDescription(peerConn, answer)
   209  	return
   210  }
   211  
   212  // newAnsweringPeerConnection creates a transport from a WebRTC offer and returns a WebRTC answer to be announced.
   213  func (tc *TrackerClient) newAnsweringPeerConnection(
   214  	offerContext offerContext,
   215  ) (
   216  	peerConn *wrappedPeerConnection, answer webrtc.SessionDescription, err error,
   217  ) {
   218  	peerConn, err = newPeerConnection(tc.Logger, tc.ICEServers)
   219  	if err != nil {
   220  		err = fmt.Errorf("failed to create new connection: %w", err)
   221  		return
   222  	}
   223  	answer, err = tc.initAnsweringPeerConnection(peerConn, offerContext)
   224  	if err != nil {
   225  		peerConn.span.RecordError(err)
   226  		peerConn.Close()
   227  	}
   228  	return
   229  }
   230  
   231  type ioCloserFunc func() error
   232  
   233  func (me ioCloserFunc) Close() error {
   234  	return me()
   235  }
   236  
   237  func initDataChannel(
   238  	dc *webrtc.DataChannel,
   239  	pc *wrappedPeerConnection,
   240  	onOpen onDetachedDataChannelFunc,
   241  ) {
   242  	var span trace.Span
   243  	dc.OnClose(func() {
   244  		span.End()
   245  	})
   246  	dc.OnOpen(func() {
   247  		pc.span.AddEvent("data channel opened")
   248  		var ctx context.Context
   249  		ctx, span = otel.Tracer(tracerName).Start(pc.ctx, "DataChannel")
   250  		raw, err := dc.Detach()
   251  		if err != nil {
   252  			// This shouldn't happen if the API is configured correctly, and we call from OnOpen.
   253  			panic(err)
   254  		}
   255  		onOpen(wrapDataChannel(raw, pc, span, dc), ctx, span)
   256  	})
   257  }
   258  
   259  // WebRTC data channel wrapper that supports operating as a peer conn ReadWriteCloser.
   260  type DataChannelConn struct {
   261  	ioCloserFunc
   262  	rawDataChannel datachannel.ReadWriteCloser
   263  }
   264  
   265  func (d DataChannelConn) Read(p []byte) (int, error) {
   266  	return d.rawDataChannel.Read(p)
   267  }
   268  
   269  // Limit write size for WebRTC data channels. See https://github.com/pion/datachannel/issues/59. The
   270  // default used to be (1<<16)-1. This will be set to the new appropriate value if it's discovered to
   271  // still be a limitation. Set WEBTORRENT_MAX_WRITE_SIZE to experiment with it.
   272  var maxWriteSize = g.None[int]()
   273  
   274  func init() {
   275  	s, ok := os.LookupEnv("WEBTORRENT_MAX_WRITE_SIZE")
   276  	if !ok {
   277  		return
   278  	}
   279  	i64, err := strconv.ParseInt(s, 0, 0)
   280  	panicif.Err(err)
   281  	maxWriteSize = g.Some(int(i64))
   282  }
   283  
   284  func (d DataChannelConn) Write(p []byte) (n int, err error) {
   285  	for {
   286  		p1 := p
   287  		if maxWriteSize.Ok {
   288  			p1 = p1[:min(len(p1), maxWriteSize.Value)]
   289  		}
   290  		var n1 int
   291  		n1, err = d.rawDataChannel.Write(p1)
   292  		n += n1
   293  		p = p[n1:]
   294  		if err != nil {
   295  			return
   296  		}
   297  		if len(p) == 0 {
   298  			return
   299  		}
   300  	}
   301  }
   302  
   303  func wrapDataChannel(
   304  	dcrwc datachannel.ReadWriteCloser,
   305  	pc *wrappedPeerConnection,
   306  	dataChannelSpan trace.Span,
   307  	originalDataChannel *webrtc.DataChannel,
   308  ) DataChannelConn {
   309  	return DataChannelConn{
   310  		ioCloserFunc: ioCloserFunc(func() error {
   311  			dcrwc.Close()
   312  			pc.Close()
   313  			originalDataChannel.Close()
   314  			dataChannelSpan.End()
   315  			return nil
   316  		}),
   317  		rawDataChannel: dcrwc,
   318  	}
   319  }