github.com/xraypb/xray-core@v1.6.6/transport/internet/udp/dispatcher.go (about) 1 package udp 2 3 import ( 4 "context" 5 "errors" 6 "io" 7 "sync" 8 "time" 9 10 "github.com/xraypb/xray-core/common" 11 "github.com/xraypb/xray-core/common/buf" 12 "github.com/xraypb/xray-core/common/net" 13 "github.com/xraypb/xray-core/common/protocol/udp" 14 "github.com/xraypb/xray-core/common/session" 15 "github.com/xraypb/xray-core/common/signal" 16 "github.com/xraypb/xray-core/common/signal/done" 17 "github.com/xraypb/xray-core/features/routing" 18 "github.com/xraypb/xray-core/transport" 19 ) 20 21 type ResponseCallback func(ctx context.Context, packet *udp.Packet) 22 23 type connEntry struct { 24 link *transport.Link 25 timer signal.ActivityUpdater 26 cancel context.CancelFunc 27 } 28 29 type Dispatcher struct { 30 sync.RWMutex 31 conns map[net.Destination]*connEntry 32 dispatcher routing.Dispatcher 33 callback ResponseCallback 34 } 35 36 func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher { 37 return &Dispatcher{ 38 conns: make(map[net.Destination]*connEntry), 39 dispatcher: dispatcher, 40 callback: callback, 41 } 42 } 43 44 func (v *Dispatcher) RemoveRay(dest net.Destination) { 45 v.Lock() 46 defer v.Unlock() 47 if conn, found := v.conns[dest]; found { 48 common.Close(conn.link.Reader) 49 common.Close(conn.link.Writer) 50 delete(v.conns, dest) 51 } 52 } 53 54 func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) *connEntry { 55 v.Lock() 56 defer v.Unlock() 57 58 if entry, found := v.conns[dest]; found { 59 return entry 60 } 61 62 newError("establishing new connection for ", dest).WriteToLog() 63 64 ctx, cancel := context.WithCancel(ctx) 65 removeRay := func() { 66 cancel() 67 v.RemoveRay(dest) 68 } 69 timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute) 70 link, _ := v.dispatcher.Dispatch(ctx, dest) 71 entry := &connEntry{ 72 link: link, 73 timer: timer, 74 cancel: removeRay, 75 } 76 v.conns[dest] = entry 77 go handleInput(ctx, entry, dest, v.callback) 78 return entry 79 } 80 81 func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer) { 82 // TODO: Add user to destString 83 newError("dispatch request to: ", destination).AtDebug().WriteToLog(session.ExportIDToError(ctx)) 84 85 conn := v.getInboundRay(ctx, destination) 86 outputStream := conn.link.Writer 87 if outputStream != nil { 88 if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil { 89 newError("failed to write first UDP payload").Base(err).WriteToLog(session.ExportIDToError(ctx)) 90 conn.cancel() 91 return 92 } 93 } 94 } 95 96 func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback) { 97 defer conn.cancel() 98 99 input := conn.link.Reader 100 timer := conn.timer 101 102 for { 103 select { 104 case <-ctx.Done(): 105 return 106 default: 107 } 108 109 mb, err := input.ReadMultiBuffer() 110 if err != nil { 111 if !errors.Is(err, io.EOF) { 112 newError("failed to handle UDP input").Base(err).WriteToLog(session.ExportIDToError(ctx)) 113 } 114 return 115 } 116 timer.Update() 117 for _, b := range mb { 118 callback(ctx, &udp.Packet{ 119 Payload: b, 120 Source: dest, 121 }) 122 } 123 } 124 } 125 126 type dispatcherConn struct { 127 dispatcher *Dispatcher 128 cache chan *udp.Packet 129 done *done.Instance 130 } 131 132 func DialDispatcher(ctx context.Context, dispatcher routing.Dispatcher) (net.PacketConn, error) { 133 c := &dispatcherConn{ 134 cache: make(chan *udp.Packet, 16), 135 done: done.New(), 136 } 137 138 d := NewDispatcher(dispatcher, c.callback) 139 c.dispatcher = d 140 return c, nil 141 } 142 143 func (c *dispatcherConn) callback(ctx context.Context, packet *udp.Packet) { 144 select { 145 case <-c.done.Wait(): 146 packet.Payload.Release() 147 return 148 case c.cache <- packet: 149 default: 150 packet.Payload.Release() 151 return 152 } 153 } 154 155 func (c *dispatcherConn) ReadFrom(p []byte) (int, net.Addr, error) { 156 select { 157 case <-c.done.Wait(): 158 return 0, nil, io.EOF 159 case packet := <-c.cache: 160 n := copy(p, packet.Payload.Bytes()) 161 return n, &net.UDPAddr{ 162 IP: packet.Source.Address.IP(), 163 Port: int(packet.Source.Port), 164 }, nil 165 } 166 } 167 168 func (c *dispatcherConn) WriteTo(p []byte, addr net.Addr) (int, error) { 169 buffer := buf.New() 170 raw := buffer.Extend(buf.Size) 171 n := copy(raw, p) 172 buffer.Resize(0, int32(n)) 173 174 ctx := context.Background() 175 c.dispatcher.Dispatch(ctx, net.DestinationFromAddr(addr), buffer) 176 return n, nil 177 } 178 179 func (c *dispatcherConn) Close() error { 180 return c.done.Close() 181 } 182 183 func (c *dispatcherConn) LocalAddr() net.Addr { 184 return &net.UDPAddr{ 185 IP: []byte{0, 0, 0, 0}, 186 Port: 0, 187 } 188 } 189 190 func (c *dispatcherConn) SetDeadline(t time.Time) error { 191 return nil 192 } 193 194 func (c *dispatcherConn) SetReadDeadline(t time.Time) error { 195 return nil 196 } 197 198 func (c *dispatcherConn) SetWriteDeadline(t time.Time) error { 199 return nil 200 }