github.com/laof/lite-speed-test@v0.0.0-20230930011949-1f39b7037845/transport/gun/gun.go (about) 1 // License: MIT 2 3 package gun 4 5 import ( 6 "bufio" 7 "context" 8 "crypto/tls" 9 "encoding/binary" 10 "errors" 11 "fmt" 12 "io" 13 "net" 14 "net/http" 15 "net/url" 16 "sync" 17 "time" 18 19 "github.com/laof/lite-speed-test/common/pool" 20 21 "go.uber.org/atomic" 22 "golang.org/x/net/http2" 23 ) 24 25 var ( 26 ErrInvalidLength = errors.New("invalid length") 27 ErrSmallBuffer = errors.New("buffer too small") 28 ) 29 30 var defaultHeader = http.Header{ 31 "content-type": []string{"application/grpc"}, 32 "user-agent": []string{"grpc-go/1.36.0"}, 33 } 34 35 type DialFn = func(network, addr string) (net.Conn, error) 36 37 type Conn struct { 38 response *http.Response 39 request *http.Request 40 transport *http2.Transport 41 writer *io.PipeWriter 42 once sync.Once 43 close *atomic.Bool 44 err error 45 remain int 46 br *bufio.Reader 47 48 // deadlines 49 deadline *time.Timer 50 } 51 52 type Config struct { 53 ServiceName string 54 Host string 55 } 56 57 func (g *Conn) initRequest() { 58 response, err := g.transport.RoundTrip(g.request) 59 if err != nil { 60 g.err = err 61 g.writer.Close() 62 return 63 } 64 65 if !g.close.Load() { 66 g.response = response 67 g.br = bufio.NewReader(response.Body) 68 } else { 69 response.Body.Close() 70 } 71 } 72 73 func (g *Conn) Read(b []byte) (n int, err error) { 74 g.once.Do(g.initRequest) 75 if g.err != nil { 76 return 0, g.err 77 } 78 79 if g.remain > 0 { 80 size := g.remain 81 if len(b) < size { 82 size = len(b) 83 } 84 85 n, err = io.ReadFull(g.br, b[:size]) 86 g.remain -= n 87 return 88 } else if g.response == nil { 89 return 0, net.ErrClosed 90 } 91 92 // 0x00 grpclength(uint32) 0x0A uleb128 payload 93 _, err = g.br.Discard(6) 94 if err != nil { 95 return 0, err 96 } 97 98 protobufPayloadLen, err := binary.ReadUvarint(g.br) 99 if err != nil { 100 return 0, ErrInvalidLength 101 } 102 103 size := int(protobufPayloadLen) 104 if len(b) < size { 105 size = len(b) 106 } 107 108 n, err = io.ReadFull(g.br, b[:size]) 109 if err != nil { 110 return 111 } 112 113 remain := int(protobufPayloadLen) - n 114 if remain > 0 { 115 g.remain = remain 116 } 117 118 return n, nil 119 } 120 121 func (g *Conn) Write(b []byte) (n int, err error) { 122 protobufHeader := [binary.MaxVarintLen64 + 1]byte{0x0A} 123 varuintSize := binary.PutUvarint(protobufHeader[1:], uint64(len(b))) 124 grpcHeader := make([]byte, 5) 125 grpcPayloadLen := uint32(varuintSize + 1 + len(b)) 126 binary.BigEndian.PutUint32(grpcHeader[1:5], grpcPayloadLen) 127 128 buf := pool.GetBuffer() 129 defer pool.PutBuffer(buf) 130 buf.Write(grpcHeader) 131 buf.Write(protobufHeader[:varuintSize+1]) 132 buf.Write(b) 133 134 _, err = g.writer.Write(buf.Bytes()) 135 if err == io.ErrClosedPipe && g.err != nil { 136 err = g.err 137 } 138 139 return len(b), err 140 } 141 142 func (g *Conn) Close() error { 143 g.close.Store(true) 144 if r := g.response; r != nil { 145 r.Body.Close() 146 } 147 148 return g.writer.Close() 149 } 150 151 func (g *Conn) LocalAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4zero, Port: 0} } 152 func (g *Conn) RemoteAddr() net.Addr { return &net.TCPAddr{IP: net.IPv4zero, Port: 0} } 153 func (g *Conn) SetReadDeadline(t time.Time) error { return g.SetDeadline(t) } 154 func (g *Conn) SetWriteDeadline(t time.Time) error { return g.SetDeadline(t) } 155 156 func (g *Conn) SetDeadline(t time.Time) error { 157 d := time.Until(t) 158 if g.deadline != nil { 159 g.deadline.Reset(d) 160 return nil 161 } 162 g.deadline = time.AfterFunc(d, func() { 163 g.Close() 164 }) 165 return nil 166 } 167 168 func NewHTTP2Client(dialFn DialFn, tlsConfig *tls.Config) *http2.Transport { 169 dialFunc := func(ctx context.Context, network, addr string, cfg *tls.Config) (net.Conn, error) { 170 pconn, err := dialFn(network, addr) 171 if err != nil { 172 return nil, err 173 } 174 175 cn := tls.Client(pconn, cfg) 176 if err := cn.HandshakeContext(ctx); err != nil { 177 pconn.Close() 178 return nil, err 179 } 180 state := cn.ConnectionState() 181 if p := state.NegotiatedProtocol; p != http2.NextProtoTLS { 182 cn.Close() 183 return nil, fmt.Errorf("http2: unexpected ALPN protocol %s, want %s", p, http2.NextProtoTLS) 184 } 185 return cn, nil 186 } 187 188 return &http2.Transport{ 189 DialTLSContext: dialFunc, 190 TLSClientConfig: tlsConfig, 191 AllowHTTP: false, 192 DisableCompression: true, 193 PingTimeout: 0, 194 } 195 } 196 197 func StreamGunWithTransport(transport *http2.Transport, cfg *Config) (net.Conn, error) { 198 serviceName := "GunService" 199 if cfg.ServiceName != "" { 200 serviceName = cfg.ServiceName 201 } 202 203 reader, writer := io.Pipe() 204 request := &http.Request{ 205 Method: http.MethodPost, 206 Body: reader, 207 URL: &url.URL{ 208 Scheme: "https", 209 Host: cfg.Host, 210 Path: fmt.Sprintf("/%s/Tun", serviceName), 211 // for unescape path 212 Opaque: fmt.Sprintf("//%s/%s/Tun", cfg.Host, serviceName), 213 }, 214 Proto: "HTTP/2", 215 ProtoMajor: 2, 216 ProtoMinor: 0, 217 Header: defaultHeader, 218 } 219 220 conn := &Conn{ 221 request: request, 222 transport: transport, 223 writer: writer, 224 close: atomic.NewBool(false), 225 } 226 227 go conn.once.Do(conn.initRequest) 228 return conn, nil 229 } 230 231 func StreamGunWithConn(conn net.Conn, tlsConfig *tls.Config, cfg *Config) (net.Conn, error) { 232 dialFn := func(network, addr string) (net.Conn, error) { 233 return conn, nil 234 } 235 236 transport := NewHTTP2Client(dialFn, tlsConfig) 237 return StreamGunWithTransport(transport, cfg) 238 }