github.com/sagernet/sing@v0.4.0-beta.19.0.20240518125136-f67a0988a636/common/udpnat/service.go (about) 1 package udpnat 2 3 import ( 4 "context" 5 "io" 6 "net" 7 "os" 8 "time" 9 10 "github.com/sagernet/sing/common" 11 "github.com/sagernet/sing/common/buf" 12 "github.com/sagernet/sing/common/cache" 13 E "github.com/sagernet/sing/common/exceptions" 14 M "github.com/sagernet/sing/common/metadata" 15 N "github.com/sagernet/sing/common/network" 16 ) 17 18 type Handler interface { 19 N.UDPConnectionHandler 20 E.Handler 21 } 22 23 type Service[K comparable] struct { 24 nat *cache.LruCache[K, *conn] 25 handler Handler 26 } 27 28 func New[K comparable](maxAge int64, handler Handler) *Service[K] { 29 return &Service[K]{ 30 nat: cache.New( 31 cache.WithAge[K, *conn](maxAge), 32 cache.WithUpdateAgeOnGet[K, *conn](), 33 cache.WithEvict[K, *conn](func(key K, conn *conn) { 34 conn.Close() 35 }), 36 ), 37 handler: handler, 38 } 39 } 40 41 func (s *Service[T]) WriteIsThreadUnsafe() { 42 } 43 44 func (s *Service[T]) NewPacketDirect(ctx context.Context, key T, conn N.PacketConn, buffer *buf.Buffer, metadata M.Metadata) { 45 s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) { 46 return ctx, &DirectBackWriter{conn, natConn} 47 }) 48 } 49 50 type DirectBackWriter struct { 51 Source N.PacketConn 52 Nat N.PacketConn 53 } 54 55 func (w *DirectBackWriter) WritePacket(buffer *buf.Buffer, addr M.Socksaddr) error { 56 return w.Source.WritePacket(buffer, M.SocksaddrFromNet(w.Nat.LocalAddr())) 57 } 58 59 func (w *DirectBackWriter) Upstream() any { 60 return w.Source 61 } 62 63 func (s *Service[T]) NewPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) N.PacketWriter) { 64 s.NewContextPacket(ctx, key, buffer, metadata, func(natConn N.PacketConn) (context.Context, N.PacketWriter) { 65 return ctx, init(natConn) 66 }) 67 } 68 69 func (s *Service[T]) NewContextPacket(ctx context.Context, key T, buffer *buf.Buffer, metadata M.Metadata, init func(natConn N.PacketConn) (context.Context, N.PacketWriter)) { 70 c, loaded := s.nat.LoadOrStore(key, func() *conn { 71 c := &conn{ 72 data: make(chan packet, 64), 73 localAddr: metadata.Source, 74 remoteAddr: metadata.Destination, 75 } 76 c.ctx, c.cancel = common.ContextWithCancelCause(ctx) 77 return c 78 }) 79 if !loaded { 80 ctx, c.source = init(c) 81 go func() { 82 err := s.handler.NewPacketConnection(ctx, c, metadata) 83 if err != nil { 84 s.handler.NewError(ctx, err) 85 } 86 c.Close() 87 s.nat.Delete(key) 88 }() 89 } else { 90 c.localAddr = metadata.Source 91 } 92 if common.Done(c.ctx) { 93 s.nat.Delete(key) 94 if !common.Done(ctx) { 95 s.NewContextPacket(ctx, key, buffer, metadata, init) 96 } 97 return 98 } 99 c.data <- packet{ 100 data: buffer, 101 destination: metadata.Destination, 102 } 103 } 104 105 type packet struct { 106 data *buf.Buffer 107 destination M.Socksaddr 108 } 109 110 var _ N.PacketConn = (*conn)(nil) 111 112 type conn struct { 113 ctx context.Context 114 cancel common.ContextCancelCauseFunc 115 data chan packet 116 localAddr M.Socksaddr 117 remoteAddr M.Socksaddr 118 source N.PacketWriter 119 readWaitOptions N.ReadWaitOptions 120 } 121 122 func (c *conn) ReadPacket(buffer *buf.Buffer) (addr M.Socksaddr, err error) { 123 select { 124 case p := <-c.data: 125 _, err = buffer.ReadOnceFrom(p.data) 126 p.data.Release() 127 return p.destination, err 128 case <-c.ctx.Done(): 129 return M.Socksaddr{}, io.ErrClosedPipe 130 } 131 } 132 133 func (c *conn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error { 134 return c.source.WritePacket(buffer, destination) 135 } 136 137 func (c *conn) Close() error { 138 select { 139 case <-c.ctx.Done(): 140 default: 141 c.cancel(net.ErrClosed) 142 } 143 if sourceCloser, sourceIsCloser := c.source.(io.Closer); sourceIsCloser { 144 return sourceCloser.Close() 145 } 146 return nil 147 } 148 149 func (c *conn) LocalAddr() net.Addr { 150 return c.localAddr 151 } 152 153 func (c *conn) RemoteAddr() net.Addr { 154 return c.remoteAddr 155 } 156 157 func (c *conn) SetDeadline(t time.Time) error { 158 return os.ErrInvalid 159 } 160 161 func (c *conn) SetReadDeadline(t time.Time) error { 162 return os.ErrInvalid 163 } 164 165 func (c *conn) SetWriteDeadline(t time.Time) error { 166 return os.ErrInvalid 167 } 168 169 func (c *conn) NeedAdditionalReadDeadline() bool { 170 return true 171 } 172 173 func (c *conn) Upstream() any { 174 return c.source 175 }