github.com/sagernet/sing-mux@v0.2.1-0.20240124034317-9bfb33698bb6/h2mux.go (about)

     1  package mux
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"io"
     7  	"net"
     8  	"net/http"
     9  	"net/url"
    10  	"os"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/sagernet/sing/common/atomic"
    15  	"github.com/sagernet/sing/common/buf"
    16  	"github.com/sagernet/sing/common/bufio"
    17  	E "github.com/sagernet/sing/common/exceptions"
    18  	N "github.com/sagernet/sing/common/network"
    19  
    20  	"golang.org/x/net/http2"
    21  )
    22  
    23  const idleTimeout = 30 * time.Second
    24  
    25  var _ abstractSession = (*h2MuxServerSession)(nil)
    26  
    27  type h2MuxServerSession struct {
    28  	server  http2.Server
    29  	active  atomic.Int32
    30  	conn    net.Conn
    31  	inbound chan net.Conn
    32  	done    chan struct{}
    33  }
    34  
    35  func newH2MuxServer(conn net.Conn) *h2MuxServerSession {
    36  	session := &h2MuxServerSession{
    37  		conn:    conn,
    38  		inbound: make(chan net.Conn),
    39  		done:    make(chan struct{}),
    40  		server: http2.Server{
    41  			IdleTimeout:      idleTimeout,
    42  			MaxReadFrameSize: buf.BufferSize,
    43  		},
    44  	}
    45  	go func() {
    46  		session.server.ServeConn(conn, &http2.ServeConnOpts{
    47  			Handler: session,
    48  		})
    49  		_ = session.Close()
    50  	}()
    51  	return session
    52  }
    53  
    54  func (s *h2MuxServerSession) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
    55  	s.active.Add(1)
    56  	defer s.active.Add(-1)
    57  	writer.WriteHeader(http.StatusOK)
    58  	conn := newHTTP2Wrapper(newHTTPConn(request.Body, writer), writer.(http.Flusher))
    59  	s.inbound <- conn
    60  	select {
    61  	case <-conn.done:
    62  	case <-s.done:
    63  		_ = conn.Close()
    64  	}
    65  }
    66  
    67  func (s *h2MuxServerSession) Open() (net.Conn, error) {
    68  	return nil, os.ErrInvalid
    69  }
    70  
    71  func (s *h2MuxServerSession) Accept() (net.Conn, error) {
    72  	select {
    73  	case conn := <-s.inbound:
    74  		return conn, nil
    75  	case <-s.done:
    76  		return nil, os.ErrClosed
    77  	}
    78  }
    79  
    80  func (s *h2MuxServerSession) NumStreams() int {
    81  	return int(s.active.Load())
    82  }
    83  
    84  func (s *h2MuxServerSession) Close() error {
    85  	select {
    86  	case <-s.done:
    87  	default:
    88  		close(s.done)
    89  	}
    90  	return s.conn.Close()
    91  }
    92  
    93  func (s *h2MuxServerSession) IsClosed() bool {
    94  	select {
    95  	case <-s.done:
    96  		return true
    97  	default:
    98  		return false
    99  	}
   100  }
   101  
   102  func (s *h2MuxServerSession) CanTakeNewRequest() bool {
   103  	return false
   104  }
   105  
   106  type h2MuxConnWrapper struct {
   107  	N.ExtendedConn
   108  	flusher http.Flusher
   109  	access  sync.Mutex
   110  	closed  bool
   111  	done    chan struct{}
   112  }
   113  
   114  func newHTTP2Wrapper(conn net.Conn, flusher http.Flusher) *h2MuxConnWrapper {
   115  	return &h2MuxConnWrapper{
   116  		ExtendedConn: bufio.NewExtendedConn(conn),
   117  		flusher:      flusher,
   118  		done:         make(chan struct{}),
   119  	}
   120  }
   121  
   122  func (w *h2MuxConnWrapper) Write(p []byte) (n int, err error) {
   123  	w.access.Lock()
   124  	defer w.access.Unlock()
   125  	if w.closed {
   126  		return 0, net.ErrClosed
   127  	}
   128  	n, err = w.ExtendedConn.Write(p)
   129  	if err == nil {
   130  		w.flusher.Flush()
   131  	}
   132  	return
   133  }
   134  
   135  func (w *h2MuxConnWrapper) WriteBuffer(buffer *buf.Buffer) error {
   136  	w.access.Lock()
   137  	defer w.access.Unlock()
   138  	if w.closed {
   139  		return net.ErrClosed
   140  	}
   141  	err := w.ExtendedConn.WriteBuffer(buffer)
   142  	if err == nil {
   143  		w.flusher.Flush()
   144  	}
   145  	return err
   146  }
   147  
   148  func (w *h2MuxConnWrapper) Close() error {
   149  	w.access.Lock()
   150  	select {
   151  	case <-w.done:
   152  	default:
   153  		close(w.done)
   154  	}
   155  	w.closed = true
   156  	w.access.Unlock()
   157  	return w.ExtendedConn.Close()
   158  }
   159  
   160  func (w *h2MuxConnWrapper) Upstream() any {
   161  	return w.ExtendedConn
   162  }
   163  
   164  var _ abstractSession = (*h2MuxClientSession)(nil)
   165  
   166  type h2MuxClientSession struct {
   167  	transport  *http2.Transport
   168  	clientConn *http2.ClientConn
   169  	access     sync.RWMutex
   170  	closed     bool
   171  }
   172  
   173  func newH2MuxClient(conn net.Conn) (*h2MuxClientSession, error) {
   174  	session := &h2MuxClientSession{
   175  		transport: &http2.Transport{
   176  			DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
   177  				return conn, nil
   178  			},
   179  			ReadIdleTimeout:  idleTimeout,
   180  			MaxReadFrameSize: buf.BufferSize,
   181  		},
   182  	}
   183  	session.transport.ConnPool = session
   184  	clientConn, err := session.transport.NewClientConn(conn)
   185  	if err != nil {
   186  		return nil, err
   187  	}
   188  	session.clientConn = clientConn
   189  	return session, nil
   190  }
   191  
   192  func (s *h2MuxClientSession) GetClientConn(req *http.Request, addr string) (*http2.ClientConn, error) {
   193  	return s.clientConn, nil
   194  }
   195  
   196  func (s *h2MuxClientSession) MarkDead(conn *http2.ClientConn) {
   197  	s.Close()
   198  }
   199  
   200  func (s *h2MuxClientSession) Open() (net.Conn, error) {
   201  	pipeInReader, pipeInWriter := io.Pipe()
   202  	request := &http.Request{
   203  		Method: http.MethodConnect,
   204  		Body:   pipeInReader,
   205  		URL:    &url.URL{Scheme: "https", Host: "localhost"},
   206  	}
   207  	connCtx, cancel := context.WithCancel(context.Background())
   208  	request = request.WithContext(connCtx)
   209  	conn := newLateHTTPConn(pipeInWriter, cancel)
   210  	requestDone := make(chan struct{})
   211  	go func() {
   212  		select {
   213  		case <-requestDone:
   214  			return
   215  		case <-time.After(TCPTimeout):
   216  			cancel()
   217  		}
   218  	}()
   219  	go func() {
   220  		response, err := s.transport.RoundTrip(request)
   221  		close(requestDone)
   222  		if err != nil {
   223  			conn.setup(nil, err)
   224  		} else if response.StatusCode != 200 {
   225  			response.Body.Close()
   226  			conn.setup(nil, E.New("unexpected status: ", response.StatusCode, " ", response.Status))
   227  		} else {
   228  			conn.setup(response.Body, nil)
   229  		}
   230  	}()
   231  	return conn, nil
   232  }
   233  
   234  func (s *h2MuxClientSession) Accept() (net.Conn, error) {
   235  	return nil, os.ErrInvalid
   236  }
   237  
   238  func (s *h2MuxClientSession) NumStreams() int {
   239  	return s.clientConn.State().StreamsActive
   240  }
   241  
   242  func (s *h2MuxClientSession) Close() error {
   243  	s.access.Lock()
   244  	defer s.access.Unlock()
   245  	if s.closed {
   246  		return os.ErrClosed
   247  	}
   248  	s.closed = true
   249  	return s.clientConn.Close()
   250  }
   251  
   252  func (s *h2MuxClientSession) IsClosed() bool {
   253  	s.access.RLock()
   254  	defer s.access.RUnlock()
   255  	return s.closed || s.clientConn.State().Closed
   256  }
   257  
   258  func (s *h2MuxClientSession) CanTakeNewRequest() bool {
   259  	return s.clientConn.CanTakeNewRequest()
   260  }