github.com/xxf098/lite-proxy@v0.15.1-0.20230422081941-12c69f323218/transport/gun/gun.go (about)

     1  // License: MIT
     2  
     3  package gun
     4  
     5  import (
     6  	"bufio"
     7  	"context"
     8  	"crypto/tls"
     9  	"encoding/binary"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"net/http"
    15  	"net/url"
    16  	"sync"
    17  	"time"
    18  
    19  	"github.com/xxf098/lite-proxy/common/pool"
    20  
    21  	"go.uber.org/atomic"
    22  	"golang.org/x/net/http2"
    23  )
    24  
    25  var (
    26  	ErrInvalidLength = errors.New("invalid length")
    27  	ErrSmallBuffer   = errors.New("buffer too small")
    28  )
    29  
    30  var defaultHeader = http.Header{
    31  	"content-type": []string{"application/grpc"},
    32  	"user-agent":   []string{"grpc-go/1.36.0"},
    33  }
    34  
    35  type DialFn = func(network, addr string) (net.Conn, error)
    36  
    37  type Conn struct {
    38  	response  *http.Response
    39  	request   *http.Request
    40  	transport *http2.Transport
    41  	writer    *io.PipeWriter
    42  	once      sync.Once
    43  	close     *atomic.Bool
    44  	err       error
    45  	remain    int
    46  	br        *bufio.Reader
    47  
    48  	// deadlines
    49  	deadline *time.Timer
    50  }
    51  
    52  type Config struct {
    53  	ServiceName string
    54  	Host        string
    55  }
    56  
    57  func (g *Conn) initRequest() {
    58  	response, err := g.transport.RoundTrip(g.request)
    59  	if err != nil {
    60  		g.err = err
    61  		g.writer.Close()
    62  		return
    63  	}
    64  
    65  	if !g.close.Load() {
    66  		g.response = response
    67  		g.br = bufio.NewReader(response.Body)
    68  	} else {
    69  		response.Body.Close()
    70  	}
    71  }
    72  
    73  func (g *Conn) Read(b []byte) (n int, err error) {
    74  	g.once.Do(g.initRequest)
    75  	if g.err != nil {
    76  		return 0, g.err
    77  	}
    78  
    79  	if g.remain > 0 {
    80  		size := g.remain
    81  		if len(b) < size {
    82  			size = len(b)
    83  		}
    84  
    85  		n, err = io.ReadFull(g.br, b[:size])
    86  		g.remain -= n
    87  		return
    88  	} else if g.response == nil {
    89  		return 0, net.ErrClosed
    90  	}
    91  
    92  	// 0x00 grpclength(uint32) 0x0A uleb128 payload
    93  	_, err = g.br.Discard(6)
    94  	if err != nil {
    95  		return 0, err
    96  	}
    97  
    98  	protobufPayloadLen, err := binary.ReadUvarint(g.br)
    99  	if err != nil {
   100  		return 0, ErrInvalidLength
   101  	}
   102  
   103  	size := int(protobufPayloadLen)
   104  	if len(b) < size {
   105  		size = len(b)
   106  	}
   107  
   108  	n, err = io.ReadFull(g.br, b[:size])
   109  	if err != nil {
   110  		return
   111  	}
   112  
   113  	remain := int(protobufPayloadLen) - n
   114  	if remain > 0 {
   115  		g.remain = remain
   116  	}
   117  
   118  	return n, nil
   119  }
   120  
   121  func (g *Conn) Write(b []byte) (n int, err error) {
   122  	protobufHeader := [binary.MaxVarintLen64 + 1]byte{0x0A}
   123  	varuintSize := binary.PutUvarint(protobufHeader[1:], uint64(len(b)))
   124  	grpcHeader := make([]byte, 5)
   125  	grpcPayloadLen := uint32(varuintSize + 1 + len(b))
   126  	binary.BigEndian.PutUint32(grpcHeader[1:5], grpcPayloadLen)
   127  
   128  	buf := pool.GetBuffer()
   129  	defer pool.PutBuffer(buf)
   130  	buf.Write(grpcHeader)
   131  	buf.Write(protobufHeader[:varuintSize+1])
   132  	buf.Write(b)
   133  
   134  	_, err = g.writer.Write(buf.Bytes())
   135  	if err == io.ErrClosedPipe && g.err != nil {
   136  		err = g.err
   137  	}
   138  
   139  	return len(b), err
   140  }
   141  
   142  func (g *Conn) Close() error {
   143  	g.close.Store(true)
   144  	if r := g.response; r != nil {
   145  		r.Body.Close()
   146  	}
   147  
   148  	return g.writer.Close()
   149  }
   150  
   151  func (g *Conn) LocalAddr() net.Addr                { return &net.TCPAddr{IP: net.IPv4zero, Port: 0} }
   152  func (g *Conn) RemoteAddr() net.Addr               { return &net.TCPAddr{IP: net.IPv4zero, Port: 0} }
   153  func (g *Conn) SetReadDeadline(t time.Time) error  { return g.SetDeadline(t) }
   154  func (g *Conn) SetWriteDeadline(t time.Time) error { return g.SetDeadline(t) }
   155  
   156  func (g *Conn) SetDeadline(t time.Time) error {
   157  	d := time.Until(t)
   158  	if g.deadline != nil {
   159  		g.deadline.Reset(d)
   160  		return nil
   161  	}
   162  	g.deadline = time.AfterFunc(d, func() {
   163  		g.Close()
   164  	})
   165  	return nil
   166  }
   167  
   168  func NewHTTP2Client(dialFn DialFn, tlsConfig *tls.Config) *http2.Transport {
   169  	dialFunc := func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
   170  		pconn, err := dialFn(network, addr)
   171  		if err != nil {
   172  			return nil, err
   173  		}
   174  
   175  		cn := tls.Client(pconn, cfg)
   176  		if err := cn.HandshakeContext(ctx); err != nil {
   177  			pconn.Close()
   178  			return nil, err
   179  		}
   180  		state := cn.ConnectionState()
   181  		if p := state.NegotiatedProtocol; p != http2.NextProtoTLS {
   182  			cn.Close()
   183  			return nil, fmt.Errorf("http2: unexpected ALPN protocol %s, want %s", p, http2.NextProtoTLS)
   184  		}
   185  		return cn, nil
   186  	}
   187  
   188  	return &http2.Transport{
   189  		DialTLSContext:     dialFunc,
   190  		TLSClientConfig:    tlsConfig,
   191  		AllowHTTP:          false,
   192  		DisableCompression: true,
   193  		PingTimeout:        0,
   194  	}
   195  }
   196  
   197  func StreamGunWithTransport(transport *http2.Transport, cfg *Config) (net.Conn, error) {
   198  	serviceName := "GunService"
   199  	if cfg.ServiceName != "" {
   200  		serviceName = cfg.ServiceName
   201  	}
   202  
   203  	reader, writer := io.Pipe()
   204  	request := &http.Request{
   205  		Method: http.MethodPost,
   206  		Body:   reader,
   207  		URL: &url.URL{
   208  			Scheme: "https",
   209  			Host:   cfg.Host,
   210  			Path:   fmt.Sprintf("/%s/Tun", serviceName),
   211  			// for unescape path
   212  			Opaque: fmt.Sprintf("//%s/%s/Tun", cfg.Host, serviceName),
   213  		},
   214  		Proto:      "HTTP/2",
   215  		ProtoMajor: 2,
   216  		ProtoMinor: 0,
   217  		Header:     defaultHeader,
   218  	}
   219  
   220  	conn := &Conn{
   221  		request:   request,
   222  		transport: transport,
   223  		writer:    writer,
   224  		close:     atomic.NewBool(false),
   225  	}
   226  
   227  	go conn.once.Do(conn.initRequest)
   228  	return conn, nil
   229  }
   230  
   231  func StreamGunWithConn(conn net.Conn, tlsConfig *tls.Config, cfg *Config) (net.Conn, error) {
   232  	dialFn := func(network, addr string) (net.Conn, error) {
   233  		return conn, nil
   234  	}
   235  
   236  	transport := NewHTTP2Client(dialFn, tlsConfig)
   237  	return StreamGunWithTransport(transport, cfg)
   238  }