gitee.com/zhaochuninhefei/gmgo@v0.0.31-0.20240209061119-069254a02979/grpc/internal/transport/proxy_test.go (about)

     1  //go:build !race
     2  // +build !race
     3  
     4  /*
     5   *
     6   * Copyright 2017 gRPC authors.
     7   *
     8   * Licensed under the Apache License, Version 2.0 (the "License");
     9   * you may not use this file except in compliance with the License.
    10   * You may obtain a copy of the License at
    11   *
    12   *     http://www.apache.org/licenses/LICENSE-2.0
    13   *
    14   * Unless required by applicable law or agreed to in writing, software
    15   * distributed under the License is distributed on an "AS IS" BASIS,
    16   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    17   * See the License for the specific language governing permissions and
    18   * limitations under the License.
    19   *
    20   */
    21  
    22  package transport
    23  
    24  import (
    25  	"bufio"
    26  	"context"
    27  	"encoding/base64"
    28  	"fmt"
    29  	"io"
    30  	"net"
    31  	"net/url"
    32  	"testing"
    33  	"time"
    34  
    35  	http "gitee.com/zhaochuninhefei/gmgo/gmhttp"
    36  )
    37  
    38  const (
    39  	envTestAddr  = "1.2.3.4:8080"
    40  	envProxyAddr = "2.3.4.5:7687"
    41  )
    42  
    43  // overwriteAndRestore overwrite function httpProxyFromEnvironment and
    44  // returns a function to restore the default values.
    45  func overwrite(hpfe func(req *http.Request) (*url.URL, error)) func() {
    46  	backHPFE := httpProxyFromEnvironment
    47  	httpProxyFromEnvironment = hpfe
    48  	return func() {
    49  		httpProxyFromEnvironment = backHPFE
    50  	}
    51  }
    52  
    53  type proxyServer struct {
    54  	t   *testing.T
    55  	lis net.Listener
    56  	in  net.Conn
    57  	out net.Conn
    58  
    59  	requestCheck func(*http.Request) error
    60  }
    61  
    62  func (p *proxyServer) run() {
    63  	in, err := p.lis.Accept()
    64  	if err != nil {
    65  		return
    66  	}
    67  	p.in = in
    68  
    69  	req, err := http.ReadRequest(bufio.NewReader(in))
    70  	if err != nil {
    71  		p.t.Errorf("failed to read CONNECT req: %v", err)
    72  		return
    73  	}
    74  	if err := p.requestCheck(req); err != nil {
    75  		resp := http.Response{StatusCode: http.StatusMethodNotAllowed}
    76  		_ = resp.Write(p.in)
    77  		_ = p.in.Close()
    78  		p.t.Errorf("get wrong CONNECT req: %+v, error: %v", req, err)
    79  		return
    80  	}
    81  
    82  	out, err := net.Dial("tcp", req.URL.Host)
    83  	if err != nil {
    84  		p.t.Errorf("failed to dial to server: %v", err)
    85  		return
    86  	}
    87  	resp := http.Response{StatusCode: http.StatusOK, Proto: "HTTP/1.0"}
    88  	_ = resp.Write(p.in)
    89  	p.out = out
    90  	go func() {
    91  		_, _ = io.Copy(p.in, p.out)
    92  	}()
    93  	go func() {
    94  		_, _ = io.Copy(p.out, p.in)
    95  	}()
    96  }
    97  
    98  func (p *proxyServer) stop() {
    99  	_ = p.lis.Close()
   100  	if p.in != nil {
   101  		_ = p.in.Close()
   102  	}
   103  	if p.out != nil {
   104  		_ = p.out.Close()
   105  	}
   106  }
   107  
   108  func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxyReqCheck func(*http.Request) error) {
   109  	plis, err := net.Listen("tcp", "localhost:0")
   110  	if err != nil {
   111  		t.Fatalf("failed to listen: %v", err)
   112  	}
   113  	p := &proxyServer{
   114  		t:            t,
   115  		lis:          plis,
   116  		requestCheck: proxyReqCheck,
   117  	}
   118  	go p.run()
   119  	defer p.stop()
   120  
   121  	blis, err := net.Listen("tcp", "localhost:0")
   122  	if err != nil {
   123  		t.Fatalf("failed to listen: %v", err)
   124  	}
   125  
   126  	msg := []byte{4, 3, 5, 2}
   127  	recvBuf := make([]byte, len(msg))
   128  	done := make(chan error, 1)
   129  	go func() {
   130  		in, err := blis.Accept()
   131  		if err != nil {
   132  			done <- err
   133  			return
   134  		}
   135  		defer func(in net.Conn) {
   136  			_ = in.Close()
   137  		}(in)
   138  		_, _ = in.Read(recvBuf)
   139  		done <- nil
   140  	}()
   141  
   142  	// Overwrite the function in the test and restore them in defer.
   143  	hpfe := func(req *http.Request) (*url.URL, error) {
   144  		return proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil
   145  	}
   146  	defer overwrite(hpfe)()
   147  
   148  	// Dial to proxy server.
   149  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   150  	defer cancel()
   151  	c, err := proxyDial(ctx, blis.Addr().String(), "test")
   152  	if err != nil {
   153  		t.Fatalf("http connect Dial failed: %v", err)
   154  	}
   155  	defer func(c net.Conn) {
   156  		_ = c.Close()
   157  	}(c)
   158  
   159  	// Send msg on the connection.
   160  	_, _ = c.Write(msg)
   161  	if err := <-done; err != nil {
   162  		t.Fatalf("failed to accept: %v", err)
   163  	}
   164  
   165  	// Check received msg.
   166  	if string(recvBuf) != string(msg) {
   167  		t.Fatalf("received msg: %v, want %v", recvBuf, msg)
   168  	}
   169  }
   170  
   171  func (s) TestHTTPConnect(t *testing.T) {
   172  	testHTTPConnect(t,
   173  		func(in *url.URL) *url.URL {
   174  			return in
   175  		},
   176  		func(req *http.Request) error {
   177  			if req.Method != http.MethodConnect {
   178  				return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
   179  			}
   180  			return nil
   181  		},
   182  	)
   183  }
   184  
   185  func (s) TestHTTPConnectBasicAuth(t *testing.T) {
   186  	const (
   187  		user     = "notAUser"
   188  		password = "notAPassword"
   189  	)
   190  	testHTTPConnect(t,
   191  		func(in *url.URL) *url.URL {
   192  			in.User = url.UserPassword(user, password)
   193  			return in
   194  		},
   195  		func(req *http.Request) error {
   196  			if req.Method != http.MethodConnect {
   197  				return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
   198  			}
   199  			wantProxyAuthStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password))
   200  			if got := req.Header.Get(proxyAuthHeaderKey); got != wantProxyAuthStr {
   201  				gotDecoded, _ := base64.StdEncoding.DecodeString(got)
   202  				wantDecoded, _ := base64.StdEncoding.DecodeString(wantProxyAuthStr)
   203  				return fmt.Errorf("unexpected auth %q (%q), want %q (%q)", got, gotDecoded, wantProxyAuthStr, wantDecoded)
   204  			}
   205  			return nil
   206  		},
   207  	)
   208  }
   209  
   210  func (s) TestMapAddressEnv(t *testing.T) {
   211  	// Overwrite the function in the test and restore them in defer.
   212  	hpfe := func(req *http.Request) (*url.URL, error) {
   213  		if req.URL.Host == envTestAddr {
   214  			return &url.URL{
   215  				Scheme: "https",
   216  				Host:   envProxyAddr,
   217  			}, nil
   218  		}
   219  		return nil, nil
   220  	}
   221  	defer overwrite(hpfe)()
   222  
   223  	// envTestAddr should be handled by ProxyFromEnvironment.
   224  	got, err := mapAddress(envTestAddr)
   225  	if err != nil {
   226  		t.Error(err)
   227  	}
   228  	if got.Host != envProxyAddr {
   229  		t.Errorf("want %v, got %v", envProxyAddr, got)
   230  	}
   231  }