github.com/cloudwego/hertz@v0.9.3/pkg/protocol/http1/proxy/proxy.go (about)

     1  /*
     2   * Copyright 2022 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   *
    16   * Copyright 2016 The Go Authors. All rights reserved.
    17   * Use of this source code is governed by a BSD-style
    18   * license that can be found in the LICENSE file.
    19   *
    20   * This file may have been modified by CloudWeGo authors. All CloudWeGo
    21   * Modifications are Copyright 2022 CloudWeGo Authors.
    22   */
    23  
    24  package proxy
    25  
    26  import (
    27  	"bytes"
    28  	"context"
    29  	"crypto/tls"
    30  	"encoding/base64"
    31  	"time"
    32  
    33  	"github.com/cloudwego/hertz/internal/bytesconv"
    34  	"github.com/cloudwego/hertz/internal/bytestr"
    35  	"github.com/cloudwego/hertz/pkg/common/errors"
    36  	"github.com/cloudwego/hertz/pkg/network"
    37  	"github.com/cloudwego/hertz/pkg/protocol"
    38  	"github.com/cloudwego/hertz/pkg/protocol/consts"
    39  	reqI "github.com/cloudwego/hertz/pkg/protocol/http1/req"
    40  	respI "github.com/cloudwego/hertz/pkg/protocol/http1/resp"
    41  )
    42  
    43  func SetupProxy(conn network.Conn, addr string, proxyURI *protocol.URI, tlsConfig *tls.Config, isTLS bool, dialer network.Dialer) (network.Conn, error) {
    44  	var err error
    45  	if bytes.Equal(proxyURI.Scheme(), bytestr.StrHTTPS) {
    46  		conn, err = dialer.AddTLS(conn, tlsConfig)
    47  		if err != nil {
    48  			return nil, err
    49  		}
    50  	}
    51  
    52  	switch {
    53  	case proxyURI == nil:
    54  		// Do nothing. Not using a proxy.
    55  	case isTLS: // target addr is https
    56  		connectReq, connectResp := protocol.AcquireRequest(), protocol.AcquireResponse()
    57  		defer func() {
    58  			protocol.ReleaseRequest(connectReq)
    59  			protocol.ReleaseResponse(connectResp)
    60  		}()
    61  
    62  		SetProxyAuthHeader(&connectReq.Header, proxyURI)
    63  		connectReq.SetMethod(consts.MethodConnect)
    64  		connectReq.SetHost(addr)
    65  
    66  		// Skip response body when send CONNECT request.
    67  		connectResp.SkipBody = true
    68  
    69  		// If there's no done channel (no deadline or cancellation
    70  		// from the caller possible), at least set some (long)
    71  		// timeout here. This will make sure we don't block forever
    72  		// and leak a goroutine if the connection stops replying
    73  		// after the TCP connect.
    74  		connectCtx, cancel := context.WithTimeout(context.Background(), 1*time.Minute)
    75  		defer cancel()
    76  
    77  		didReadResponse := make(chan struct{}) // closed after CONNECT write+read is done or fails
    78  
    79  		// Write the CONNECT request & read the response.
    80  		go func() {
    81  			defer close(didReadResponse)
    82  
    83  			err = reqI.Write(connectReq, conn)
    84  			if err != nil {
    85  				return
    86  			}
    87  
    88  			err = conn.Flush()
    89  			if err != nil {
    90  				return
    91  			}
    92  
    93  			err = respI.Read(connectResp, conn)
    94  		}()
    95  		select {
    96  		case <-connectCtx.Done():
    97  			conn.Close()
    98  			<-didReadResponse
    99  
   100  			return nil, connectCtx.Err()
   101  		case <-didReadResponse:
   102  		}
   103  
   104  		if err != nil {
   105  			conn.Close()
   106  			return nil, err
   107  		}
   108  
   109  		if connectResp.StatusCode() != consts.StatusOK {
   110  			conn.Close()
   111  
   112  			return nil, errors.NewPublic(consts.StatusMessage(connectResp.StatusCode()))
   113  		}
   114  	}
   115  
   116  	if proxyURI != nil && isTLS {
   117  		conn, err = dialer.AddTLS(conn, tlsConfig)
   118  		if err != nil {
   119  			return nil, err
   120  		}
   121  	}
   122  
   123  	return conn, nil
   124  }
   125  
   126  func SetProxyAuthHeader(h *protocol.RequestHeader, proxyURI *protocol.URI) {
   127  	if username := proxyURI.Username(); username != nil {
   128  		password := proxyURI.Password()
   129  		auth := base64.StdEncoding.EncodeToString(bytesconv.S2b(bytesconv.B2s(username) + ":" + bytesconv.B2s(password)))
   130  		h.Set("Proxy-Authorization", "Basic "+auth)
   131  	}
   132  }