github.com/cloudwego/hertz@v0.9.3/pkg/common/adaptor/request_test.go (about)

     1  /*
     2   * Copyright 2022 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package adaptor
    18  
    19  import (
    20  	"context"
    21  	"io/ioutil"
    22  	"net/http"
    23  	"net/url"
    24  	"strings"
    25  	"testing"
    26  	"time"
    27  
    28  	"github.com/cloudwego/hertz/pkg/app"
    29  	"github.com/cloudwego/hertz/pkg/app/server"
    30  	"github.com/cloudwego/hertz/pkg/common/test/assert"
    31  	"github.com/cloudwego/hertz/pkg/protocol"
    32  	"github.com/cloudwego/hertz/pkg/protocol/consts"
    33  )
    34  
    35  func TestCompatResponse_WriteHeader(t *testing.T) {
    36  	var testHeader http.Header
    37  	var testBody string
    38  	testUrl1 := "http://127.0.0.1:9000/test1"
    39  	testUrl2 := "http://127.0.0.1:9000/test2"
    40  	testStatusCode := 299
    41  	testCookieValue := "cookie"
    42  
    43  	testHeader = make(map[string][]string)
    44  	testHeader["Key1"] = []string{"value1"}
    45  	testHeader["Key2"] = []string{"value2", "value22"}
    46  	testHeader["Key3"] = []string{"value3", "value33", "value333"}
    47  	testHeader[consts.HeaderSetCookie] = []string{testCookieValue}
    48  
    49  	testBody = "test body"
    50  
    51  	h := server.New(server.WithHostPorts("127.0.0.1:9000"))
    52  	h.POST("/test1", func(c context.Context, ctx *app.RequestContext) {
    53  		req, _ := GetCompatRequest(&ctx.Request)
    54  		resp := GetCompatResponseWriter(&ctx.Response)
    55  		handlerAndCheck(t, resp, req, testHeader, testBody, testStatusCode)
    56  	})
    57  
    58  	h.POST("/test2", func(c context.Context, ctx *app.RequestContext) {
    59  		req, _ := GetCompatRequest(&ctx.Request)
    60  		resp := GetCompatResponseWriter(&ctx.Response)
    61  		handlerAndCheck(t, resp, req, testHeader, testBody)
    62  	})
    63  
    64  	go h.Spin()
    65  	time.Sleep(200 * time.Millisecond)
    66  
    67  	makeACall(t, http.MethodPost, testUrl1, testHeader, testBody, testStatusCode, []byte(testCookieValue))
    68  	makeACall(t, http.MethodPost, testUrl2, testHeader, testBody, consts.StatusOK, []byte(testCookieValue))
    69  }
    70  
    71  func makeACall(t *testing.T, method, url string, header http.Header, body string, expectStatusCode int, expectCookieValue []byte) {
    72  	client := http.Client{}
    73  	req, _ := http.NewRequest(method, url, strings.NewReader(body))
    74  	req.Header = header
    75  	resp, err := client.Do(req)
    76  	if err != nil {
    77  		t.Fatalf("make a call error: %s", err)
    78  	}
    79  
    80  	respHeader := resp.Header
    81  
    82  	for k, v := range header {
    83  		for i := 0; i < len(v); i++ {
    84  			if respHeader[k][i] != v[i] {
    85  				t.Fatalf("Header error: want %s=%s, got %s=%s", respHeader[k], respHeader[k][i], respHeader[k], v[i])
    86  			}
    87  		}
    88  	}
    89  
    90  	b, err := ioutil.ReadAll(resp.Body)
    91  	if err != nil {
    92  		t.Fatalf("Read body error: %s", err)
    93  	}
    94  	assert.DeepEqual(t, body, string(b))
    95  	assert.DeepEqual(t, expectStatusCode, resp.StatusCode)
    96  
    97  	// Parse out the cookie to verify it is correct
    98  	cookie := protocol.Cookie{}
    99  	_ = cookie.Parse(header[consts.HeaderSetCookie][0])
   100  	assert.DeepEqual(t, expectCookieValue, cookie.Value())
   101  }
   102  
   103  // handlerAndCheck is designed to handle the program and check the header
   104  //
   105  // "..." is used in the type of statusCode, which is a syntactic sugar in Go.
   106  // In this way, the statusCode can be made an optional parameter,
   107  // and there is no need to pass in some meaningless numbers to judge some special cases.
   108  func handlerAndCheck(t *testing.T, writer http.ResponseWriter, request *http.Request, wantHeader http.Header, wantBody string, statusCode ...int) {
   109  	reqHeader := request.Header
   110  	for k, v := range wantHeader {
   111  		if reqHeader[k] == nil {
   112  			t.Fatalf("Header error: want %s=%s, got %s=nil", reqHeader[k], reqHeader[k][0], reqHeader[k])
   113  		}
   114  		if reqHeader[k][0] != v[0] {
   115  			t.Fatalf("Header error: want %s=%s, got %s=%s", reqHeader[k], reqHeader[k][0], reqHeader[k], v[0])
   116  		}
   117  	}
   118  
   119  	body, err := ioutil.ReadAll(request.Body)
   120  	if err != nil {
   121  		t.Fatalf("Read body error: %s", err)
   122  	}
   123  	assert.DeepEqual(t, wantBody, string(body))
   124  
   125  	respHeader := writer.Header()
   126  	for k, v := range reqHeader {
   127  		respHeader[k] = v
   128  	}
   129  
   130  	// When the incoming status code is nil, the execution of this code is skipped
   131  	// and the status code is set to 200
   132  	if statusCode != nil {
   133  		writer.WriteHeader(statusCode[0])
   134  	}
   135  
   136  	_, err = writer.Write([]byte("test"))
   137  	if err != nil {
   138  		t.Fatalf("Write body error: %s", err)
   139  	}
   140  	_, err = writer.Write([]byte(" body"))
   141  	if err != nil {
   142  		t.Fatalf("Write body error: %s", err)
   143  	}
   144  }
   145  
   146  func TestCopyToHertzRequest(t *testing.T) {
   147  	req := http.Request{
   148  		Method:     "GET",
   149  		RequestURI: "/test",
   150  		URL: &url.URL{
   151  			Scheme: "http",
   152  			Host:   "test.com",
   153  		},
   154  		Proto:  "HTTP/1.1",
   155  		Header: http.Header{},
   156  	}
   157  	req.Header.Set("key1", "value1")
   158  	req.Header.Add("key2", "value2")
   159  	req.Header.Add("key2", "value22")
   160  	hertzReq := protocol.Request{}
   161  	err := CopyToHertzRequest(&req, &hertzReq)
   162  	assert.Nil(t, err)
   163  	assert.DeepEqual(t, req.Method, string(hertzReq.Method()))
   164  	assert.DeepEqual(t, req.RequestURI, string(hertzReq.Path()))
   165  	assert.DeepEqual(t, req.Proto, hertzReq.Header.GetProtocol())
   166  	assert.DeepEqual(t, req.Header.Get("key1"), hertzReq.Header.Get("key1"))
   167  	valueSlice := make([]string, 0, 2)
   168  	hertzReq.Header.VisitAllCustomHeader(func(key, value []byte) {
   169  		if strings.ToLower(string(key)) == "key2" {
   170  			valueSlice = append(valueSlice, string(value))
   171  		}
   172  	})
   173  
   174  	assert.DeepEqual(t, req.Header.Values("key2"), valueSlice)
   175  
   176  	assert.DeepEqual(t, 3, hertzReq.Header.Len())
   177  }