github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/quic/server.go (about) 1 package quic 2 3 import ( 4 "context" 5 "crypto/tls" 6 "errors" 7 "fmt" 8 "net" 9 "time" 10 11 "github.com/Asutorufa/yuhaiin/pkg/log" 12 "github.com/Asutorufa/yuhaiin/pkg/net/deadline" 13 "github.com/Asutorufa/yuhaiin/pkg/net/dialer" 14 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 15 "github.com/Asutorufa/yuhaiin/pkg/protos/config/listener" 16 "github.com/Asutorufa/yuhaiin/pkg/utils/pool" 17 "github.com/Asutorufa/yuhaiin/pkg/utils/syncmap" 18 "github.com/quic-go/quic-go" 19 ) 20 21 type Server struct { 22 packetConn net.PacketConn 23 *quic.Listener 24 25 ctx context.Context 26 cancel context.CancelFunc 27 connChan chan *interConn 28 29 packetChan chan serverMsg 30 natMap syncmap.SyncMap[string, *ConnectionPacketConn] 31 } 32 33 func init() { 34 listener.RegisterNetwork(NewServer) 35 } 36 37 func NewServer(c *listener.Inbound_Quic) (netapi.Listener, error) { 38 packetConn, err := dialer.ListenPacket("udp", c.Quic.Host) 39 if err != nil { 40 return nil, err 41 } 42 43 tlsConfig, err := listener.ParseTLS(c.Quic.Tls) 44 if err != nil { 45 return nil, err 46 } 47 48 return newServer(packetConn, tlsConfig) 49 } 50 51 func newServer(packetConn net.PacketConn, tlsConfig *tls.Config) (*Server, error) { 52 tr := quic.Transport{ 53 Conn: packetConn, 54 ConnectionIDLength: 12, 55 } 56 57 config := &quic.Config{ 58 MaxIncomingStreams: 1 << 60, 59 KeepAlivePeriod: 0, 60 MaxIdleTimeout: 3 * time.Minute, 61 EnableDatagrams: true, 62 Allow0RTT: true, 63 MaxIncomingUniStreams: -1, 64 } 65 66 lis, err := tr.Listen(tlsConfig, config) 67 if err != nil { 68 return nil, err 69 } 70 71 ctx, cancel := context.WithCancel(context.Background()) 72 73 s := &Server{ 74 packetConn: packetConn, 75 ctx: ctx, 76 cancel: cancel, 77 connChan: make(chan *interConn, 100), 78 packetChan: make(chan serverMsg, 100), 79 Listener: lis, 80 } 81 82 go func() { 83 defer s.Close() 84 if err := s.server(); err != nil { 85 log.Error("quic server failed:", "err", err) 86 } 87 }() 88 89 return s, nil 90 } 91 92 func (s *Server) Close() error { 93 var err error 94 95 s.cancel() 96 if s.Listener != nil { 97 if er := s.Listener.Close(); er != nil { 98 err = errors.Join(err, er) 99 } 100 } 101 if s.packetConn != nil { 102 if er := s.packetConn.Close(); er != nil { 103 err = errors.Join(err, er) 104 } 105 } 106 107 return err 108 } 109 110 func (s *Server) Accept() (net.Conn, error) { 111 select { 112 case conn := <-s.connChan: 113 return conn, nil 114 case <-s.ctx.Done(): 115 return nil, s.ctx.Err() 116 } 117 } 118 119 func (s *Server) Packet(context.Context) (net.PacketConn, error) { 120 return newServerPacketConn(s), nil 121 } 122 123 func (s *Server) Stream(ctx context.Context) (net.Listener, error) { 124 return s, nil 125 } 126 127 func (s *Server) server() error { 128 for { 129 conn, err := s.Listener.Accept(s.ctx) 130 if err != nil { 131 return err 132 } 133 134 go func() { 135 defer conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "") // nolint:errcheck 136 137 go func() { 138 if err := s.listenDatagram(conn); err != nil { 139 log.Error("listen datagram failed:", "err", err) 140 } 141 }() 142 143 if err := s.listenStream(conn); err != nil { 144 log.Error("listen quic connection failed:", "err", err) 145 } 146 }() 147 } 148 } 149 150 func (s *Server) listenDatagram(conn quic.Connection) error { 151 raddr := conn.RemoteAddr() 152 153 packetConn := NewConnectionPacketConn(conn) 154 155 s.natMap.Store(raddr.String(), packetConn) 156 defer s.natMap.Delete(raddr.String()) 157 158 for { 159 id, data, err := packetConn.Receive(s.ctx) 160 if err != nil { 161 return err 162 } 163 164 select { 165 case <-s.ctx.Done(): 166 return s.ctx.Err() 167 case s.packetChan <- serverMsg{msg: data, src: raddr, id: id}: 168 } 169 } 170 } 171 func (s *Server) listenStream(conn quic.Connection) error { 172 for { 173 stream, err := conn.AcceptStream(s.ctx) 174 if err != nil { 175 return err 176 } 177 178 select { 179 case <-s.ctx.Done(): 180 return s.ctx.Err() 181 case s.connChan <- &interConn{ 182 Stream: stream, 183 session: conn, 184 }: 185 } 186 } 187 } 188 189 type serverMsg struct { 190 msg *pool.Buffer 191 src net.Addr 192 id uint64 193 } 194 type serverPacketConn struct { 195 *Server 196 197 ctx context.Context 198 cancel context.CancelFunc 199 200 deadline *deadline.PipeDeadline 201 } 202 203 func newServerPacketConn(s *Server) *serverPacketConn { 204 ctx, cancel := context.WithCancel(s.ctx) 205 return &serverPacketConn{ 206 Server: s, 207 ctx: ctx, 208 cancel: cancel, 209 deadline: deadline.NewPipe(), 210 } 211 } 212 213 func (x *serverPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { 214 select { 215 case <-x.Server.ctx.Done(): 216 x.cancel() 217 return 0, nil, x.Server.ctx.Err() 218 case <-x.ctx.Done(): 219 return 0, nil, x.ctx.Err() 220 case <-x.deadline.ReadContext().Done(): 221 return 0, nil, x.deadline.ReadContext().Err() 222 case msg := <-x.packetChan: 223 defer msg.msg.Free() 224 225 n = copy(p, msg.msg.Bytes()) 226 return n, &QuicAddr{Addr: msg.src, ID: quic.StreamID(msg.id)}, nil 227 } 228 } 229 230 func (x *serverPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { 231 select { 232 case <-x.Server.ctx.Done(): 233 return 0, x.Server.ctx.Err() 234 case <-x.ctx.Done(): 235 return 0, x.ctx.Err() 236 case <-x.deadline.WriteContext().Done(): 237 return 0, x.deadline.WriteContext().Err() 238 default: 239 } 240 241 qaddr, ok := addr.(*QuicAddr) 242 if !ok { 243 return 0, errors.New("invalid addr") 244 } 245 246 conn, ok := x.natMap.Load(qaddr.Addr.String()) 247 if !ok { 248 return 0, fmt.Errorf("no such addr: %s", addr.String()) 249 } 250 err = conn.Write(p, uint64(qaddr.ID)) 251 return len(p), err 252 } 253 254 func (x *serverPacketConn) LocalAddr() net.Addr { 255 return x.Addr() 256 } 257 258 func (x *serverPacketConn) SetDeadline(t time.Time) error { 259 select { 260 case <-x.Server.ctx.Done(): 261 return x.Server.ctx.Err() 262 case <-x.ctx.Done(): 263 return x.ctx.Err() 264 default: 265 } 266 267 x.deadline.SetDeadline(t) 268 return nil 269 } 270 271 func (x *serverPacketConn) SetReadDeadline(t time.Time) error { 272 x.deadline.SetReadDeadline(t) 273 return nil 274 } 275 276 func (x *serverPacketConn) SetWriteDeadline(t time.Time) error { 277 x.deadline.SetWriteDeadline(t) 278 return nil 279 } 280 281 func (x *serverPacketConn) Close() error { 282 x.cancel() 283 x.deadline.Close() 284 return x.Server.Close() 285 }