github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/quic/quic.go (about)

     1  package quic
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/Asutorufa/yuhaiin/pkg/net/deadline"
    13  	"github.com/Asutorufa/yuhaiin/pkg/net/dialer"
    14  	"github.com/Asutorufa/yuhaiin/pkg/net/nat"
    15  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    16  	"github.com/Asutorufa/yuhaiin/pkg/protos/node/point"
    17  	"github.com/Asutorufa/yuhaiin/pkg/protos/node/protocol"
    18  	"github.com/Asutorufa/yuhaiin/pkg/protos/statistic"
    19  	"github.com/Asutorufa/yuhaiin/pkg/utils/id"
    20  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    21  	"github.com/Asutorufa/yuhaiin/pkg/utils/syncmap"
    22  	"github.com/quic-go/quic-go"
    23  )
    24  
    25  type Client struct {
    26  	netapi.EmptyDispatch
    27  
    28  	tlsConfig *tls.Config
    29  	dialer    netapi.Proxy
    30  
    31  	session     quic.Connection
    32  	underlying  net.PacketConn
    33  	sessionMu   sync.Mutex
    34  	sessionUnix int64
    35  
    36  	packetConn *ConnectionPacketConn
    37  	natMap     syncmap.SyncMap[uint64, *clientPacketConn]
    38  
    39  	idg id.IDGenerator
    40  
    41  	host *net.UDPAddr
    42  }
    43  
    44  func init() {
    45  	point.RegisterProtocol(NewClient)
    46  }
    47  
    48  func NewClient(config *protocol.Protocol_Quic) point.WrapProxy {
    49  	return func(dialer netapi.Proxy) (netapi.Proxy, error) {
    50  
    51  		var host *net.UDPAddr = &net.UDPAddr{IP: net.IPv4zero}
    52  
    53  		if config.Quic.Host != "" {
    54  			addr, err := netapi.ParseAddress(statistic.Type_udp, config.Quic.Host)
    55  			if err == nil {
    56  				if ur := addr.UDPAddr(context.TODO()); ur.Err == nil {
    57  					host = ur.V
    58  				}
    59  			}
    60  		}
    61  
    62  		tlsConfig := point.ParseTLSConfig(config.Quic.Tls)
    63  		if tlsConfig == nil {
    64  			tlsConfig = &tls.Config{}
    65  		}
    66  
    67  		if point.IsBootstrap(dialer) {
    68  			dialer = nil
    69  		}
    70  
    71  		c := &Client{
    72  			dialer:    dialer,
    73  			tlsConfig: tlsConfig,
    74  			host:      host,
    75  		}
    76  
    77  		return c, nil
    78  	}
    79  }
    80  
    81  func (c *Client) initSession(ctx context.Context) (quic.Connection, error) {
    82  	session := c.session
    83  
    84  	if session != nil {
    85  		select {
    86  		case <-session.Context().Done():
    87  		default:
    88  			return session, nil
    89  		}
    90  	}
    91  
    92  	c.sessionMu.Lock()
    93  	defer c.sessionMu.Unlock()
    94  
    95  	if c.session != nil {
    96  		select {
    97  		case <-c.session.Context().Done():
    98  		default:
    99  			return c.session, nil
   100  		}
   101  	}
   102  
   103  	if c.session != nil {
   104  		_ = c.session.CloseWithError(0, "")
   105  	}
   106  
   107  	if c.underlying != nil {
   108  		_ = c.underlying.Close()
   109  	}
   110  
   111  	var conn net.PacketConn
   112  	var err error
   113  
   114  	if c.dialer == nil {
   115  		conn, err = dialer.ListenPacket("udp", "")
   116  	} else {
   117  		conn, err = c.dialer.PacketConn(ctx, netapi.EmptyAddr)
   118  	}
   119  	if err != nil {
   120  		return nil, err
   121  	}
   122  
   123  	tr := quic.Transport{
   124  		Conn:               conn,
   125  		ConnectionIDLength: 12,
   126  	}
   127  
   128  	config := &quic.Config{
   129  		KeepAlivePeriod: 15 * time.Second,
   130  		MaxIdleTimeout:  nat.IdleTimeout,
   131  		EnableDatagrams: true,
   132  	}
   133  
   134  	session, err = tr.Dial(ctx, c.host, c.tlsConfig, config)
   135  	if err != nil {
   136  		_ = conn.Close()
   137  		return nil, err
   138  	}
   139  
   140  	pconn := NewConnectionPacketConn(session)
   141  
   142  	c.underlying = conn
   143  	c.session = session
   144  	c.sessionUnix = time.Now().Unix()
   145  
   146  	// Datagram
   147  	c.packetConn = pconn
   148  
   149  	go func() {
   150  		defer session.CloseWithError(0, "")
   151  		for {
   152  			id, data, err := pconn.Receive(context.TODO())
   153  			if err != nil {
   154  				return
   155  			}
   156  
   157  			cchan, ok := c.natMap.Load(id)
   158  			if !ok {
   159  				continue
   160  			}
   161  
   162  			select {
   163  			case <-session.Context().Done():
   164  				return
   165  			case <-cchan.ctx.Done():
   166  			case cchan.msg <- data:
   167  			}
   168  		}
   169  	}()
   170  	return session, nil
   171  }
   172  
   173  func (c *Client) Conn(ctx context.Context, s netapi.Address) (net.Conn, error) {
   174  	session, err := c.initSession(ctx)
   175  	if err != nil {
   176  		return nil, err
   177  	}
   178  
   179  	stream, err := session.OpenStream()
   180  	if err != nil {
   181  		_ = session.CloseWithError(0, "")
   182  		return nil, err
   183  	}
   184  
   185  	return &interConn{
   186  		Stream:  stream,
   187  		session: session,
   188  		time:    c.sessionUnix,
   189  	}, nil
   190  }
   191  
   192  func (c *Client) PacketConn(ctx context.Context, host netapi.Address) (net.PacketConn, error) {
   193  	_, err := c.initSession(ctx)
   194  	if err != nil {
   195  		return nil, err
   196  	}
   197  
   198  	ctx, cancel := context.WithCancel(context.TODO())
   199  
   200  	cp := &clientPacketConn{
   201  		c:        c,
   202  		ctx:      ctx,
   203  		cancel:   cancel,
   204  		session:  c.packetConn,
   205  		id:       c.idg.Generate(),
   206  		msg:      make(chan *pool.Buffer, 64),
   207  		deadline: deadline.NewPipe(),
   208  	}
   209  	c.natMap.Store(cp.id, cp)
   210  
   211  	return cp, nil
   212  }
   213  
   214  var _ net.Conn = (*interConn)(nil)
   215  
   216  type interConn struct {
   217  	quic.Stream
   218  	session quic.Connection
   219  	time    int64
   220  }
   221  
   222  func (c *interConn) Read(p []byte) (n int, err error) {
   223  	n, err = c.Stream.Read(p)
   224  
   225  	if err != nil && err != io.EOF {
   226  		qe, ok := err.(*quic.StreamError)
   227  		if ok && qe.ErrorCode == quic.StreamErrorCode(quic.NoError) {
   228  			err = io.EOF
   229  		}
   230  	}
   231  	return
   232  }
   233  
   234  func (c *interConn) Write(p []byte) (n int, err error) {
   235  	n, err = c.Stream.Write(p)
   236  	if err != nil && err != io.EOF {
   237  		qe, ok := err.(*quic.StreamError)
   238  		if ok && qe.ErrorCode == quic.StreamErrorCode(quic.NoError) {
   239  			err = io.EOF
   240  		}
   241  	}
   242  	return
   243  }
   244  
   245  func (c *interConn) Close() error {
   246  	c.Stream.CancelRead(0)
   247  	return c.Stream.Close()
   248  }
   249  
   250  func (c *interConn) LocalAddr() net.Addr {
   251  	return &QuicAddr{
   252  		Addr: c.session.LocalAddr(),
   253  		ID:   c.Stream.StreamID(),
   254  		time: c.time,
   255  	}
   256  }
   257  
   258  func (c *interConn) RemoteAddr() net.Addr {
   259  	return &QuicAddr{
   260  		Addr: c.session.RemoteAddr(),
   261  		ID:   c.Stream.StreamID(),
   262  		time: c.time,
   263  	}
   264  }
   265  
   266  type QuicAddr struct {
   267  	Addr net.Addr
   268  	ID   quic.StreamID
   269  	time int64
   270  }
   271  
   272  func (q *QuicAddr) String() string {
   273  	if q.time == 0 {
   274  		return fmt.Sprintf("quic://%d@%v", q.ID, q.Addr)
   275  	}
   276  	return fmt.Sprintf("quic://%d-%d@%v", q.time, q.ID, q.Addr)
   277  }
   278  
   279  func (q *QuicAddr) Network() string { return "udp" }
   280  
   281  type clientPacketConn struct {
   282  	c       *Client
   283  	session *ConnectionPacketConn
   284  	id      uint64
   285  
   286  	ctx    context.Context
   287  	cancel context.CancelFunc
   288  
   289  	msg chan *pool.Buffer
   290  
   291  	deadline *deadline.PipeDeadline
   292  }
   293  
   294  func (x *clientPacketConn) ReadFrom(p []byte) (n int, _ net.Addr, err error) {
   295  	select {
   296  	case <-x.session.Context().Done():
   297  		return x.read(p, func() error {
   298  			x.Close()
   299  			return x.session.Context().Err()
   300  		})
   301  	case <-x.deadline.ReadContext().Done():
   302  		return x.read(p, x.deadline.ReadContext().Err)
   303  	case <-x.ctx.Done():
   304  		return x.read(p, x.ctx.Err)
   305  	case msg := <-x.msg:
   306  		defer msg.Free()
   307  
   308  		n = copy(p, msg.Bytes())
   309  		return n, x.session.conn.RemoteAddr(), nil
   310  	}
   311  }
   312  
   313  func (x *clientPacketConn) read(p []byte, err func() error) (n int, _ net.Addr, _ error) {
   314  	if len(x.msg) > 0 {
   315  		select {
   316  		case msg := <-x.msg:
   317  			defer msg.Free()
   318  
   319  			n = copy(p, msg.Bytes())
   320  			return n, x.session.conn.RemoteAddr(), nil
   321  		default:
   322  		}
   323  	}
   324  
   325  	return 0, nil, err()
   326  }
   327  
   328  func (x *clientPacketConn) WriteTo(p []byte, _ net.Addr) (n int, err error) {
   329  	select {
   330  	case <-x.ctx.Done():
   331  		return 0, x.ctx.Err()
   332  	case <-x.deadline.WriteContext().Done():
   333  		return 0, x.deadline.WriteContext().Err()
   334  	case <-x.session.Context().Done():
   335  		return 0, x.session.Context().Err()
   336  	default:
   337  	}
   338  
   339  	err = x.session.Write(p, x.id)
   340  	if err != nil {
   341  		return 0, err
   342  	}
   343  	return len(p), nil
   344  }
   345  
   346  func (x *clientPacketConn) Close() error {
   347  	x.cancel()
   348  	x.deadline.Close()
   349  	x.c.natMap.Delete(x.id)
   350  	return nil
   351  }
   352  
   353  func (x *clientPacketConn) LocalAddr() net.Addr {
   354  	return &QuicAddr{
   355  		Addr: x.session.conn.LocalAddr(),
   356  		ID:   quic.StreamID(x.id),
   357  	}
   358  }
   359  
   360  func (x *clientPacketConn) SetDeadline(t time.Time) error {
   361  	select {
   362  	case <-x.ctx.Done():
   363  		return io.EOF
   364  	default:
   365  	}
   366  
   367  	x.deadline.SetDeadline(t)
   368  	return nil
   369  }
   370  
   371  func (x *clientPacketConn) SetReadDeadline(t time.Time) error {
   372  	x.deadline.SetReadDeadline(t)
   373  	return nil
   374  }
   375  
   376  func (x *clientPacketConn) SetWriteDeadline(t time.Time) error {
   377  	x.deadline.SetWriteDeadline(t)
   378  	return nil
   379  }