github.com/xraypb/Xray-core@v1.8.1/transport/internet/udp/dispatcher.go (about) 1 package udp 2 3 import ( 4 "context" 5 "errors" 6 "io" 7 "sync" 8 "time" 9 10 "github.com/xraypb/Xray-core/common" 11 "github.com/xraypb/Xray-core/common/buf" 12 "github.com/xraypb/Xray-core/common/net" 13 "github.com/xraypb/Xray-core/common/protocol/udp" 14 "github.com/xraypb/Xray-core/common/session" 15 "github.com/xraypb/Xray-core/common/signal" 16 "github.com/xraypb/Xray-core/common/signal/done" 17 "github.com/xraypb/Xray-core/features/routing" 18 "github.com/xraypb/Xray-core/transport" 19 ) 20 21 type ResponseCallback func(ctx context.Context, packet *udp.Packet) 22 23 type connEntry struct { 24 link *transport.Link 25 timer signal.ActivityUpdater 26 cancel context.CancelFunc 27 } 28 29 type Dispatcher struct { 30 sync.RWMutex 31 conns map[net.Destination]*connEntry 32 dispatcher routing.Dispatcher 33 callback ResponseCallback 34 callClose func() error 35 } 36 37 func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher { 38 return &Dispatcher{ 39 conns: make(map[net.Destination]*connEntry), 40 dispatcher: dispatcher, 41 callback: callback, 42 } 43 } 44 45 func (v *Dispatcher) RemoveRay(dest net.Destination) { 46 v.Lock() 47 defer v.Unlock() 48 if conn, found := v.conns[dest]; found { 49 common.Close(conn.link.Reader) 50 common.Close(conn.link.Writer) 51 delete(v.conns, dest) 52 } 53 } 54 55 func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (*connEntry, error) { 56 v.Lock() 57 defer v.Unlock() 58 59 if entry, found := v.conns[dest]; found { 60 return entry, nil 61 } 62 63 newError("establishing new connection for ", dest).WriteToLog() 64 65 ctx, cancel := context.WithCancel(ctx) 66 removeRay := func() { 67 cancel() 68 v.RemoveRay(dest) 69 } 70 timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute) 71 72 link, err := v.dispatcher.Dispatch(ctx, dest) 73 if err != nil { 74 return nil, newError("failed to dispatch request to ", dest).Base(err) 75 } 76 77 entry := &connEntry{ 78 link: link, 79 timer: timer, 80 cancel: removeRay, 81 } 82 v.conns[dest] = entry 83 go handleInput(ctx, entry, dest, v.callback, v.callClose) 84 return entry, nil 85 } 86 87 func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer) { 88 // TODO: Add user to destString 89 newError("dispatch request to: ", destination).AtDebug().WriteToLog(session.ExportIDToError(ctx)) 90 91 conn, err := v.getInboundRay(ctx, destination) 92 if err != nil { 93 newError("failed to get inbound").Base(err).WriteToLog(session.ExportIDToError(ctx)) 94 return 95 } 96 outputStream := conn.link.Writer 97 if outputStream != nil { 98 if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil { 99 newError("failed to write first UDP payload").Base(err).WriteToLog(session.ExportIDToError(ctx)) 100 conn.cancel() 101 return 102 } 103 } 104 } 105 106 func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback, callClose func() error) { 107 defer func() { 108 conn.cancel() 109 if callClose != nil { 110 callClose() 111 } 112 }() 113 114 input := conn.link.Reader 115 timer := conn.timer 116 117 for { 118 select { 119 case <-ctx.Done(): 120 return 121 default: 122 } 123 124 mb, err := input.ReadMultiBuffer() 125 if err != nil { 126 if !errors.Is(err, io.EOF) { 127 newError("failed to handle UDP input").Base(err).WriteToLog(session.ExportIDToError(ctx)) 128 } 129 return 130 } 131 timer.Update() 132 for _, b := range mb { 133 callback(ctx, &udp.Packet{ 134 Payload: b, 135 Source: dest, 136 }) 137 } 138 } 139 } 140 141 type dispatcherConn struct { 142 dispatcher *Dispatcher 143 cache chan *udp.Packet 144 done *done.Instance 145 } 146 147 func DialDispatcher(ctx context.Context, dispatcher routing.Dispatcher) (net.PacketConn, error) { 148 c := &dispatcherConn{ 149 cache: make(chan *udp.Packet, 16), 150 done: done.New(), 151 } 152 153 d := &Dispatcher{ 154 conns: make(map[net.Destination]*connEntry), 155 dispatcher: dispatcher, 156 callback: c.callback, 157 callClose: c.Close, 158 } 159 c.dispatcher = d 160 return c, nil 161 } 162 163 func (c *dispatcherConn) callback(ctx context.Context, packet *udp.Packet) { 164 select { 165 case <-c.done.Wait(): 166 packet.Payload.Release() 167 return 168 case c.cache <- packet: 169 default: 170 packet.Payload.Release() 171 return 172 } 173 } 174 175 func (c *dispatcherConn) ReadFrom(p []byte) (int, net.Addr, error) { 176 var packet *udp.Packet 177 s: 178 select { 179 case <-c.done.Wait(): 180 select { 181 case packet = <-c.cache: 182 break s 183 default: 184 return 0, nil, io.EOF 185 } 186 case packet = <-c.cache: 187 } 188 return copy(p, packet.Payload.Bytes()), &net.UDPAddr{ 189 IP: packet.Source.Address.IP(), 190 Port: int(packet.Source.Port), 191 }, nil 192 } 193 194 func (c *dispatcherConn) WriteTo(p []byte, addr net.Addr) (int, error) { 195 buffer := buf.New() 196 raw := buffer.Extend(buf.Size) 197 n := copy(raw, p) 198 buffer.Resize(0, int32(n)) 199 200 ctx := context.Background() 201 c.dispatcher.Dispatch(ctx, net.DestinationFromAddr(addr), buffer) 202 return n, nil 203 } 204 205 func (c *dispatcherConn) Close() error { 206 return c.done.Close() 207 } 208 209 func (c *dispatcherConn) LocalAddr() net.Addr { 210 return &net.UDPAddr{ 211 IP: []byte{0, 0, 0, 0}, 212 Port: 0, 213 } 214 } 215 216 func (c *dispatcherConn) SetDeadline(t time.Time) error { 217 return nil 218 } 219 220 func (c *dispatcherConn) SetReadDeadline(t time.Time) error { 221 return nil 222 } 223 224 func (c *dispatcherConn) SetWriteDeadline(t time.Time) error { 225 return nil 226 }