github.com/cloudwego/hertz@v0.9.3/pkg/common/ut/request_test.go (about)

     1  /*
     2   * Copyright 2022 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package ut
    18  
    19  import (
    20  	"bytes"
    21  	"context"
    22  	"fmt"
    23  	"testing"
    24  
    25  	"github.com/cloudwego/hertz/pkg/app"
    26  	"github.com/cloudwego/hertz/pkg/common/config"
    27  	"github.com/cloudwego/hertz/pkg/common/test/assert"
    28  	"github.com/cloudwego/hertz/pkg/protocol/consts"
    29  	"github.com/cloudwego/hertz/pkg/route"
    30  )
    31  
    32  func newTestEngine() *route.Engine {
    33  	opt := config.NewOptions([]config.Option{})
    34  	return route.NewEngine(opt)
    35  }
    36  
    37  func TestPerformRequest(t *testing.T) {
    38  	router := newTestEngine()
    39  	router.PUT("/hey/:user", func(ctx context.Context, c *app.RequestContext) {
    40  		user := c.Param("user")
    41  		if string(c.Request.Body()) == "1" {
    42  			assert.DeepEqual(t, "close", c.Request.Header.Get("Connection"))
    43  			c.Response.SetConnectionClose()
    44  			c.JSON(consts.StatusCreated, map[string]string{"hi": user})
    45  		} else if string(c.Request.Body()) == "" {
    46  			c.AbortWithMsg("unauthorized", consts.StatusUnauthorized)
    47  		} else {
    48  			assert.DeepEqual(t, "PUT /hey/dy HTTP/1.1\r\nContent-Type: application/x-www-form-urlencoded\r\nTransfer-Encoding: chunked\r\n\r\n", string(c.Request.Header.Header()))
    49  			c.String(consts.StatusAccepted, "body:%v", string(c.Request.Body()))
    50  		}
    51  	})
    52  	router.GET("/her/header", func(ctx context.Context, c *app.RequestContext) {
    53  		assert.DeepEqual(t, "application/json", string(c.GetHeader("Content-Type")))
    54  		assert.DeepEqual(t, 1, c.Request.Header.ContentLength())
    55  		assert.DeepEqual(t, "a", c.Request.Header.Get("dummy"))
    56  	})
    57  
    58  	// valid user
    59  	w := PerformRequest(router, "PUT", "/hey/dy", &Body{bytes.NewBufferString("1"), 1},
    60  		Header{"Connection", "close"})
    61  	resp := w.Result()
    62  	assert.DeepEqual(t, consts.StatusCreated, resp.StatusCode())
    63  	assert.DeepEqual(t, "{\"hi\":\"dy\"}", string(resp.Body()))
    64  	assert.DeepEqual(t, "application/json; charset=utf-8", string(resp.Header.ContentType()))
    65  	assert.DeepEqual(t, true, resp.Header.ConnectionClose())
    66  
    67  	// unauthorized user
    68  	w = PerformRequest(router, "PUT", "/hey/dy", nil)
    69  	_ = w.Result()
    70  	resp = w.Result()
    71  	assert.DeepEqual(t, consts.StatusUnauthorized, resp.StatusCode())
    72  	assert.DeepEqual(t, "unauthorized", string(resp.Body()))
    73  	assert.DeepEqual(t, "text/plain; charset=utf-8", string(resp.Header.ContentType()))
    74  	assert.DeepEqual(t, 12, resp.Header.ContentLength())
    75  
    76  	// special header
    77  	PerformRequest(router, "GET", "/hey/header", nil,
    78  		Header{"content-type", "application/json"},
    79  		Header{"content-length", "1"},
    80  		Header{"dummy", "a"},
    81  		Header{"dummy", "b"},
    82  	)
    83  
    84  	// not found
    85  	w = PerformRequest(router, "GET", "/hey", nil)
    86  	resp = w.Result()
    87  	assert.DeepEqual(t, consts.StatusNotFound, resp.StatusCode())
    88  
    89  	// fake body
    90  	w = PerformRequest(router, "GET", "/hey", nil)
    91  	_, err := w.WriteString(", faker")
    92  	resp = w.Result()
    93  	assert.Nil(t, err)
    94  	assert.DeepEqual(t, consts.StatusNotFound, resp.StatusCode())
    95  	assert.DeepEqual(t, "404 page not found, faker", string(resp.Body()))
    96  
    97  	// chunked body
    98  	body := bytes.NewReader(createChunkedBody([]byte("hello world!")))
    99  	w = PerformRequest(router, "PUT", "/hey/dy", &Body{body, -1})
   100  	resp = w.Result()
   101  	assert.DeepEqual(t, consts.StatusAccepted, resp.StatusCode())
   102  	assert.DeepEqual(t, "body:1\r\nh\r\n2\r\nel\r\n3\r\nlo \r\n4\r\nworl\r\n2\r\nd!\r\n0\r\n\r\n", string(resp.Body()))
   103  }
   104  
   105  func createChunkedBody(body []byte) []byte {
   106  	var b []byte
   107  	chunkSize := 1
   108  	for len(body) > 0 {
   109  		if chunkSize > len(body) {
   110  			chunkSize = len(body)
   111  		}
   112  		b = append(b, []byte(fmt.Sprintf("%x\r\n", chunkSize))...)
   113  		b = append(b, body[:chunkSize]...)
   114  		b = append(b, []byte("\r\n")...)
   115  		body = body[chunkSize:]
   116  		chunkSize++
   117  	}
   118  	return append(b, []byte("0\r\n\r\n")...)
   119  }