github.com/yaling888/clash@v1.53.0/transport/gun/gun.go (about)

     1  // Modified from: https://github.com/Qv2ray/gun-lite
     2  // License: MIT
     3  
     4  package gun
     5  
     6  import (
     7  	"bufio"
     8  	"context"
     9  	"crypto/tls"
    10  	"encoding/binary"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"net"
    15  	"net/http"
    16  	"net/url"
    17  	"sync"
    18  	"time"
    19  
    20  	"go.uber.org/atomic"
    21  	"golang.org/x/net/http2"
    22  
    23  	"github.com/yaling888/clash/common/pool"
    24  )
    25  
    26  var (
    27  	ErrInvalidLength = errors.New("invalid length")
    28  
    29  	defaultHeader = http.Header{
    30  		"content-type": []string{"application/grpc"},
    31  		"user-agent":   []string{"grpc-go/1.36.0"},
    32  	}
    33  )
    34  
    35  type DialFn = func(network, addr string) (net.Conn, error)
    36  
    37  type Conn struct {
    38  	response  *http.Response
    39  	request   *http.Request
    40  	transport *http2.Transport
    41  	writer    *io.PipeWriter
    42  	once      sync.Once
    43  	close     *atomic.Bool
    44  	err       error
    45  	remain    int
    46  	br        *bufio.Reader
    47  
    48  	// deadlines
    49  	deadline *time.Timer
    50  }
    51  
    52  type Config struct {
    53  	ServiceName string
    54  	Host        string
    55  }
    56  
    57  func (g *Conn) initRequest() {
    58  	response, err := g.transport.RoundTrip(g.request)
    59  	if err != nil {
    60  		g.err = err
    61  		_ = g.writer.Close()
    62  		return
    63  	}
    64  
    65  	if !g.close.Load() {
    66  		g.response = response
    67  		g.br = bufio.NewReader(response.Body)
    68  	} else {
    69  		_ = response.Body.Close()
    70  	}
    71  }
    72  
    73  func (g *Conn) Read(b []byte) (n int, err error) {
    74  	g.once.Do(g.initRequest)
    75  	if g.err != nil {
    76  		return 0, g.err
    77  	}
    78  
    79  	if g.remain > 0 {
    80  		size := g.remain
    81  		if len(b) < size {
    82  			size = len(b)
    83  		}
    84  
    85  		n, err = io.ReadFull(g.br, b[:size])
    86  		g.remain -= n
    87  		return
    88  	} else if g.response == nil {
    89  		return 0, net.ErrClosed
    90  	}
    91  
    92  	// 0x00 grpclength(uint32) 0x0A uleb128 payload
    93  	_, err = g.br.Discard(6)
    94  	if err != nil {
    95  		return 0, err
    96  	}
    97  
    98  	protobufPayloadLen, err := binary.ReadUvarint(g.br)
    99  	if err != nil {
   100  		return 0, ErrInvalidLength
   101  	}
   102  
   103  	size := int(protobufPayloadLen)
   104  	if len(b) < size {
   105  		size = len(b)
   106  	}
   107  
   108  	n, err = io.ReadFull(g.br, b[:size])
   109  	if err != nil {
   110  		return
   111  	}
   112  
   113  	remain := int(protobufPayloadLen) - n
   114  	if remain > 0 {
   115  		g.remain = remain
   116  	}
   117  
   118  	return n, nil
   119  }
   120  
   121  const bufSize = pool.NetBufferSize + 16
   122  
   123  var bufPool = sync.Pool{
   124  	New: func() any {
   125  		b := make([]byte, bufSize)
   126  		return &b
   127  	},
   128  }
   129  
   130  func (g *Conn) Write(b []byte) (n int, err error) {
   131  	n = len(b)
   132  
   133  	bufP := bufPool.Get().(*[]byte)
   134  	defer bufPool.Put(bufP)
   135  
   136  	varuintSize := binary.PutUvarint((*bufP)[6:], uint64(n))
   137  	grpcPayloadLen := uint32(varuintSize + 1 + n)
   138  
   139  	(*bufP)[0] = byte(0)
   140  	(*bufP)[5] = byte(0x0A)
   141  	binary.BigEndian.PutUint32((*bufP)[1:], grpcPayloadLen)
   142  
   143  	t := 6 + varuintSize
   144  	t1 := copy((*bufP)[t:], b)
   145  
   146  	_, err = g.writer.Write((*bufP)[:t+t1])
   147  	if err == io.ErrClosedPipe && g.err != nil {
   148  		err = g.err
   149  	}
   150  	if n > t1 {
   151  		n = t1
   152  		if err == nil {
   153  			err = io.ErrShortWrite
   154  		}
   155  	}
   156  	return
   157  }
   158  
   159  func (g *Conn) Close() error {
   160  	g.close.Store(true)
   161  	if r := g.response; r != nil {
   162  		_ = r.Body.Close()
   163  	}
   164  
   165  	return g.writer.Close()
   166  }
   167  
   168  func (g *Conn) LocalAddr() net.Addr                { return &net.TCPAddr{IP: net.IPv4zero, Port: 0} }
   169  func (g *Conn) RemoteAddr() net.Addr               { return &net.TCPAddr{IP: net.IPv4zero, Port: 0} }
   170  func (g *Conn) SetReadDeadline(t time.Time) error  { return g.SetDeadline(t) }
   171  func (g *Conn) SetWriteDeadline(t time.Time) error { return g.SetDeadline(t) }
   172  
   173  func (g *Conn) SetDeadline(t time.Time) error {
   174  	d := time.Until(t)
   175  	if g.deadline != nil {
   176  		g.deadline.Reset(d)
   177  		return nil
   178  	}
   179  	g.deadline = time.AfterFunc(d, func() {
   180  		_ = g.Close()
   181  	})
   182  	return nil
   183  }
   184  
   185  func NewHTTP2Client(dialFn DialFn, tlsConfig *tls.Config) *http2.Transport {
   186  	dialFunc := func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) {
   187  		pconn, err := dialFn(network, addr)
   188  		if err != nil {
   189  			return nil, err
   190  		}
   191  
   192  		cn := tls.Client(pconn, cfg)
   193  		if err = cn.HandshakeContext(ctx); err != nil {
   194  			_ = pconn.Close()
   195  			return nil, err
   196  		}
   197  		state := cn.ConnectionState()
   198  		if p := state.NegotiatedProtocol; p != http2.NextProtoTLS {
   199  			_ = cn.Close()
   200  			return nil, fmt.Errorf("http2: unexpected ALPN protocol %s, want %s", p, http2.NextProtoTLS)
   201  		}
   202  		return cn, nil
   203  	}
   204  
   205  	return &http2.Transport{
   206  		DialTLSContext:     dialFunc,
   207  		TLSClientConfig:    tlsConfig,
   208  		AllowHTTP:          false,
   209  		DisableCompression: true,
   210  		PingTimeout:        0,
   211  	}
   212  }
   213  
   214  func StreamGunWithTransport(transport *http2.Transport, cfg *Config) (net.Conn, error) {
   215  	serviceName := "GunService"
   216  	if cfg.ServiceName != "" {
   217  		serviceName = cfg.ServiceName
   218  	}
   219  
   220  	reader, writer := io.Pipe()
   221  	request := &http.Request{
   222  		Method: http.MethodPost,
   223  		Body:   reader,
   224  		URL: &url.URL{
   225  			Scheme: "https",
   226  			Host:   cfg.Host,
   227  			Path:   fmt.Sprintf("/%s/Tun", serviceName),
   228  			// for unescape path
   229  			Opaque: fmt.Sprintf("//%s/%s/Tun", cfg.Host, serviceName),
   230  		},
   231  		Proto:      "HTTP/2",
   232  		ProtoMajor: 2,
   233  		ProtoMinor: 0,
   234  		Header:     defaultHeader,
   235  	}
   236  
   237  	conn := &Conn{
   238  		request:   request,
   239  		transport: transport,
   240  		writer:    writer,
   241  		close:     atomic.NewBool(false),
   242  	}
   243  
   244  	go conn.once.Do(conn.initRequest)
   245  	return conn, nil
   246  }
   247  
   248  func StreamGunWithConn(conn net.Conn, tlsConfig *tls.Config, cfg *Config) (net.Conn, error) {
   249  	dialFn := func(network, addr string) (net.Conn, error) {
   250  		return conn, nil
   251  	}
   252  
   253  	transport := NewHTTP2Client(dialFn, tlsConfig)
   254  	return StreamGunWithTransport(transport, cfg)
   255  }