github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/http2/client.go (about)

     1  package http2
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"net/http"
    10  	"net/url"
    11  	"strings"
    12  	"sync"
    13  	"sync/atomic"
    14  	"time"
    15  
    16  	"github.com/Asutorufa/yuhaiin/pkg/log"
    17  	"github.com/Asutorufa/yuhaiin/pkg/net/deadline"
    18  	"github.com/Asutorufa/yuhaiin/pkg/net/nat"
    19  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    20  	"github.com/Asutorufa/yuhaiin/pkg/protos/node/point"
    21  	"github.com/Asutorufa/yuhaiin/pkg/protos/node/protocol"
    22  	"github.com/Asutorufa/yuhaiin/pkg/utils/pool"
    23  	"golang.org/x/net/http2"
    24  )
    25  
    26  type Client struct {
    27  	client *clientConnPool
    28  	netapi.Proxy
    29  }
    30  
    31  func init() {
    32  	point.RegisterProtocol(NewClient)
    33  }
    34  
    35  func NewClient(config *protocol.Protocol_Http2) point.WrapProxy {
    36  	return func(p netapi.Proxy) (netapi.Proxy, error) {
    37  
    38  		if config.Http2.Concurrency < 1 {
    39  			config.Http2.Concurrency = 1
    40  		}
    41  
    42  		cpool := &clientConnPool{
    43  			dialer: p,
    44  			conns:  make([]*entry, config.Http2.Concurrency),
    45  			max:    uint64(config.Http2.Concurrency),
    46  		}
    47  
    48  		for i := range cpool.conns {
    49  			cpool.conns[i] = &entry{}
    50  		}
    51  
    52  		return &Client{
    53  			client: cpool,
    54  			Proxy:  p,
    55  		}, nil
    56  	}
    57  }
    58  
    59  type entry struct {
    60  	mu   sync.Mutex
    61  	raw  net.Conn
    62  	conn *http2.ClientConn
    63  }
    64  
    65  type clientConnPool struct {
    66  	dialer netapi.Proxy
    67  	conns  []*entry
    68  
    69  	max     uint64
    70  	current atomic.Uint64
    71  }
    72  
    73  func (c *clientConnPool) OpenStream(ctx context.Context) (uint64, net.Conn, *http2.ClientConn, error) {
    74  	nowNumber := c.current.Add(1)
    75  
    76  	conn := c.conns[nowNumber%(c.max)]
    77  
    78  	cc := conn.conn
    79  
    80  	if cc != nil {
    81  		state := cc.State()
    82  		if !state.Closed && !state.Closing {
    83  			return nowNumber, conn.raw, cc, nil
    84  		}
    85  	}
    86  
    87  	conn.mu.Lock()
    88  	defer conn.mu.Unlock()
    89  
    90  	if conn.conn != nil {
    91  		state := conn.conn.State()
    92  		if !state.Closed && !state.Closing {
    93  			return nowNumber, conn.raw, conn.conn, nil
    94  		}
    95  		_ = conn.conn.Close()
    96  	}
    97  
    98  	rawConn, err := c.dialer.Conn(ctx, netapi.EmptyAddr)
    99  	if err != nil {
   100  		return nowNumber, nil, nil, err
   101  	}
   102  
   103  	transport := &http2.Transport{
   104  		DisableCompression: true,
   105  		AllowHTTP:          true,
   106  		ReadIdleTimeout:    time.Second * 30,
   107  		MaxReadFrameSize:   pool.DefaultSize,
   108  		IdleConnTimeout:    nat.IdleTimeout,
   109  		DialTLSContext: func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
   110  			return rawConn, nil
   111  		},
   112  	}
   113  
   114  	cc, err = transport.NewClientConn(rawConn)
   115  	if err != nil {
   116  		rawConn.Close()
   117  		return nowNumber, nil, nil, err
   118  	}
   119  
   120  	conn.conn = cc
   121  	conn.raw = rawConn
   122  
   123  	return nowNumber, rawConn, cc, nil
   124  }
   125  
   126  func (c *Client) Conn(ctx context.Context, add netapi.Address) (net.Conn, error) {
   127  	id, raw, clientConn, err := c.client.OpenStream(ctx)
   128  	if err != nil {
   129  		return nil, fmt.Errorf("http2 get client conn failed: %w", err)
   130  	}
   131  
   132  	r, w := io.Pipe()
   133  
   134  	respr := newReadCloser()
   135  
   136  	h2conn := &http2Conn{
   137  		piper:      r,
   138  		pipew:      w,
   139  		r:          respr,
   140  		localAddr:  addr{addr: raw.LocalAddr().String(), id: id},
   141  		remoteAddr: raw.RemoteAddr(),
   142  		deadline: deadline.NewPipe(
   143  			deadline.WithReadClose(func() {
   144  				_ = respr.Close()
   145  			}),
   146  			deadline.WithWriteClose(func() {
   147  				_ = w.CloseWithError(io.EOF)
   148  			}),
   149  		),
   150  	}
   151  
   152  	go func() {
   153  		resp, err := clientConn.RoundTrip(&http.Request{
   154  			Method: http.MethodConnect,
   155  			Body:   &wrapPipeReaderClose{r},
   156  			URL:    &url.URL{Scheme: "https", Host: "localhost"},
   157  		})
   158  		if err != nil {
   159  			r.CloseWithError(err)
   160  			h2conn.Close()
   161  			log.Error("http2 do request failed:", "err", err)
   162  			return
   163  		}
   164  
   165  		respr.SetReadCloser(resp.Body)
   166  	}()
   167  
   168  	return h2conn, nil
   169  }
   170  
   171  type readCloser struct {
   172  	rc   io.ReadCloser
   173  	ctx  context.Context
   174  	done context.CancelFunc
   175  }
   176  
   177  func newReadCloser() *readCloser {
   178  	ctx, cancel := context.WithCancel(context.Background())
   179  	return &readCloser{ctx: ctx, done: cancel}
   180  }
   181  
   182  func (r *readCloser) Close() error {
   183  	if r.rc != nil {
   184  		return r.rc.Close()
   185  	}
   186  
   187  	r.done()
   188  	return nil
   189  }
   190  
   191  func (r *readCloser) SetReadCloser(rc io.ReadCloser) {
   192  	r.rc = rc
   193  	r.done()
   194  }
   195  
   196  func (r *readCloser) Read(b []byte) (int, error) {
   197  	if r.rc == nil {
   198  		<-r.ctx.Done()
   199  		if r.rc == nil {
   200  			return 0, io.EOF
   201  		}
   202  	}
   203  
   204  	n, err := r.rc.Read(b)
   205  	if err != nil {
   206  		if strings.Contains(err.Error(), "http2: response body closed") {
   207  			err = io.EOF
   208  		}
   209  
   210  		return n, err
   211  	}
   212  
   213  	return n, nil
   214  }
   215  
   216  type wrapPipeReaderClose struct {
   217  	*io.PipeReader
   218  }
   219  
   220  func (w *wrapPipeReaderClose) Close() error { return w.PipeReader.CloseWithError(io.EOF) }