github.com/v2fly/v2ray-core/v4@v4.45.2/transport/internet/udp/dispatcher.go (about)

     1  package udp
     2  
     3  import (
     4  	"context"
     5  	"io"
     6  	"sync"
     7  	"time"
     8  
     9  	"github.com/v2fly/v2ray-core/v4/common"
    10  	"github.com/v2fly/v2ray-core/v4/common/buf"
    11  	"github.com/v2fly/v2ray-core/v4/common/net"
    12  	"github.com/v2fly/v2ray-core/v4/common/protocol/udp"
    13  	"github.com/v2fly/v2ray-core/v4/common/session"
    14  	"github.com/v2fly/v2ray-core/v4/common/signal"
    15  	"github.com/v2fly/v2ray-core/v4/common/signal/done"
    16  	"github.com/v2fly/v2ray-core/v4/features/routing"
    17  	"github.com/v2fly/v2ray-core/v4/transport"
    18  )
    19  
    20  type ResponseCallback func(ctx context.Context, packet *udp.Packet)
    21  
    22  type connEntry struct {
    23  	link   *transport.Link
    24  	timer  signal.ActivityUpdater
    25  	cancel context.CancelFunc
    26  }
    27  
    28  type Dispatcher struct {
    29  	sync.RWMutex
    30  	conns      map[net.Destination]*connEntry
    31  	dispatcher routing.Dispatcher
    32  	callback   ResponseCallback
    33  }
    34  
    35  func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher {
    36  	return &Dispatcher{
    37  		conns:      make(map[net.Destination]*connEntry),
    38  		dispatcher: dispatcher,
    39  		callback:   callback,
    40  	}
    41  }
    42  
    43  func (v *Dispatcher) RemoveRay(dest net.Destination) {
    44  	v.Lock()
    45  	defer v.Unlock()
    46  	if conn, found := v.conns[dest]; found {
    47  		common.Close(conn.link.Reader)
    48  		common.Close(conn.link.Writer)
    49  		delete(v.conns, dest)
    50  	}
    51  }
    52  
    53  func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) *connEntry {
    54  	v.Lock()
    55  	defer v.Unlock()
    56  
    57  	if entry, found := v.conns[dest]; found {
    58  		return entry
    59  	}
    60  
    61  	newError("establishing new connection for ", dest).WriteToLog()
    62  
    63  	ctx, cancel := context.WithCancel(ctx)
    64  	removeRay := func() {
    65  		cancel()
    66  		v.RemoveRay(dest)
    67  	}
    68  	timer := signal.CancelAfterInactivity(ctx, removeRay, time.Second*4)
    69  	link, _ := v.dispatcher.Dispatch(ctx, dest)
    70  	entry := &connEntry{
    71  		link:   link,
    72  		timer:  timer,
    73  		cancel: removeRay,
    74  	}
    75  	v.conns[dest] = entry
    76  	go handleInput(ctx, entry, dest, v.callback)
    77  	return entry
    78  }
    79  
    80  func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer) {
    81  	// TODO: Add user to destString
    82  	newError("dispatch request to: ", destination).AtDebug().WriteToLog(session.ExportIDToError(ctx))
    83  
    84  	conn := v.getInboundRay(ctx, destination)
    85  	outputStream := conn.link.Writer
    86  	if outputStream != nil {
    87  		if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil {
    88  			newError("failed to write first UDP payload").Base(err).WriteToLog(session.ExportIDToError(ctx))
    89  			conn.cancel()
    90  			return
    91  		}
    92  	}
    93  }
    94  
    95  func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback) {
    96  	defer conn.cancel()
    97  
    98  	input := conn.link.Reader
    99  	timer := conn.timer
   100  
   101  	for {
   102  		select {
   103  		case <-ctx.Done():
   104  			return
   105  		default:
   106  		}
   107  
   108  		mb, err := input.ReadMultiBuffer()
   109  		if err != nil {
   110  			newError("failed to handle UDP input").Base(err).WriteToLog(session.ExportIDToError(ctx))
   111  			return
   112  		}
   113  		timer.Update()
   114  		for _, b := range mb {
   115  			callback(ctx, &udp.Packet{
   116  				Payload: b,
   117  				Source:  dest,
   118  			})
   119  		}
   120  	}
   121  }
   122  
   123  type dispatcherConn struct {
   124  	dispatcher *Dispatcher
   125  	cache      chan *udp.Packet
   126  	done       *done.Instance
   127  }
   128  
   129  func DialDispatcher(ctx context.Context, dispatcher routing.Dispatcher) (net.PacketConn, error) {
   130  	c := &dispatcherConn{
   131  		cache: make(chan *udp.Packet, 16),
   132  		done:  done.New(),
   133  	}
   134  
   135  	d := NewDispatcher(dispatcher, c.callback)
   136  	c.dispatcher = d
   137  	return c, nil
   138  }
   139  
   140  func (c *dispatcherConn) callback(ctx context.Context, packet *udp.Packet) {
   141  	select {
   142  	case <-c.done.Wait():
   143  		packet.Payload.Release()
   144  		return
   145  	case c.cache <- packet:
   146  	default:
   147  		packet.Payload.Release()
   148  		return
   149  	}
   150  }
   151  
   152  func (c *dispatcherConn) ReadFrom(p []byte) (int, net.Addr, error) {
   153  	select {
   154  	case <-c.done.Wait():
   155  		return 0, nil, io.EOF
   156  	case packet := <-c.cache:
   157  		n := copy(p, packet.Payload.Bytes())
   158  		return n, &net.UDPAddr{
   159  			IP:   packet.Source.Address.IP(),
   160  			Port: int(packet.Source.Port),
   161  		}, nil
   162  	}
   163  }
   164  
   165  func (c *dispatcherConn) WriteTo(p []byte, addr net.Addr) (int, error) {
   166  	buffer := buf.New()
   167  	raw := buffer.Extend(buf.Size)
   168  	n := copy(raw, p)
   169  	buffer.Resize(0, int32(n))
   170  
   171  	ctx := context.Background()
   172  	c.dispatcher.Dispatch(ctx, net.DestinationFromAddr(addr), buffer)
   173  	return n, nil
   174  }
   175  
   176  func (c *dispatcherConn) Close() error {
   177  	return c.done.Close()
   178  }
   179  
   180  func (c *dispatcherConn) LocalAddr() net.Addr {
   181  	return &net.UDPAddr{
   182  		IP:   []byte{0, 0, 0, 0},
   183  		Port: 0,
   184  	}
   185  }
   186  
   187  func (c *dispatcherConn) SetDeadline(t time.Time) error {
   188  	return nil
   189  }
   190  
   191  func (c *dispatcherConn) SetReadDeadline(t time.Time) error {
   192  	return nil
   193  }
   194  
   195  func (c *dispatcherConn) SetWriteDeadline(t time.Time) error {
   196  	return nil
   197  }