github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/transport/internet/udp/dispatcher.go (about) 1 package udp 2 3 import ( 4 "context" 5 "errors" 6 "io" 7 "sync" 8 "time" 9 10 "github.com/xtls/xray-core/common" 11 "github.com/xtls/xray-core/common/buf" 12 "github.com/xtls/xray-core/common/net" 13 "github.com/xtls/xray-core/common/protocol/udp" 14 "github.com/xtls/xray-core/common/session" 15 "github.com/xtls/xray-core/common/signal" 16 "github.com/xtls/xray-core/common/signal/done" 17 "github.com/xtls/xray-core/features/routing" 18 "github.com/xtls/xray-core/transport" 19 ) 20 21 type ResponseCallback func(ctx context.Context, packet *udp.Packet) 22 23 type connEntry struct { 24 link *transport.Link 25 timer signal.ActivityUpdater 26 cancel context.CancelFunc 27 } 28 29 type Dispatcher struct { 30 sync.RWMutex 31 conn *connEntry 32 dispatcher routing.Dispatcher 33 callback ResponseCallback 34 callClose func() error 35 } 36 37 func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher { 38 return &Dispatcher{ 39 dispatcher: dispatcher, 40 callback: callback, 41 } 42 } 43 44 func (v *Dispatcher) RemoveRay() { 45 v.Lock() 46 defer v.Unlock() 47 if v.conn != nil { 48 common.Close(v.conn.link.Reader) 49 common.Close(v.conn.link.Writer) 50 v.conn = nil 51 } 52 } 53 54 func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (*connEntry, error) { 55 v.Lock() 56 defer v.Unlock() 57 58 if v.conn != nil { 59 return v.conn, nil 60 } 61 62 newError("establishing new connection for ", dest).WriteToLog() 63 64 ctx, cancel := context.WithCancel(ctx) 65 removeRay := func() { 66 cancel() 67 v.RemoveRay() 68 } 69 timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute) 70 71 link, err := v.dispatcher.Dispatch(ctx, dest) 72 if err != nil { 73 return nil, newError("failed to dispatch request to ", dest).Base(err) 74 } 75 76 entry := &connEntry{ 77 link: link, 78 timer: timer, 79 cancel: removeRay, 80 } 81 v.conn = entry 82 go handleInput(ctx, entry, dest, v.callback, v.callClose) 83 return entry, nil 84 } 85 86 func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer) { 87 // TODO: Add user to destString 88 newError("dispatch request to: ", destination).AtDebug().WriteToLog(session.ExportIDToError(ctx)) 89 90 conn, err := v.getInboundRay(ctx, destination) 91 if err != nil { 92 newError("failed to get inbound").Base(err).WriteToLog(session.ExportIDToError(ctx)) 93 return 94 } 95 outputStream := conn.link.Writer 96 if outputStream != nil { 97 if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil { 98 newError("failed to write first UDP payload").Base(err).WriteToLog(session.ExportIDToError(ctx)) 99 conn.cancel() 100 return 101 } 102 } 103 } 104 105 func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback, callClose func() error) { 106 defer func() { 107 conn.cancel() 108 if callClose != nil { 109 callClose() 110 } 111 }() 112 113 input := conn.link.Reader 114 timer := conn.timer 115 116 for { 117 select { 118 case <-ctx.Done(): 119 return 120 default: 121 } 122 123 mb, err := input.ReadMultiBuffer() 124 if err != nil { 125 if !errors.Is(err, io.EOF) { 126 newError("failed to handle UDP input").Base(err).WriteToLog(session.ExportIDToError(ctx)) 127 } 128 return 129 } 130 timer.Update() 131 for _, b := range mb { 132 if b.UDP != nil { 133 dest = *b.UDP 134 } 135 callback(ctx, &udp.Packet{ 136 Payload: b, 137 Source: dest, 138 }) 139 } 140 } 141 } 142 143 type dispatcherConn struct { 144 dispatcher *Dispatcher 145 cache chan *udp.Packet 146 done *done.Instance 147 ctx context.Context 148 } 149 150 func DialDispatcher(ctx context.Context, dispatcher routing.Dispatcher) (net.PacketConn, error) { 151 c := &dispatcherConn{ 152 cache: make(chan *udp.Packet, 16), 153 done: done.New(), 154 ctx: ctx, 155 } 156 157 d := &Dispatcher{ 158 dispatcher: dispatcher, 159 callback: c.callback, 160 callClose: c.Close, 161 } 162 c.dispatcher = d 163 return c, nil 164 } 165 166 func (c *dispatcherConn) callback(ctx context.Context, packet *udp.Packet) { 167 select { 168 case <-c.done.Wait(): 169 packet.Payload.Release() 170 return 171 case c.cache <- packet: 172 default: 173 packet.Payload.Release() 174 return 175 } 176 } 177 178 func (c *dispatcherConn) ReadFrom(p []byte) (int, net.Addr, error) { 179 var packet *udp.Packet 180 s: 181 select { 182 case <-c.done.Wait(): 183 select { 184 case packet = <-c.cache: 185 break s 186 default: 187 return 0, nil, io.EOF 188 } 189 case packet = <-c.cache: 190 } 191 return copy(p, packet.Payload.Bytes()), &net.UDPAddr{ 192 IP: packet.Source.Address.IP(), 193 Port: int(packet.Source.Port), 194 }, nil 195 } 196 197 func (c *dispatcherConn) WriteTo(p []byte, addr net.Addr) (int, error) { 198 buffer := buf.New() 199 raw := buffer.Extend(buf.Size) 200 n := copy(raw, p) 201 buffer.Resize(0, int32(n)) 202 203 destination := net.DestinationFromAddr(addr) 204 buffer.UDP = &destination 205 c.dispatcher.Dispatch(c.ctx, destination, buffer) 206 return n, nil 207 } 208 209 func (c *dispatcherConn) Close() error { 210 return c.done.Close() 211 } 212 213 func (c *dispatcherConn) LocalAddr() net.Addr { 214 return &net.UDPAddr{ 215 IP: []byte{0, 0, 0, 0}, 216 Port: 0, 217 } 218 } 219 220 func (c *dispatcherConn) SetDeadline(t time.Time) error { 221 return nil 222 } 223 224 func (c *dispatcherConn) SetReadDeadline(t time.Time) error { 225 return nil 226 } 227 228 func (c *dispatcherConn) SetWriteDeadline(t time.Time) error { 229 return nil 230 }