github.com/sagernet/sing@v0.2.6/common/udpnat/service.go (about) 1 package udpnat 2 3 import ( 4 "context" 5 "io" 6 "net" 7 "os" 8 "time" 9 10 "github.com/sagernet/sing/common" 11 "github.com/sagernet/sing/common/buf" 12 "github.com/sagernet/sing/common/cache" 13 E "github.com/sagernet/sing/common/exceptions" 14 M "github.com/sagernet/sing/common/metadata" 15 N "github.com/sagernet/sing/common/network" 16 ) 17 18 type Handler interface { 19 N.UDPConnectionHandler 20 E.Handler 21 } 22 23 type Service[K comparable] struct { 24 nat *cache.LruCache[K, *conn] 25 handler Handler 26 } 27 28 func New[K comparable](maxAge int64, handler Handler) *Service[K] { 29 return &Service[K]{ 30 nat: cache.New( 31 cache.WithAge[K, *conn](maxAge), 32 cache.WithUpdateAgeOnGet[K, *conn](), 33 cache.WithEvict[K, *conn](func(key K, conn *conn) { 34 conn.Close() 35 }), 36 ), 37 handler: handler, 38 } 39 } 40 41 func (s *Service[T]) WriteIsThreadUnsafe() { 42 } 43 44 func (s *Service[T]) NewPacketDirect(ctx context.Context, key T, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) { 45 s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) { 46 return ctx, &DirectBackWriter{conn, natConn} 47 }) 48 } 49 50 type DirectBackWriter struct { 51 Source N.PacketConn 52 Nat N.PacketConn 53 } 54 55 func (w *DirectBackWriter) WritePacket(buffer *buf.Buffer, addr M.Socksaddr) error { 56 return w.Source.WritePacket(buffer, M.SocksaddrFromNet(w.Nat.LocalAddr())) 57 } 58 59 func (w *DirectBackWriter) Upstream() any { 60 return w.Source 61 } 62 63 func (s *Service[T]) NewPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) N.PacketWriter) { 64 s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) { 65 return ctx, init(natConn) 66 }) 67 } 68 69 func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) (context.Context, N.PacketWriter)) { 70 c, loaded := s.nat.LoadOrStore(key, func() *conn { 71 c := &conn{ 72 data: make(chan packet, 64), 73 localAddr: metadata.Source, 74 remoteAddr: metadata.Destination, 75 } 76 c.ctx, c.cancel = common.ContextWithCancelCause(ctx) 77 return c 78 }) 79 if !loaded { 80 ctx, c.source = init(c) 81 go func() { 82 err := s.handler.NewPacketConnection(ctx, c, metadata) 83 if err != nil { 84 s.handler.NewError(ctx, err) 85 } 86 c.Close() 87 s.nat.Delete(key) 88 }() 89 } else { 90 c.localAddr = metadata.Source 91 } 92 if common.Done(c.ctx) { 93 s.nat.Delete(key) 94 if !common.Done(ctx) { 95 s.NewContextPacket(ctx, key, buffer, metadata, init) 96 } 97 return 98 } 99 c.data <- packet{ 100 data: buffer, 101 destination: metadata.Destination, 102 } 103 } 104 105 type packet struct { 106 data *buf.Buffer 107 destination M.Socksaddr 108 } 109 110 type conn struct { 111 ctx context.Context 112 cancel common.ContextCancelCauseFunc 113 data chan packet 114 localAddr M.Socksaddr 115 remoteAddr M.Socksaddr 116 source N.PacketWriter 117 } 118 119 func (c *conn) ReadPacketThreadSafe() (buffer *buf.Buffer, addr M.Socksaddr, err error) { 120 select { 121 case p := <-c.data: 122 return p.data, p.destination, nil 123 case <-c.ctx.Done(): 124 return nil, M.Socksaddr{}, io.ErrClosedPipe 125 } 126 } 127 128 func (c *conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { 129 select { 130 case p := <-c.data: 131 _, err = buffer.ReadOnceFrom(p.data) 132 p.data.Release() 133 return p.destination, err 134 case <-c.ctx.Done(): 135 return M.Socksaddr{}, io.ErrClosedPipe 136 } 137 } 138 139 func (c *conn) WaitReadPacket(newBuffer func() *buf.Buffer) (destination M.Socksaddr, err error) { 140 select { 141 case p := <-c.data: 142 _, err = newBuffer().ReadOnceFrom(p.data) 143 p.data.Release() 144 return p.destination, err 145 case <-c.ctx.Done(): 146 return M.Socksaddr{}, io.ErrClosedPipe 147 } 148 } 149 150 func (c *conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 151 return c.source.WritePacket(buffer, destination) 152 } 153 154 func (c *conn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 155 select { 156 case pkt := <-c.data: 157 n = copy(p, pkt.data.Bytes()) 158 pkt.data.Release() 159 addr = pkt.destination.UDPAddr() 160 return n, addr, nil 161 case <-c.ctx.Done(): 162 return 0, nil, io.ErrClosedPipe 163 } 164 } 165 166 func (c *conn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 167 return len(p), c.source.WritePacket(buf.As(p).ToOwned(), M.SocksaddrFromNet(addr)) 168 } 169 170 func (c *conn) Close() error { 171 select { 172 case <-c.ctx.Done(): 173 default: 174 c.cancel(net.ErrClosed) 175 } 176 if sourceCloser, sourceIsCloser := c.source.(io.Closer); sourceIsCloser { 177 return sourceCloser.Close() 178 } 179 return nil 180 } 181 182 func (c *conn) LocalAddr() net.Addr { 183 return c.localAddr 184 } 185 186 func (c *conn) RemoteAddr() net.Addr { 187 return c.remoteAddr 188 } 189 190 func (c *conn) SetDeadline(t time.Time) error { 191 return os.ErrInvalid 192 } 193 194 func (c *conn) SetReadDeadline(t time.Time) error { 195 return os.ErrInvalid 196 } 197 198 func (c *conn) SetWriteDeadline(t time.Time) error { 199 return os.ErrInvalid 200 } 201 202 func (c *conn) NeedAdditionalReadDeadline() bool { 203 return true 204 } 205 206 func (c *conn) Upstream() any { 207 return c.source 208 }