github.com/xraypb/xray-core@v1.6.6/transport/internet/udp/dispatcher.go (about)

     1  package udp
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"sync"
     8  	"time"
     9  
    10  	"github.com/xraypb/xray-core/common"
    11  	"github.com/xraypb/xray-core/common/buf"
    12  	"github.com/xraypb/xray-core/common/net"
    13  	"github.com/xraypb/xray-core/common/protocol/udp"
    14  	"github.com/xraypb/xray-core/common/session"
    15  	"github.com/xraypb/xray-core/common/signal"
    16  	"github.com/xraypb/xray-core/common/signal/done"
    17  	"github.com/xraypb/xray-core/features/routing"
    18  	"github.com/xraypb/xray-core/transport"
    19  )
    20  
    21  type ResponseCallback func(ctx context.Context, packet *udp.Packet)
    22  
    23  type connEntry struct {
    24  	link   *transport.Link
    25  	timer  signal.ActivityUpdater
    26  	cancel context.CancelFunc
    27  }
    28  
    29  type Dispatcher struct {
    30  	sync.RWMutex
    31  	conns      map[net.Destination]*connEntry
    32  	dispatcher routing.Dispatcher
    33  	callback   ResponseCallback
    34  }
    35  
    36  func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher {
    37  	return &Dispatcher{
    38  		conns:      make(map[net.Destination]*connEntry),
    39  		dispatcher: dispatcher,
    40  		callback:   callback,
    41  	}
    42  }
    43  
    44  func (v *Dispatcher) RemoveRay(dest net.Destination) {
    45  	v.Lock()
    46  	defer v.Unlock()
    47  	if conn, found := v.conns[dest]; found {
    48  		common.Close(conn.link.Reader)
    49  		common.Close(conn.link.Writer)
    50  		delete(v.conns, dest)
    51  	}
    52  }
    53  
    54  func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) *connEntry {
    55  	v.Lock()
    56  	defer v.Unlock()
    57  
    58  	if entry, found := v.conns[dest]; found {
    59  		return entry
    60  	}
    61  
    62  	newError("establishing new connection for ", dest).WriteToLog()
    63  
    64  	ctx, cancel := context.WithCancel(ctx)
    65  	removeRay := func() {
    66  		cancel()
    67  		v.RemoveRay(dest)
    68  	}
    69  	timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute)
    70  	link, _ := v.dispatcher.Dispatch(ctx, dest)
    71  	entry := &connEntry{
    72  		link:   link,
    73  		timer:  timer,
    74  		cancel: removeRay,
    75  	}
    76  	v.conns[dest] = entry
    77  	go handleInput(ctx, entry, dest, v.callback)
    78  	return entry
    79  }
    80  
    81  func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer) {
    82  	// TODO: Add user to destString
    83  	newError("dispatch request to: ", destination).AtDebug().WriteToLog(session.ExportIDToError(ctx))
    84  
    85  	conn := v.getInboundRay(ctx, destination)
    86  	outputStream := conn.link.Writer
    87  	if outputStream != nil {
    88  		if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil {
    89  			newError("failed to write first UDP payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
    90  			conn.cancel()
    91  			return
    92  		}
    93  	}
    94  }
    95  
    96  func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback) {
    97  	defer conn.cancel()
    98  
    99  	input := conn.link.Reader
   100  	timer := conn.timer
   101  
   102  	for {
   103  		select {
   104  		case <-ctx.Done():
   105  			return
   106  		default:
   107  		}
   108  
   109  		mb, err := input.ReadMultiBuffer()
   110  		if err != nil {
   111  			if !errors.Is(err, io.EOF) {
   112  				newError("failed to handle UDP input").Base(err).WriteToLog(session.ExportIDToError(ctx))
   113  			}
   114  			return
   115  		}
   116  		timer.Update()
   117  		for _, b := range mb {
   118  			callback(ctx, &udp.Packet{
   119  				Payload: b,
   120  				Source:  dest,
   121  			})
   122  		}
   123  	}
   124  }
   125  
   126  type dispatcherConn struct {
   127  	dispatcher *Dispatcher
   128  	cache      chan *udp.Packet
   129  	done       *done.Instance
   130  }
   131  
   132  func DialDispatcher(ctx context.Context, dispatcher routing.Dispatcher) (net.PacketConn, error) {
   133  	c := &dispatcherConn{
   134  		cache: make(chan *udp.Packet, 16),
   135  		done:  done.New(),
   136  	}
   137  
   138  	d := NewDispatcher(dispatcher, c.callback)
   139  	c.dispatcher = d
   140  	return c, nil
   141  }
   142  
   143  func (c *dispatcherConn) callback(ctx context.Context, packet *udp.Packet) {
   144  	select {
   145  	case <-c.done.Wait():
   146  		packet.Payload.Release()
   147  		return
   148  	case c.cache <- packet:
   149  	default:
   150  		packet.Payload.Release()
   151  		return
   152  	}
   153  }
   154  
   155  func (c *dispatcherConn) ReadFrom(p []byte) (int, net.Addr, error) {
   156  	select {
   157  	case <-c.done.Wait():
   158  		return 0, nil, io.EOF
   159  	case packet := <-c.cache:
   160  		n := copy(p, packet.Payload.Bytes())
   161  		return n, &net.UDPAddr{
   162  			IP:   packet.Source.Address.IP(),
   163  			Port: int(packet.Source.Port),
   164  		}, nil
   165  	}
   166  }
   167  
   168  func (c *dispatcherConn) WriteTo(p []byte, addr net.Addr) (int, error) {
   169  	buffer := buf.New()
   170  	raw := buffer.Extend(buf.Size)
   171  	n := copy(raw, p)
   172  	buffer.Resize(0, int32(n))
   173  
   174  	ctx := context.Background()
   175  	c.dispatcher.Dispatch(ctx, net.DestinationFromAddr(addr), buffer)
   176  	return n, nil
   177  }
   178  
   179  func (c *dispatcherConn) Close() error {
   180  	return c.done.Close()
   181  }
   182  
   183  func (c *dispatcherConn) LocalAddr() net.Addr {
   184  	return &net.UDPAddr{
   185  		IP:   []byte{0, 0, 0, 0},
   186  		Port: 0,
   187  	}
   188  }
   189  
   190  func (c *dispatcherConn) SetDeadline(t time.Time) error {
   191  	return nil
   192  }
   193  
   194  func (c *dispatcherConn) SetReadDeadline(t time.Time) error {
   195  	return nil
   196  }
   197  
   198  func (c *dispatcherConn) SetWriteDeadline(t time.Time) error {
   199  	return nil
   200  }