github.com/gofunct/common@v0.0.0-20190131174352-fd058c7fbf22/pkg/transport/middleware/passing_header_middleware_test.go (about)

     1  package middleware
     2  
     3  import (
     4  	"net/http"
     5  	"net/http/httptest"
     6  	"strings"
     7  	"testing"
     8  
     9  	"github.com/google/go-cmp/cmp"
    10  )
    11  
    12  func Test_passingHeaderMiddleware(t *testing.T) {
    13  	type Case struct {
    14  		test    string
    15  		decider PassedHeaderDeciderFunc
    16  		in      http.Header
    17  		out     http.Header
    18  	}
    19  
    20  	cases := []Case{
    21  		{
    22  			test:    "passing 1 header",
    23  			decider: func(k string) bool { return strings.HasPrefix(k, "X-Debug-") },
    24  			in: http.Header{
    25  				"X-Debug-User-Id": []string{"100"},
    26  				"X-User-Id":       []string{"100"},
    27  			},
    28  			out: http.Header{
    29  				"X-Debug-User-Id":               []string{"100"},
    30  				"Grpc-Metadata-X-Debug-User-Id": []string{"100"},
    31  				"X-User-Id":                     []string{"100"},
    32  			},
    33  		},
    34  	}
    35  
    36  	getDefaultHeader := func() http.Header {
    37  		return http.Header{
    38  			"Accept-Encoding": []string{"gzip"},
    39  			"User-Agent":      []string{"Go-http-client/1.1"},
    40  		}
    41  	}
    42  
    43  	for _, c := range cases {
    44  		t.Run(c.test, func(t *testing.T) {
    45  			var wantHeader, gotHeader http.Header
    46  			wantHeader = getDefaultHeader()
    47  			for k, v := range c.out {
    48  				wantHeader.Set(k, v[0])
    49  			}
    50  
    51  			wrap := CreatePassingHeaderMiddleware(c.decider)
    52  			s := httptest.NewServer(wrap(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    53  				gotHeader = r.Header
    54  				w.WriteHeader(200)
    55  			})))
    56  			defer s.Close()
    57  
    58  			req, _ := http.NewRequest("GET", s.URL, nil)
    59  			req.Header = c.in
    60  
    61  			(&http.Client{}).Do(req)
    62  			if diff := cmp.Diff(gotHeader, wantHeader); diff != "" {
    63  				t.Errorf("Received header differs: (-got +want)\n%s", diff)
    64  			}
    65  
    66  			(&http.Client{}).Do(req)
    67  			if diff := cmp.Diff(gotHeader, wantHeader); diff != "" {
    68  				t.Errorf("Received header differs: (-got +want)\n%s", diff)
    69  			}
    70  		})
    71  	}
    72  }