github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/transport/internet/udp/dispatcher_split.go (about)

     1  package udp
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/v2fly/v2ray-core/v5/common"
    10  	"github.com/v2fly/v2ray-core/v5/common/buf"
    11  	"github.com/v2fly/v2ray-core/v5/common/net"
    12  	"github.com/v2fly/v2ray-core/v5/common/protocol/udp"
    13  	"github.com/v2fly/v2ray-core/v5/common/session"
    14  	"github.com/v2fly/v2ray-core/v5/common/signal"
    15  	"github.com/v2fly/v2ray-core/v5/common/signal/done"
    16  	"github.com/v2fly/v2ray-core/v5/features/routing"
    17  	"github.com/v2fly/v2ray-core/v5/transport"
    18  )
    19  
    20  type ResponseCallback func(ctx context.Context, packet *udp.Packet)
    21  
    22  type connEntry struct {
    23  	link   *transport.Link
    24  	timer  signal.ActivityUpdater
    25  	cancel context.CancelFunc
    26  }
    27  
    28  type Dispatcher struct {
    29  	sync.RWMutex
    30  	conns      map[net.Destination]*connEntry
    31  	dispatcher routing.Dispatcher
    32  	callback   ResponseCallback
    33  }
    34  
    35  func (v *Dispatcher) Close() error {
    36  	return nil
    37  }
    38  
    39  func NewSplitDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) DispatcherI {
    40  	return &Dispatcher{
    41  		conns:      make(map[net.Destination]*connEntry),
    42  		dispatcher: dispatcher,
    43  		callback:   callback,
    44  	}
    45  }
    46  
    47  func (v *Dispatcher) RemoveRay(dest net.Destination) {
    48  	v.Lock()
    49  	defer v.Unlock()
    50  	if conn, found := v.conns[dest]; found {
    51  		common.Close(conn.link.Reader)
    52  		common.Close(conn.link.Writer)
    53  		delete(v.conns, dest)
    54  	}
    55  }
    56  
    57  func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) *connEntry {
    58  	v.Lock()
    59  	defer v.Unlock()
    60  
    61  	if entry, found := v.conns[dest]; found {
    62  		return entry
    63  	}
    64  
    65  	newError("establishing new connection for ", dest).WriteToLog()
    66  
    67  	ctx, cancel := context.WithCancel(ctx)
    68  	removeRay := func() {
    69  		cancel()
    70  		v.RemoveRay(dest)
    71  	}
    72  	timer := signal.CancelAfterInactivity(ctx, removeRay, time.Second*300)
    73  	link, _ := v.dispatcher.Dispatch(ctx, dest)
    74  	entry := &connEntry{
    75  		link:   link,
    76  		timer:  timer,
    77  		cancel: removeRay,
    78  	}
    79  	v.conns[dest] = entry
    80  	go handleInput(ctx, entry, dest, v.callback)
    81  	return entry
    82  }
    83  
    84  func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer) {
    85  	// TODO: Add user to destString
    86  	newError("dispatch request to: ", destination).AtDebug().WriteToLog(session.ExportIDToError(ctx))
    87  
    88  	conn := v.getInboundRay(ctx, destination)
    89  	outputStream := conn.link.Writer
    90  	if outputStream != nil {
    91  		if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil {
    92  			newError("failed to write first UDP payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
    93  			conn.cancel()
    94  			return
    95  		}
    96  	}
    97  }
    98  
    99  func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback) {
   100  	defer conn.cancel()
   101  
   102  	input := conn.link.Reader
   103  	timer := conn.timer
   104  
   105  	for {
   106  		select {
   107  		case <-ctx.Done():
   108  			return
   109  		default:
   110  		}
   111  
   112  		mb, err := input.ReadMultiBuffer()
   113  		if err != nil {
   114  			newError("failed to handle UDP input").Base(err).WriteToLog(session.ExportIDToError(ctx))
   115  			return
   116  		}
   117  		timer.Update()
   118  		for _, b := range mb {
   119  			callback(ctx, &udp.Packet{
   120  				Payload: b,
   121  				Source:  dest,
   122  			})
   123  		}
   124  	}
   125  }
   126  
   127  type dispatcherConn struct {
   128  	dispatcher *Dispatcher
   129  	cache      chan *udp.Packet
   130  	done       *done.Instance
   131  }
   132  
   133  func DialDispatcher(ctx context.Context, dispatcher routing.Dispatcher) (net.PacketConn, error) {
   134  	c := &dispatcherConn{
   135  		cache: make(chan *udp.Packet, 16),
   136  		done:  done.New(),
   137  	}
   138  
   139  	d := NewSplitDispatcher(dispatcher, c.callback)
   140  	c.dispatcher = d.(*Dispatcher)
   141  	return c, nil
   142  }
   143  
   144  func (c *dispatcherConn) callback(ctx context.Context, packet *udp.Packet) {
   145  	select {
   146  	case <-c.done.Wait():
   147  		packet.Payload.Release()
   148  		return
   149  	case c.cache <- packet:
   150  	default:
   151  		packet.Payload.Release()
   152  		return
   153  	}
   154  }
   155  
   156  func (c *dispatcherConn) ReadFrom(p []byte) (int, net.Addr, error) {
   157  	select {
   158  	case <-c.done.Wait():
   159  		return 0, nil, io.EOF
   160  	case packet := <-c.cache:
   161  		n := copy(p, packet.Payload.Bytes())
   162  		return n, &net.UDPAddr{
   163  			IP:   packet.Source.Address.IP(),
   164  			Port: int(packet.Source.Port),
   165  		}, nil
   166  	}
   167  }
   168  
   169  func (c *dispatcherConn) WriteTo(p []byte, addr net.Addr) (int, error) {
   170  	buffer := buf.New()
   171  	raw := buffer.Extend(buf.Size)
   172  	n := copy(raw, p)
   173  	buffer.Resize(0, int32(n))
   174  
   175  	ctx := context.Background()
   176  	c.dispatcher.Dispatch(ctx, net.DestinationFromAddr(addr), buffer)
   177  	return n, nil
   178  }
   179  
   180  func (c *dispatcherConn) Close() error {
   181  	return c.done.Close()
   182  }
   183  
   184  func (c *dispatcherConn) LocalAddr() net.Addr {
   185  	return &net.UDPAddr{
   186  		IP:   []byte{0, 0, 0, 0},
   187  		Port: 0,
   188  	}
   189  }
   190  
   191  func (c *dispatcherConn) SetDeadline(t time.Time) error {
   192  	return nil
   193  }
   194  
   195  func (c *dispatcherConn) SetReadDeadline(t time.Time) error {
   196  	return nil
   197  }
   198  
   199  func (c *dispatcherConn) SetWriteDeadline(t time.Time) error {
   200  	return nil
   201  }