github.com/xraypb/Xray-core@v1.8.1/transport/internet/udp/dispatcher.go (about)

     1  package udp
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/xraypb/Xray-core/common"
    11  	"github.com/xraypb/Xray-core/common/buf"
    12  	"github.com/xraypb/Xray-core/common/net"
    13  	"github.com/xraypb/Xray-core/common/protocol/udp"
    14  	"github.com/xraypb/Xray-core/common/session"
    15  	"github.com/xraypb/Xray-core/common/signal"
    16  	"github.com/xraypb/Xray-core/common/signal/done"
    17  	"github.com/xraypb/Xray-core/features/routing"
    18  	"github.com/xraypb/Xray-core/transport"
    19  )
    20  
    21  type ResponseCallback func(ctx context.Context, packet *udp.Packet)
    22  
    23  type connEntry struct {
    24  	link   *transport.Link
    25  	timer  signal.ActivityUpdater
    26  	cancel context.CancelFunc
    27  }
    28  
    29  type Dispatcher struct {
    30  	sync.RWMutex
    31  	conns      map[net.Destination]*connEntry
    32  	dispatcher routing.Dispatcher
    33  	callback   ResponseCallback
    34  	callClose  func() error
    35  }
    36  
    37  func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher {
    38  	return &Dispatcher{
    39  		conns:      make(map[net.Destination]*connEntry),
    40  		dispatcher: dispatcher,
    41  		callback:   callback,
    42  	}
    43  }
    44  
    45  func (v *Dispatcher) RemoveRay(dest net.Destination) {
    46  	v.Lock()
    47  	defer v.Unlock()
    48  	if conn, found := v.conns[dest]; found {
    49  		common.Close(conn.link.Reader)
    50  		common.Close(conn.link.Writer)
    51  		delete(v.conns, dest)
    52  	}
    53  }
    54  
    55  func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (*connEntry, error) {
    56  	v.Lock()
    57  	defer v.Unlock()
    58  
    59  	if entry, found := v.conns[dest]; found {
    60  		return entry, nil
    61  	}
    62  
    63  	newError("establishing new connection for ", dest).WriteToLog()
    64  
    65  	ctx, cancel := context.WithCancel(ctx)
    66  	removeRay := func() {
    67  		cancel()
    68  		v.RemoveRay(dest)
    69  	}
    70  	timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute)
    71  
    72  	link, err := v.dispatcher.Dispatch(ctx, dest)
    73  	if err != nil {
    74  		return nil, newError("failed to dispatch request to ", dest).Base(err)
    75  	}
    76  
    77  	entry := &connEntry{
    78  		link:   link,
    79  		timer:  timer,
    80  		cancel: removeRay,
    81  	}
    82  	v.conns[dest] = entry
    83  	go handleInput(ctx, entry, dest, v.callback, v.callClose)
    84  	return entry, nil
    85  }
    86  
    87  func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer) {
    88  	// TODO: Add user to destString
    89  	newError("dispatch request to: ", destination).AtDebug().WriteToLog(session.ExportIDToError(ctx))
    90  
    91  	conn, err := v.getInboundRay(ctx, destination)
    92  	if err != nil {
    93  		newError("failed to get inbound").Base(err).WriteToLog(session.ExportIDToError(ctx))
    94  		return
    95  	}
    96  	outputStream := conn.link.Writer
    97  	if outputStream != nil {
    98  		if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil {
    99  			newError("failed to write first UDP payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
   100  			conn.cancel()
   101  			return
   102  		}
   103  	}
   104  }
   105  
   106  func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback, callClose func() error) {
   107  	defer func() {
   108  		conn.cancel()
   109  		if callClose != nil {
   110  			callClose()
   111  		}
   112  	}()
   113  
   114  	input := conn.link.Reader
   115  	timer := conn.timer
   116  
   117  	for {
   118  		select {
   119  		case <-ctx.Done():
   120  			return
   121  		default:
   122  		}
   123  
   124  		mb, err := input.ReadMultiBuffer()
   125  		if err != nil {
   126  			if !errors.Is(err, io.EOF) {
   127  				newError("failed to handle UDP input").Base(err).WriteToLog(session.ExportIDToError(ctx))
   128  			}
   129  			return
   130  		}
   131  		timer.Update()
   132  		for _, b := range mb {
   133  			callback(ctx, &udp.Packet{
   134  				Payload: b,
   135  				Source:  dest,
   136  			})
   137  		}
   138  	}
   139  }
   140  
   141  type dispatcherConn struct {
   142  	dispatcher *Dispatcher
   143  	cache      chan *udp.Packet
   144  	done       *done.Instance
   145  }
   146  
   147  func DialDispatcher(ctx context.Context, dispatcher routing.Dispatcher) (net.PacketConn, error) {
   148  	c := &dispatcherConn{
   149  		cache: make(chan *udp.Packet, 16),
   150  		done:  done.New(),
   151  	}
   152  
   153  	d := &Dispatcher{
   154  		conns:      make(map[net.Destination]*connEntry),
   155  		dispatcher: dispatcher,
   156  		callback:   c.callback,
   157  		callClose:  c.Close,
   158  	}
   159  	c.dispatcher = d
   160  	return c, nil
   161  }
   162  
   163  func (c *dispatcherConn) callback(ctx context.Context, packet *udp.Packet) {
   164  	select {
   165  	case <-c.done.Wait():
   166  		packet.Payload.Release()
   167  		return
   168  	case c.cache <- packet:
   169  	default:
   170  		packet.Payload.Release()
   171  		return
   172  	}
   173  }
   174  
   175  func (c *dispatcherConn) ReadFrom(p []byte) (int, net.Addr, error) {
   176  	var packet *udp.Packet
   177  s:
   178  	select {
   179  	case <-c.done.Wait():
   180  		select {
   181  		case packet = <-c.cache:
   182  			break s
   183  		default:
   184  			return 0, nil, io.EOF
   185  		}
   186  	case packet = <-c.cache:
   187  	}
   188  	return copy(p, packet.Payload.Bytes()), &net.UDPAddr{
   189  		IP:   packet.Source.Address.IP(),
   190  		Port: int(packet.Source.Port),
   191  	}, nil
   192  }
   193  
   194  func (c *dispatcherConn) WriteTo(p []byte, addr net.Addr) (int, error) {
   195  	buffer := buf.New()
   196  	raw := buffer.Extend(buf.Size)
   197  	n := copy(raw, p)
   198  	buffer.Resize(0, int32(n))
   199  
   200  	ctx := context.Background()
   201  	c.dispatcher.Dispatch(ctx, net.DestinationFromAddr(addr), buffer)
   202  	return n, nil
   203  }
   204  
   205  func (c *dispatcherConn) Close() error {
   206  	return c.done.Close()
   207  }
   208  
   209  func (c *dispatcherConn) LocalAddr() net.Addr {
   210  	return &net.UDPAddr{
   211  		IP:   []byte{0, 0, 0, 0},
   212  		Port: 0,
   213  	}
   214  }
   215  
   216  func (c *dispatcherConn) SetDeadline(t time.Time) error {
   217  	return nil
   218  }
   219  
   220  func (c *dispatcherConn) SetReadDeadline(t time.Time) error {
   221  	return nil
   222  }
   223  
   224  func (c *dispatcherConn) SetWriteDeadline(t time.Time) error {
   225  	return nil
   226  }