gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/grpc/internal/transport/proxy_test.go (about)

     1  //go:build !race
     2  // +build !race
     3  
     4  /*
     5   *
     6   * Copyright 2017 gRPC authors.
     7   *
     8   * Licensed under the Apache License, Version 2.0 (the "License");
     9   * you may not use this file except in compliance with the License.
    10   * You may obtain a copy of the License at
    11   *
    12   *     http://www.apache.org/licenses/LICENSE-2.0
    13   *
    14   * Unless required by applicable law or agreed to in writing, software
    15   * distributed under the License is distributed on an "AS IS" BASIS,
    16   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    17   * See the License for the specific language governing permissions and
    18   * limitations under the License.
    19   *
    20   */
    21  
    22  package transport
    23  
    24  import (
    25  	"bufio"
    26  	"context"
    27  	"encoding/base64"
    28  	"fmt"
    29  	"io"
    30  	"net"
    31  	"net/url"
    32  	"testing"
    33  	"time"
    34  
    35  	http "gitee.com/ks-custle/core-gm/gmhttp"
    36  )
    37  
    38  const (
    39  	envTestAddr  = "1.2.3.4:8080"
    40  	envProxyAddr = "2.3.4.5:7687"
    41  )
    42  
    43  // overwriteAndRestore overwrite function httpProxyFromEnvironment and
    44  // returns a function to restore the default values.
    45  func overwrite(hpfe func(req *http.Request) (*url.URL, error)) func() {
    46  	backHPFE := httpProxyFromEnvironment
    47  	httpProxyFromEnvironment = hpfe
    48  	return func() {
    49  		httpProxyFromEnvironment = backHPFE
    50  	}
    51  }
    52  
    53  type proxyServer struct {
    54  	t   *testing.T
    55  	lis net.Listener
    56  	in  net.Conn
    57  	out net.Conn
    58  
    59  	requestCheck func(*http.Request) error
    60  }
    61  
    62  func (p *proxyServer) run() {
    63  	in, err := p.lis.Accept()
    64  	if err != nil {
    65  		return
    66  	}
    67  	p.in = in
    68  
    69  	req, err := http.ReadRequest(bufio.NewReader(in))
    70  	if err != nil {
    71  		p.t.Errorf("failed to read CONNECT req: %v", err)
    72  		return
    73  	}
    74  	if err := p.requestCheck(req); err != nil {
    75  		resp := http.Response{StatusCode: http.StatusMethodNotAllowed}
    76  		resp.Write(p.in)
    77  		p.in.Close()
    78  		p.t.Errorf("get wrong CONNECT req: %+v, error: %v", req, err)
    79  		return
    80  	}
    81  
    82  	out, err := net.Dial("tcp", req.URL.Host)
    83  	if err != nil {
    84  		p.t.Errorf("failed to dial to server: %v", err)
    85  		return
    86  	}
    87  	resp := http.Response{StatusCode: http.StatusOK, Proto: "HTTP/1.0"}
    88  	resp.Write(p.in)
    89  	p.out = out
    90  	go io.Copy(p.in, p.out)
    91  	go io.Copy(p.out, p.in)
    92  }
    93  
    94  func (p *proxyServer) stop() {
    95  	p.lis.Close()
    96  	if p.in != nil {
    97  		p.in.Close()
    98  	}
    99  	if p.out != nil {
   100  		p.out.Close()
   101  	}
   102  }
   103  
   104  func testHTTPConnect(t *testing.T, proxyURLModify func(*url.URL) *url.URL, proxyReqCheck func(*http.Request) error) {
   105  	plis, err := net.Listen("tcp", "localhost:0")
   106  	if err != nil {
   107  		t.Fatalf("failed to listen: %v", err)
   108  	}
   109  	p := &proxyServer{
   110  		t:            t,
   111  		lis:          plis,
   112  		requestCheck: proxyReqCheck,
   113  	}
   114  	go p.run()
   115  	defer p.stop()
   116  
   117  	blis, err := net.Listen("tcp", "localhost:0")
   118  	if err != nil {
   119  		t.Fatalf("failed to listen: %v", err)
   120  	}
   121  
   122  	msg := []byte{4, 3, 5, 2}
   123  	recvBuf := make([]byte, len(msg))
   124  	done := make(chan error, 1)
   125  	go func() {
   126  		in, err := blis.Accept()
   127  		if err != nil {
   128  			done <- err
   129  			return
   130  		}
   131  		defer in.Close()
   132  		in.Read(recvBuf)
   133  		done <- nil
   134  	}()
   135  
   136  	// Overwrite the function in the test and restore them in defer.
   137  	hpfe := func(req *http.Request) (*url.URL, error) {
   138  		return proxyURLModify(&url.URL{Host: plis.Addr().String()}), nil
   139  	}
   140  	defer overwrite(hpfe)()
   141  
   142  	// Dial to proxy server.
   143  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   144  	defer cancel()
   145  	c, err := proxyDial(ctx, blis.Addr().String(), "test")
   146  	if err != nil {
   147  		t.Fatalf("http connect Dial failed: %v", err)
   148  	}
   149  	defer c.Close()
   150  
   151  	// Send msg on the connection.
   152  	c.Write(msg)
   153  	if err := <-done; err != nil {
   154  		t.Fatalf("failed to accept: %v", err)
   155  	}
   156  
   157  	// Check received msg.
   158  	if string(recvBuf) != string(msg) {
   159  		t.Fatalf("received msg: %v, want %v", recvBuf, msg)
   160  	}
   161  }
   162  
   163  func (s) TestHTTPConnect(t *testing.T) {
   164  	testHTTPConnect(t,
   165  		func(in *url.URL) *url.URL {
   166  			return in
   167  		},
   168  		func(req *http.Request) error {
   169  			if req.Method != http.MethodConnect {
   170  				return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
   171  			}
   172  			return nil
   173  		},
   174  	)
   175  }
   176  
   177  func (s) TestHTTPConnectBasicAuth(t *testing.T) {
   178  	const (
   179  		user     = "notAUser"
   180  		password = "notAPassword"
   181  	)
   182  	testHTTPConnect(t,
   183  		func(in *url.URL) *url.URL {
   184  			in.User = url.UserPassword(user, password)
   185  			return in
   186  		},
   187  		func(req *http.Request) error {
   188  			if req.Method != http.MethodConnect {
   189  				return fmt.Errorf("unexpected Method %q, want %q", req.Method, http.MethodConnect)
   190  			}
   191  			wantProxyAuthStr := "Basic " + base64.StdEncoding.EncodeToString([]byte(user+":"+password))
   192  			if got := req.Header.Get(proxyAuthHeaderKey); got != wantProxyAuthStr {
   193  				gotDecoded, _ := base64.StdEncoding.DecodeString(got)
   194  				wantDecoded, _ := base64.StdEncoding.DecodeString(wantProxyAuthStr)
   195  				return fmt.Errorf("unexpected auth %q (%q), want %q (%q)", got, gotDecoded, wantProxyAuthStr, wantDecoded)
   196  			}
   197  			return nil
   198  		},
   199  	)
   200  }
   201  
   202  func (s) TestMapAddressEnv(t *testing.T) {
   203  	// Overwrite the function in the test and restore them in defer.
   204  	hpfe := func(req *http.Request) (*url.URL, error) {
   205  		if req.URL.Host == envTestAddr {
   206  			return &url.URL{
   207  				Scheme: "https",
   208  				Host:   envProxyAddr,
   209  			}, nil
   210  		}
   211  		return nil, nil
   212  	}
   213  	defer overwrite(hpfe)()
   214  
   215  	// envTestAddr should be handled by ProxyFromEnvironment.
   216  	got, err := mapAddress(envTestAddr)
   217  	if err != nil {
   218  		t.Error(err)
   219  	}
   220  	if got.Host != envProxyAddr {
   221  		t.Errorf("want %v, got %v", envProxyAddr, got)
   222  	}
   223  }