github.com/moqsien/xraycore@v1.8.5/transport/internet/udp/dispatcher.go (about) 1 package udp 2 3 import ( 4 "context" 5 "errors" 6 "io" 7 "sync" 8 "time" 9 10 "github.com/moqsien/xraycore/common" 11 "github.com/moqsien/xraycore/common/buf" 12 "github.com/moqsien/xraycore/common/net" 13 "github.com/moqsien/xraycore/common/protocol/udp" 14 "github.com/moqsien/xraycore/common/session" 15 "github.com/moqsien/xraycore/common/signal" 16 "github.com/moqsien/xraycore/common/signal/done" 17 "github.com/moqsien/xraycore/features/routing" 18 "github.com/moqsien/xraycore/transport" 19 ) 20 21 type ResponseCallback func(ctx context.Context, packet *udp.Packet) 22 23 type connEntry struct { 24 link *transport.Link 25 timer signal.ActivityUpdater 26 cancel context.CancelFunc 27 } 28 29 type Dispatcher struct { 30 sync.RWMutex 31 conns map[net.Destination]*connEntry 32 dispatcher routing.Dispatcher 33 callback ResponseCallback 34 callClose func() error 35 } 36 37 func NewDispatcher(dispatcher routing.Dispatcher, callback ResponseCallback) *Dispatcher { 38 return &Dispatcher{ 39 conns: make(map[net.Destination]*connEntry), 40 dispatcher: dispatcher, 41 callback: callback, 42 } 43 } 44 45 func (v *Dispatcher) RemoveRay(dest net.Destination) { 46 v.Lock() 47 defer v.Unlock() 48 if conn, found := v.conns[dest]; found { 49 common.Close(conn.link.Reader) 50 common.Close(conn.link.Writer) 51 delete(v.conns, dest) 52 } 53 } 54 55 func (v *Dispatcher) getInboundRay(ctx context.Context, dest net.Destination) (*connEntry, error) { 56 v.Lock() 57 defer v.Unlock() 58 59 if entry, found := v.conns[dest]; found { 60 return entry, nil 61 } 62 63 newError("establishing new connection for ", dest).WriteToLog() 64 65 ctx, cancel := context.WithCancel(ctx) 66 removeRay := func() { 67 cancel() 68 v.RemoveRay(dest) 69 } 70 timer := signal.CancelAfterInactivity(ctx, removeRay, time.Minute) 71 72 link, err := v.dispatcher.Dispatch(ctx, dest) 73 if err != nil { 74 return nil, newError("failed to dispatch request to ", dest).Base(err) 75 } 76 77 entry := &connEntry{ 78 link: link, 79 timer: timer, 80 cancel: removeRay, 81 } 82 v.conns[dest] = entry 83 go handleInput(ctx, entry, dest, v.callback, v.callClose) 84 return entry, nil 85 } 86 87 func (v *Dispatcher) Dispatch(ctx context.Context, destination net.Destination, payload *buf.Buffer) { 88 // TODO: Add user to destString 89 newError("dispatch request to: ", destination).AtDebug().WriteToLog(session.ExportIDToError(ctx)) 90 91 conn, err := v.getInboundRay(ctx, destination) 92 if err != nil { 93 newError("failed to get inbound").Base(err).WriteToLog(session.ExportIDToError(ctx)) 94 return 95 } 96 outputStream := conn.link.Writer 97 if outputStream != nil { 98 if err := outputStream.WriteMultiBuffer(buf.MultiBuffer{payload}); err != nil { 99 newError("failed to write first UDP payload").Base(err).WriteToLog(session.ExportIDToError(ctx)) 100 conn.cancel() 101 return 102 } 103 } 104 } 105 106 func handleInput(ctx context.Context, conn *connEntry, dest net.Destination, callback ResponseCallback, callClose func() error) { 107 defer func() { 108 conn.cancel() 109 if callClose != nil { 110 callClose() 111 } 112 }() 113 114 input := conn.link.Reader 115 timer := conn.timer 116 117 for { 118 select { 119 case <-ctx.Done(): 120 return 121 default: 122 } 123 124 mb, err := input.ReadMultiBuffer() 125 if err != nil { 126 if !errors.Is(err, io.EOF) { 127 newError("failed to handle UDP input").Base(err).WriteToLog(session.ExportIDToError(ctx)) 128 } 129 return 130 } 131 timer.Update() 132 for _, b := range mb { 133 callback(ctx, &udp.Packet{ 134 Payload: b, 135 Source: dest, 136 }) 137 } 138 } 139 } 140 141 type dispatcherConn struct { 142 dispatcher *Dispatcher 143 cache chan *udp.Packet 144 done *done.Instance 145 ctx context.Context 146 } 147 148 func DialDispatcher(ctx context.Context, dispatcher routing.Dispatcher) (net.PacketConn, error) { 149 c := &dispatcherConn{ 150 cache: make(chan *udp.Packet, 16), 151 done: done.New(), 152 ctx: ctx, 153 } 154 155 d := &Dispatcher{ 156 conns: make(map[net.Destination]*connEntry), 157 dispatcher: dispatcher, 158 callback: c.callback, 159 callClose: c.Close, 160 } 161 c.dispatcher = d 162 return c, nil 163 } 164 165 func (c *dispatcherConn) callback(ctx context.Context, packet *udp.Packet) { 166 select { 167 case <-c.done.Wait(): 168 packet.Payload.Release() 169 return 170 case c.cache <- packet: 171 default: 172 packet.Payload.Release() 173 return 174 } 175 } 176 177 func (c *dispatcherConn) ReadFrom(p []byte) (int, net.Addr, error) { 178 var packet *udp.Packet 179 s: 180 select { 181 case <-c.done.Wait(): 182 select { 183 case packet = <-c.cache: 184 break s 185 default: 186 return 0, nil, io.EOF 187 } 188 case packet = <-c.cache: 189 } 190 return copy(p, packet.Payload.Bytes()), &net.UDPAddr{ 191 IP: packet.Source.Address.IP(), 192 Port: int(packet.Source.Port), 193 }, nil 194 } 195 196 func (c *dispatcherConn) WriteTo(p []byte, addr net.Addr) (int, error) { 197 buffer := buf.New() 198 raw := buffer.Extend(buf.Size) 199 n := copy(raw, p) 200 buffer.Resize(0, int32(n)) 201 202 c.dispatcher.Dispatch(c.ctx, net.DestinationFromAddr(addr), buffer) 203 return n, nil 204 } 205 206 func (c *dispatcherConn) Close() error { 207 return c.done.Close() 208 } 209 210 func (c *dispatcherConn) LocalAddr() net.Addr { 211 return &net.UDPAddr{ 212 IP: []byte{0, 0, 0, 0}, 213 Port: 0, 214 } 215 } 216 217 func (c *dispatcherConn) SetDeadline(t time.Time) error { 218 return nil 219 } 220 221 func (c *dispatcherConn) SetReadDeadline(t time.Time) error { 222 return nil 223 } 224 225 func (c *dispatcherConn) SetWriteDeadline(t time.Time) error { 226 return nil 227 }