github.com/cloudwego/hertz@v0.9.3/pkg/app/client/middleware_test.go (about)

     1  /*
     2   * Copyright 2022 CloudWeGo Authors
     3   *
     4   * Licensed under the Apache License, Version 2.0 (the "License");
     5   * you may not use this file except in compliance with the License.
     6   * You may obtain a copy of the License at
     7   *
     8   *     http://www.apache.org/licenses/LICENSE-2.0
     9   *
    10   * Unless required by applicable law or agreed to in writing, software
    11   * distributed under the License is distributed on an "AS IS" BASIS,
    12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13   * See the License for the specific language governing permissions and
    14   * limitations under the License.
    15   */
    16  
    17  package client
    18  
    19  import (
    20  	"context"
    21  	"testing"
    22  
    23  	"github.com/cloudwego/hertz/pkg/protocol"
    24  )
    25  
    26  var (
    27  	biz       = "Biz"
    28  	beforeMW0 = "BeforeMiddleware0"
    29  	afterMW0  = "AfterMiddleware0"
    30  	beforeMW1 = "BeforeMiddleware1"
    31  	afterMW1  = "AfterMiddleware1"
    32  )
    33  
    34  func invoke(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) {
    35  	req.BodyBuffer().WriteString(biz)
    36  	return nil
    37  }
    38  
    39  func mockMW0(next Endpoint) Endpoint {
    40  	return func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) {
    41  		req.BodyBuffer().WriteString(beforeMW0)
    42  		err = next(ctx, req, resp)
    43  		if err != nil {
    44  			return err
    45  		}
    46  		req.BodyBuffer().WriteString(afterMW0)
    47  		return nil
    48  	}
    49  }
    50  
    51  func mockMW1(next Endpoint) Endpoint {
    52  	return func(ctx context.Context, req *protocol.Request, resp *protocol.Response) (err error) {
    53  		req.BodyBuffer().WriteString(beforeMW1)
    54  		err = next(ctx, req, resp)
    55  		if err != nil {
    56  			return err
    57  		}
    58  		req.BodyBuffer().WriteString(afterMW1)
    59  		return nil
    60  	}
    61  }
    62  
    63  func TestChain(t *testing.T) {
    64  	mws := chain(mockMW0, mockMW1)
    65  	req := protocol.AcquireRequest()
    66  	mws(invoke)(context.Background(), req, nil)
    67  	final := beforeMW0 + beforeMW1 + biz + afterMW1 + afterMW0
    68  	if req.BodyBuffer().String() != final {
    69  		t.Errorf("unexpected %#v, expected %#v", req.BodyBuffer().String(), final)
    70  	}
    71  }