github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/http2/server.go (about)

     1  package http2
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"math"
     8  	"net"
     9  	"net/http"
    10  	"sync"
    11  	"time"
    12  
    13  	"github.com/Asutorufa/yuhaiin/pkg/log"
    14  	"github.com/Asutorufa/yuhaiin/pkg/net/deadline"
    15  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    16  	"github.com/Asutorufa/yuhaiin/pkg/protos/config/listener"
    17  	"github.com/Asutorufa/yuhaiin/pkg/utils/id"
    18  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    19  	"github.com/Asutorufa/yuhaiin/pkg/utils/syncmap"
    20  	"golang.org/x/net/http2"
    21  )
    22  
    23  type Server struct {
    24  	listener  net.Listener
    25  	id        id.IDGenerator
    26  	closedCtx context.Context
    27  	close     context.CancelFunc
    28  
    29  	connChan chan net.Conn
    30  
    31  	conns syncmap.SyncMap[string, net.Conn]
    32  }
    33  
    34  func init() {
    35  	listener.RegisterTransport(NewServer)
    36  }
    37  
    38  func NewServer(c *listener.Transport_Http2) func(netapi.Listener) (netapi.Listener, error) {
    39  	return func(ii netapi.Listener) (netapi.Listener, error) {
    40  		lis, err := ii.Stream(context.TODO())
    41  		if err != nil {
    42  			return nil, err
    43  		}
    44  		return netapi.PatchStream(newServer(lis), ii), nil
    45  	}
    46  }
    47  
    48  func newServer(lis net.Listener) *Server {
    49  	ctx, cancel := context.WithCancel(context.Background())
    50  
    51  	h := &Server{
    52  		listener:  lis,
    53  		connChan:  make(chan net.Conn, 20),
    54  		closedCtx: ctx,
    55  		close:     cancel,
    56  	}
    57  
    58  	go func() {
    59  		defer h.Close()
    60  		defer cancel()
    61  
    62  		for {
    63  			conn, err := lis.Accept()
    64  			if err != nil {
    65  				log.Error("accept failed:", "err", err)
    66  				return
    67  			}
    68  
    69  			go func() {
    70  				key := conn.RemoteAddr().String() + conn.LocalAddr().String()
    71  				h.conns.Store(key, conn)
    72  
    73  				defer func() {
    74  					h.conns.Delete(key)
    75  					conn.Close()
    76  				}()
    77  
    78  				(&http2.Server{
    79  					MaxConcurrentStreams: math.MaxUint32,
    80  					IdleTimeout:          time.Minute,
    81  					MaxReadFrameSize:     pool.DefaultSize,
    82  					NewWriteScheduler: func() http2.WriteScheduler {
    83  						return http2.NewRandomWriteScheduler()
    84  					},
    85  				}).ServeConn(conn, &http2.ServeConnOpts{
    86  					Handler: h,
    87  					Context: h.closedCtx,
    88  				})
    89  			}()
    90  		}
    91  
    92  	}()
    93  
    94  	return h
    95  }
    96  
    97  func (h *Server) Accept() (net.Conn, error) {
    98  	select {
    99  	case conn := <-h.connChan:
   100  		return conn, nil
   101  	case <-h.closedCtx.Done():
   102  		return nil, net.ErrClosed
   103  	}
   104  }
   105  
   106  func (g *Server) Addr() net.Addr {
   107  	if g.listener != nil {
   108  		return g.listener.Addr()
   109  	}
   110  
   111  	return netapi.EmptyAddr
   112  }
   113  
   114  func (h *Server) Close() error {
   115  	var err error
   116  	h.close()
   117  	log.Info("start close http2 underlying listener")
   118  	err = h.listener.Close()
   119  	log.Info("closed http2 underlying listener")
   120  
   121  	h.conns.Range(func(key string, conn net.Conn) bool {
   122  		_ = conn.Close()
   123  		return true
   124  	})
   125  
   126  	return err
   127  }
   128  
   129  func (h *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
   130  	w.WriteHeader(http.StatusOK)
   131  	if f, ok := w.(http.Flusher); ok {
   132  		f.Flush()
   133  	}
   134  	fw := newFlushWriter(w)
   135  
   136  	conn := &http2Conn{
   137  		true,
   138  		nil,
   139  		fw,
   140  		r.Body,
   141  		h.Addr(),
   142  		&addr{r.RemoteAddr, h.id.Generate()},
   143  		deadline.NewPipe(
   144  			// deadline.WithReadClose(func() {
   145  			// _ = r.Body.Close()
   146  			// }),
   147  			deadline.WithWriteClose(func() {
   148  				_ = fw.Close()
   149  			}),
   150  		),
   151  	}
   152  	defer conn.Close()
   153  
   154  	select {
   155  	case <-r.Context().Done():
   156  		return
   157  	case <-h.closedCtx.Done():
   158  		return
   159  	case h.connChan <- conn:
   160  	}
   161  
   162  	select {
   163  	case <-r.Context().Done():
   164  	case <-h.closedCtx.Done():
   165  	}
   166  }
   167  
   168  var _ net.Conn = (*http2Conn)(nil)
   169  
   170  type flushWriter struct {
   171  	w      io.Writer
   172  	flush  http.Flusher
   173  	mu     sync.RWMutex
   174  	closed bool
   175  }
   176  
   177  func newFlushWriter(w io.Writer) *flushWriter {
   178  	fw := &flushWriter{
   179  		w: w,
   180  	}
   181  
   182  	if f, ok := w.(http.Flusher); ok {
   183  		fw.flush = f
   184  	}
   185  
   186  	return fw
   187  }
   188  
   189  func (fw *flushWriter) Write(p []byte) (n int, err error) {
   190  	fw.mu.RLock()
   191  	if fw.closed {
   192  		return 0, io.EOF
   193  	}
   194  
   195  	n, err = fw.w.Write(p)
   196  	if err == nil && fw.flush != nil {
   197  		fw.flush.Flush()
   198  	}
   199  	fw.mu.RUnlock()
   200  
   201  	return
   202  }
   203  
   204  func (fw *flushWriter) Close() error {
   205  	fw.mu.Lock()
   206  	defer fw.mu.Unlock()
   207  
   208  	fw.closed = true
   209  	return nil
   210  }
   211  
   212  type http2Conn struct {
   213  	server bool
   214  
   215  	piper *io.PipeReader
   216  
   217  	pipew io.WriteCloser
   218  	r     io.ReadCloser
   219  
   220  	localAddr  net.Addr
   221  	remoteAddr net.Addr
   222  
   223  	deadline *deadline.PipeDeadline
   224  }
   225  
   226  func (h *http2Conn) Read(b []byte) (int, error) {
   227  	select {
   228  	case <-h.deadline.ReadContext().Done():
   229  		return 0, h.deadline.ReadContext().Err()
   230  	default:
   231  	}
   232  
   233  	n, err := h.r.Read(b)
   234  	if err != nil {
   235  		if he, ok := err.(http2.StreamError); h.server && ok {
   236  			// closed client, will send RSTStreamFrame
   237  			// see https://github.com/golang/net/blob/577e44a5cee023bd639dd2dcc4008644bcb71472/http2/server.go#L1615
   238  			if he.Code == http2.ErrCodeCancel || he.Code == http2.ErrCodeNo {
   239  				err = io.EOF
   240  			}
   241  		}
   242  	}
   243  
   244  	return n, err
   245  }
   246  
   247  func (h *http2Conn) Write(b []byte) (int, error) {
   248  	select {
   249  	case <-h.deadline.WriteContext().Done():
   250  		return 0, h.deadline.WriteContext().Err()
   251  	default:
   252  	}
   253  
   254  	return h.pipew.Write(b)
   255  }
   256  
   257  func (h *http2Conn) Close() error {
   258  	if h.piper != nil {
   259  		h.piper.CloseWithError(io.EOF)
   260  	}
   261  
   262  	h.pipew.Close()
   263  
   264  	if !h.server {
   265  		return h.r.Close()
   266  	}
   267  
   268  	_ = h.deadline.Close()
   269  
   270  	return nil
   271  }
   272  
   273  func (h *http2Conn) LocalAddr() net.Addr  { return h.localAddr }
   274  func (h *http2Conn) RemoteAddr() net.Addr { return h.remoteAddr }
   275  
   276  func (c *http2Conn) SetDeadline(t time.Time) error {
   277  	c.deadline.SetDeadline(t)
   278  	return nil
   279  }
   280  
   281  func (c *http2Conn) SetReadDeadline(t time.Time) error {
   282  	c.deadline.SetReadDeadline(t)
   283  	return nil
   284  }
   285  
   286  func (c *http2Conn) SetWriteDeadline(t time.Time) error {
   287  	c.deadline.SetWriteDeadline(t)
   288  	return nil
   289  }
   290  
   291  type addr struct {
   292  	addr string
   293  	id   uint64
   294  }
   295  
   296  func (addr) Network() string  { return "tcp" }
   297  func (a addr) String() string { return fmt.Sprintf("http2://%s-%d", a.addr, a.id) }