github.com/metacubex/mihomo@v1.18.5/transport/gun/gun.go (about)

     1  // Modified from: https://github.com/Qv2ray/gun-lite
     2  // License: MIT
     3  
     4  package gun
     5  
     6  import (
     7  	"bufio"
     8  	"context"
     9  	"crypto/tls"
    10  	"encoding/binary"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"net"
    15  	"net/http"
    16  	"net/url"
    17  	"sync"
    18  	"time"
    19  
    20  	"github.com/metacubex/mihomo/common/atomic"
    21  	"github.com/metacubex/mihomo/common/buf"
    22  	"github.com/metacubex/mihomo/common/pool"
    23  	tlsC "github.com/metacubex/mihomo/component/tls"
    24  
    25  	"golang.org/x/net/http2"
    26  )
    27  
    28  var (
    29  	ErrInvalidLength = errors.New("invalid length")
    30  	ErrSmallBuffer   = errors.New("buffer too small")
    31  )
    32  
    33  var defaultHeader = http.Header{
    34  	"content-type": []string{"application/grpc"},
    35  	"user-agent":   []string{"grpc-go/1.36.0"},
    36  }
    37  
    38  type DialFn = func(network, addr string) (net.Conn, error)
    39  
    40  type Conn struct {
    41  	response  *http.Response
    42  	request   *http.Request
    43  	transport *TransportWrap
    44  	writer    *io.PipeWriter
    45  	once      sync.Once
    46  	close     atomic.Bool
    47  	err       error
    48  	remain    int
    49  	br        *bufio.Reader
    50  	// deadlines
    51  	deadline *time.Timer
    52  }
    53  
    54  type Config struct {
    55  	ServiceName       string
    56  	Host              string
    57  	ClientFingerprint string
    58  }
    59  
    60  func (g *Conn) initRequest() {
    61  	response, err := g.transport.RoundTrip(g.request)
    62  	if err != nil {
    63  		g.err = err
    64  		g.writer.Close()
    65  		return
    66  	}
    67  
    68  	if !g.close.Load() {
    69  		g.response = response
    70  		g.br = bufio.NewReader(response.Body)
    71  	} else {
    72  		response.Body.Close()
    73  	}
    74  }
    75  
    76  func (g *Conn) Read(b []byte) (n int, err error) {
    77  	g.once.Do(g.initRequest)
    78  	if g.err != nil {
    79  		return 0, g.err
    80  	}
    81  
    82  	if g.remain > 0 {
    83  		size := g.remain
    84  		if len(b) < size {
    85  			size = len(b)
    86  		}
    87  
    88  		n, err = io.ReadFull(g.br, b[:size])
    89  		g.remain -= n
    90  		return
    91  	} else if g.response == nil {
    92  		return 0, net.ErrClosed
    93  	}
    94  
    95  	// 0x00 grpclength(uint32) 0x0A uleb128 payload
    96  	_, err = g.br.Discard(6)
    97  	if err != nil {
    98  		return 0, err
    99  	}
   100  
   101  	protobufPayloadLen, err := binary.ReadUvarint(g.br)
   102  	if err != nil {
   103  		return 0, ErrInvalidLength
   104  	}
   105  
   106  	size := int(protobufPayloadLen)
   107  	if len(b) < size {
   108  		size = len(b)
   109  	}
   110  
   111  	n, err = io.ReadFull(g.br, b[:size])
   112  	if err != nil {
   113  		return
   114  	}
   115  
   116  	remain := int(protobufPayloadLen) - n
   117  	if remain > 0 {
   118  		g.remain = remain
   119  	}
   120  
   121  	return n, nil
   122  }
   123  
   124  func (g *Conn) Write(b []byte) (n int, err error) {
   125  	protobufHeader := [binary.MaxVarintLen64 + 1]byte{0x0A}
   126  	varuintSize := binary.PutUvarint(protobufHeader[1:], uint64(len(b)))
   127  	var grpcHeader [5]byte
   128  	grpcPayloadLen := uint32(varuintSize + 1 + len(b))
   129  	binary.BigEndian.PutUint32(grpcHeader[1:5], grpcPayloadLen)
   130  
   131  	buf := pool.GetBuffer()
   132  	defer pool.PutBuffer(buf)
   133  	buf.Write(grpcHeader[:])
   134  	buf.Write(protobufHeader[:varuintSize+1])
   135  	buf.Write(b)
   136  
   137  	_, err = g.writer.Write(buf.Bytes())
   138  	if err == io.ErrClosedPipe && g.err != nil {
   139  		err = g.err
   140  	}
   141  
   142  	return len(b), err
   143  }
   144  
   145  func (g *Conn) WriteBuffer(buffer *buf.Buffer) error {
   146  	defer buffer.Release()
   147  	dataLen := buffer.Len()
   148  	varLen := UVarintLen(uint64(dataLen))
   149  	header := buffer.ExtendHeader(6 + varLen)
   150  	_ = header[6] // bounds check hint to compiler
   151  	header[0] = 0x00
   152  	binary.BigEndian.PutUint32(header[1:5], uint32(1+varLen+dataLen))
   153  	header[5] = 0x0A
   154  	binary.PutUvarint(header[6:], uint64(dataLen))
   155  	_, err := g.writer.Write(buffer.Bytes())
   156  
   157  	if err == io.ErrClosedPipe && g.err != nil {
   158  		err = g.err
   159  	}
   160  
   161  	return err
   162  }
   163  
   164  func (g *Conn) FrontHeadroom() int {
   165  	return 6 + binary.MaxVarintLen64
   166  }
   167  
   168  func (g *Conn) Close() error {
   169  	g.close.Store(true)
   170  	if r := g.response; r != nil {
   171  		r.Body.Close()
   172  	}
   173  
   174  	return g.writer.Close()
   175  }
   176  
   177  func (g *Conn) LocalAddr() net.Addr                { return g.transport.LocalAddr() }
   178  func (g *Conn) RemoteAddr() net.Addr               { return g.transport.RemoteAddr() }
   179  func (g *Conn) SetReadDeadline(t time.Time) error  { return g.SetDeadline(t) }
   180  func (g *Conn) SetWriteDeadline(t time.Time) error { return g.SetDeadline(t) }
   181  
   182  func (g *Conn) SetDeadline(t time.Time) error {
   183  	d := time.Until(t)
   184  	if g.deadline != nil {
   185  		g.deadline.Reset(d)
   186  		return nil
   187  	}
   188  	g.deadline = time.AfterFunc(d, func() {
   189  		g.Close()
   190  	})
   191  	return nil
   192  }
   193  
   194  func NewHTTP2Client(dialFn DialFn, tlsConfig *tls.Config, Fingerprint string, realityConfig *tlsC.RealityConfig) *TransportWrap {
   195  	wrap := TransportWrap{}
   196  
   197  	dialFunc := func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
   198  		pconn, err := dialFn(network, addr)
   199  		if err != nil {
   200  			return nil, err
   201  		}
   202  		wrap.remoteAddr = pconn.RemoteAddr()
   203  
   204  		if tlsConfig == nil {
   205  			return pconn, nil
   206  		}
   207  
   208  		if len(Fingerprint) != 0 {
   209  			if realityConfig == nil {
   210  				if fingerprint, exists := tlsC.GetFingerprint(Fingerprint); exists {
   211  					utlsConn := tlsC.UClient(pconn, cfg, fingerprint)
   212  					if err := utlsConn.HandshakeContext(ctx); err != nil {
   213  						pconn.Close()
   214  						return nil, err
   215  					}
   216  					state := utlsConn.ConnectionState()
   217  					if p := state.NegotiatedProtocol; p != http2.NextProtoTLS {
   218  						utlsConn.Close()
   219  						return nil, fmt.Errorf("http2: unexpected ALPN protocol %s, want %s", p, http2.NextProtoTLS)
   220  					}
   221  					return utlsConn, nil
   222  				}
   223  			} else {
   224  				realityConn, err := tlsC.GetRealityConn(ctx, pconn, Fingerprint, cfg, realityConfig)
   225  				if err != nil {
   226  					pconn.Close()
   227  					return nil, err
   228  				}
   229  				//state := realityConn.(*utls.UConn).ConnectionState()
   230  				//if p := state.NegotiatedProtocol; p != http2.NextProtoTLS {
   231  				//	realityConn.Close()
   232  				//	return nil, fmt.Errorf("http2: unexpected ALPN protocol %s, want %s", p, http2.NextProtoTLS)
   233  				//}
   234  				return realityConn, nil
   235  			}
   236  		}
   237  		if realityConfig != nil {
   238  			return nil, errors.New("REALITY is based on uTLS, please set a client-fingerprint")
   239  		}
   240  
   241  		conn := tls.Client(pconn, cfg)
   242  		if err := conn.HandshakeContext(ctx); err != nil {
   243  			pconn.Close()
   244  			return nil, err
   245  		}
   246  		state := conn.ConnectionState()
   247  		if p := state.NegotiatedProtocol; p != http2.NextProtoTLS {
   248  			conn.Close()
   249  			return nil, fmt.Errorf("http2: unexpected ALPN protocol %s, want %s", p, http2.NextProtoTLS)
   250  		}
   251  		return conn, nil
   252  	}
   253  
   254  	wrap.Transport = &http2.Transport{
   255  		DialTLSContext:     dialFunc,
   256  		TLSClientConfig:    tlsConfig,
   257  		AllowHTTP:          false,
   258  		DisableCompression: true,
   259  		PingTimeout:        0,
   260  	}
   261  
   262  	return &wrap
   263  }
   264  
   265  func StreamGunWithTransport(transport *TransportWrap, cfg *Config) (net.Conn, error) {
   266  	serviceName := "GunService"
   267  	if cfg.ServiceName != "" {
   268  		serviceName = cfg.ServiceName
   269  	}
   270  
   271  	reader, writer := io.Pipe()
   272  	request := &http.Request{
   273  		Method: http.MethodPost,
   274  		Body:   reader,
   275  		URL: &url.URL{
   276  			Scheme: "https",
   277  			Host:   cfg.Host,
   278  			Path:   fmt.Sprintf("/%s/Tun", serviceName),
   279  			// for unescape path
   280  			Opaque: fmt.Sprintf("//%s/%s/Tun", cfg.Host, serviceName),
   281  		},
   282  		Proto:      "HTTP/2",
   283  		ProtoMajor: 2,
   284  		ProtoMinor: 0,
   285  		Header:     defaultHeader,
   286  	}
   287  
   288  	conn := &Conn{
   289  		request:   request,
   290  		transport: transport,
   291  		writer:    writer,
   292  		close:     atomic.NewBool(false),
   293  	}
   294  
   295  	go conn.once.Do(conn.initRequest)
   296  	return conn, nil
   297  }
   298  
   299  func StreamGunWithConn(conn net.Conn, tlsConfig *tls.Config, cfg *Config, realityConfig *tlsC.RealityConfig) (net.Conn, error) {
   300  	dialFn := func(network, addr string) (net.Conn, error) {
   301  		return conn, nil
   302  	}
   303  
   304  	transport := NewHTTP2Client(dialFn, tlsConfig, cfg.ClientFingerprint, realityConfig)
   305  	return StreamGunWithTransport(transport, cfg)
   306  }