github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/quic/server.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"net"
     9  	"time"
    10  
    11  	"github.com/Asutorufa/yuhaiin/pkg/log"
    12  	"github.com/Asutorufa/yuhaiin/pkg/net/deadline"
    13  	"github.com/Asutorufa/yuhaiin/pkg/net/dialer"
    14  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    15  	"github.com/Asutorufa/yuhaiin/pkg/protos/config/listener"
    16  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    17  	"github.com/Asutorufa/yuhaiin/pkg/utils/syncmap"
    18  	"github.com/quic-go/quic-go"
    19  )
    20  
    21  type Server struct {
    22  	packetConn net.PacketConn
    23  	*quic.Listener
    24  
    25  	ctx      context.Context
    26  	cancel   context.CancelFunc
    27  	connChan chan *interConn
    28  
    29  	packetChan chan serverMsg
    30  	natMap     syncmap.SyncMap[string, *ConnectionPacketConn]
    31  }
    32  
    33  func init() {
    34  	listener.RegisterNetwork(NewServer)
    35  }
    36  
    37  func NewServer(c *listener.Inbound_Quic) (netapi.Listener, error) {
    38  	packetConn, err := dialer.ListenPacket("udp", c.Quic.Host)
    39  	if err != nil {
    40  		return nil, err
    41  	}
    42  
    43  	tlsConfig, err := listener.ParseTLS(c.Quic.Tls)
    44  	if err != nil {
    45  		return nil, err
    46  	}
    47  
    48  	return newServer(packetConn, tlsConfig)
    49  }
    50  
    51  func newServer(packetConn net.PacketConn, tlsConfig *tls.Config) (*Server, error) {
    52  	tr := quic.Transport{
    53  		Conn:               packetConn,
    54  		ConnectionIDLength: 12,
    55  	}
    56  
    57  	config := &quic.Config{
    58  		MaxIncomingStreams:    1 << 60,
    59  		KeepAlivePeriod:       0,
    60  		MaxIdleTimeout:        3 * time.Minute,
    61  		EnableDatagrams:       true,
    62  		Allow0RTT:             true,
    63  		MaxIncomingUniStreams: -1,
    64  	}
    65  
    66  	lis, err := tr.Listen(tlsConfig, config)
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  
    71  	ctx, cancel := context.WithCancel(context.Background())
    72  
    73  	s := &Server{
    74  		packetConn: packetConn,
    75  		ctx:        ctx,
    76  		cancel:     cancel,
    77  		connChan:   make(chan *interConn, 100),
    78  		packetChan: make(chan serverMsg, 100),
    79  		Listener:   lis,
    80  	}
    81  
    82  	go func() {
    83  		defer s.Close()
    84  		if err := s.server(); err != nil {
    85  			log.Error("quic server failed:", "err", err)
    86  		}
    87  	}()
    88  
    89  	return s, nil
    90  }
    91  
    92  func (s *Server) Close() error {
    93  	var err error
    94  
    95  	s.cancel()
    96  	if s.Listener != nil {
    97  		if er := s.Listener.Close(); er != nil {
    98  			err = errors.Join(err, er)
    99  		}
   100  	}
   101  	if s.packetConn != nil {
   102  		if er := s.packetConn.Close(); er != nil {
   103  			err = errors.Join(err, er)
   104  		}
   105  	}
   106  
   107  	return err
   108  }
   109  
   110  func (s *Server) Accept() (net.Conn, error) {
   111  	select {
   112  	case conn := <-s.connChan:
   113  		return conn, nil
   114  	case <-s.ctx.Done():
   115  		return nil, s.ctx.Err()
   116  	}
   117  }
   118  
   119  func (s *Server) Packet(context.Context) (net.PacketConn, error) {
   120  	return newServerPacketConn(s), nil
   121  }
   122  
   123  func (s *Server) Stream(ctx context.Context) (net.Listener, error) {
   124  	return s, nil
   125  }
   126  
   127  func (s *Server) server() error {
   128  	for {
   129  		conn, err := s.Listener.Accept(s.ctx)
   130  		if err != nil {
   131  			return err
   132  		}
   133  
   134  		go func() {
   135  			defer conn.CloseWithError(quic.ApplicationErrorCode(quic.NoError), "") // nolint:errcheck
   136  
   137  			go func() {
   138  				if err := s.listenDatagram(conn); err != nil {
   139  					log.Error("listen datagram failed:", "err", err)
   140  				}
   141  			}()
   142  
   143  			if err := s.listenStream(conn); err != nil {
   144  				log.Error("listen quic connection failed:", "err", err)
   145  			}
   146  		}()
   147  	}
   148  }
   149  
   150  func (s *Server) listenDatagram(conn quic.Connection) error {
   151  	raddr := conn.RemoteAddr()
   152  
   153  	packetConn := NewConnectionPacketConn(conn)
   154  
   155  	s.natMap.Store(raddr.String(), packetConn)
   156  	defer s.natMap.Delete(raddr.String())
   157  
   158  	for {
   159  		id, data, err := packetConn.Receive(s.ctx)
   160  		if err != nil {
   161  			return err
   162  		}
   163  
   164  		select {
   165  		case <-s.ctx.Done():
   166  			return s.ctx.Err()
   167  		case s.packetChan <- serverMsg{msg: data, src: raddr, id: id}:
   168  		}
   169  	}
   170  }
   171  func (s *Server) listenStream(conn quic.Connection) error {
   172  	for {
   173  		stream, err := conn.AcceptStream(s.ctx)
   174  		if err != nil {
   175  			return err
   176  		}
   177  
   178  		select {
   179  		case <-s.ctx.Done():
   180  			return s.ctx.Err()
   181  		case s.connChan <- &interConn{
   182  			Stream:  stream,
   183  			session: conn,
   184  		}:
   185  		}
   186  	}
   187  }
   188  
   189  type serverMsg struct {
   190  	msg *pool.Buffer
   191  	src net.Addr
   192  	id  uint64
   193  }
   194  type serverPacketConn struct {
   195  	*Server
   196  
   197  	ctx    context.Context
   198  	cancel context.CancelFunc
   199  
   200  	deadline *deadline.PipeDeadline
   201  }
   202  
   203  func newServerPacketConn(s *Server) *serverPacketConn {
   204  	ctx, cancel := context.WithCancel(s.ctx)
   205  	return &serverPacketConn{
   206  		Server:   s,
   207  		ctx:      ctx,
   208  		cancel:   cancel,
   209  		deadline: deadline.NewPipe(),
   210  	}
   211  }
   212  
   213  func (x *serverPacketConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) {
   214  	select {
   215  	case <-x.Server.ctx.Done():
   216  		x.cancel()
   217  		return 0, nil, x.Server.ctx.Err()
   218  	case <-x.ctx.Done():
   219  		return 0, nil, x.ctx.Err()
   220  	case <-x.deadline.ReadContext().Done():
   221  		return 0, nil, x.deadline.ReadContext().Err()
   222  	case msg := <-x.packetChan:
   223  		defer msg.msg.Free()
   224  
   225  		n = copy(p, msg.msg.Bytes())
   226  		return n, &QuicAddr{Addr: msg.src, ID: quic.StreamID(msg.id)}, nil
   227  	}
   228  }
   229  
   230  func (x *serverPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) {
   231  	select {
   232  	case <-x.Server.ctx.Done():
   233  		return 0, x.Server.ctx.Err()
   234  	case <-x.ctx.Done():
   235  		return 0, x.ctx.Err()
   236  	case <-x.deadline.WriteContext().Done():
   237  		return 0, x.deadline.WriteContext().Err()
   238  	default:
   239  	}
   240  
   241  	qaddr, ok := addr.(*QuicAddr)
   242  	if !ok {
   243  		return 0, errors.New("invalid addr")
   244  	}
   245  
   246  	conn, ok := x.natMap.Load(qaddr.Addr.String())
   247  	if !ok {
   248  		return 0, fmt.Errorf("no such addr: %s", addr.String())
   249  	}
   250  	err = conn.Write(p, uint64(qaddr.ID))
   251  	return len(p), err
   252  }
   253  
   254  func (x *serverPacketConn) LocalAddr() net.Addr {
   255  	return x.Addr()
   256  }
   257  
   258  func (x *serverPacketConn) SetDeadline(t time.Time) error {
   259  	select {
   260  	case <-x.Server.ctx.Done():
   261  		return x.Server.ctx.Err()
   262  	case <-x.ctx.Done():
   263  		return x.ctx.Err()
   264  	default:
   265  	}
   266  
   267  	x.deadline.SetDeadline(t)
   268  	return nil
   269  }
   270  
   271  func (x *serverPacketConn) SetReadDeadline(t time.Time) error {
   272  	x.deadline.SetReadDeadline(t)
   273  	return nil
   274  }
   275  
   276  func (x *serverPacketConn) SetWriteDeadline(t time.Time) error {
   277  	x.deadline.SetWriteDeadline(t)
   278  	return nil
   279  }
   280  
   281  func (x *serverPacketConn) Close() error {
   282  	x.cancel()
   283  	x.deadline.Close()
   284  	return x.Server.Close()
   285  }