github.com/google/martian/v3@v3.3.3/martiantest/transport_test.go (about) 1 package martiantest 2 3 import ( 4 "errors" 5 "net/http" 6 "testing" 7 8 "github.com/google/martian/v3/proxyutil" 9 ) 10 11 func TestTransport(t *testing.T) { 12 req, err := http.NewRequest("GET", "http://example.com", nil) 13 if err != nil { 14 t.Fatalf("http.NewRequest(): got %v, want no error", err) 15 } 16 17 tr := NewTransport() 18 19 res, err := tr.RoundTrip(req) 20 if err != nil { 21 t.Fatalf("tr.Roundtrip(): got %v, want no error", err) 22 } 23 res.Body.Close() 24 25 if got, want := res.StatusCode, 200; got != want { 26 t.Errorf("res.StatusCode: got %d, want %d", got, want) 27 } 28 29 // Respond with 301 response. 30 tr.Respond(301) 31 res, err = tr.RoundTrip(req) 32 if err != nil { 33 t.Fatalf("tr.Roundtrip(): got %v, want no error", err) 34 } 35 res.Body.Close() 36 37 if got, want := res.StatusCode, 301; got != want { 38 t.Errorf("res.StatusCode: got %d, want %d", got, want) 39 } 40 41 // Respond with error. 42 trerr := errors.New("transport error") 43 tr.RespondError(trerr) 44 45 if _, err := tr.RoundTrip(req); err != trerr { 46 t.Fatalf("tr.Roundtrip(): got %v, want %v", err, trerr) 47 } 48 49 // Copy headers from request to response. 50 req.Header.Set("First-Header", "first") 51 req.Header.Set("Second-Header", "second") 52 53 tr.CopyHeaders("First-Header", "Second-Header") 54 55 res, err = tr.RoundTrip(req) 56 if err != nil { 57 t.Fatalf("tr.Roundtrip(): got %v, want no error", err) 58 } 59 res.Body.Close() 60 61 if got, want := res.StatusCode, 200; got != want { 62 t.Errorf("res.StatusCode: got %d, want %d", got, want) 63 } 64 if got, want := res.Header.Get("First-Header"), "first"; got != want { 65 t.Errorf("res.Header.Get(%q): got %q, want %q", "First-Header", got, want) 66 } 67 if got, want := res.Header.Get("Second-Header"), "second"; got != want { 68 t.Errorf("res.Header.Get(%q): got %q, want %q", "Second-Header", got, want) 69 } 70 71 // Custom round trip function. 72 tr.Func(func(req *http.Request) (*http.Response, error) { 73 res := proxyutil.NewResponse(200, nil, req) 74 res.Header.Set("Request-Method", req.Method) 75 76 return res, nil 77 }) 78 79 res, err = tr.RoundTrip(req) 80 if err != nil { 81 t.Fatalf("tr.Roundtrip(): got %v, want no error", err) 82 } 83 res.Body.Close() 84 85 if got, want := res.StatusCode, 200; got != want { 86 t.Errorf("res.StatusCode: got %d, want %d", got, want) 87 } 88 if got, want := res.Header.Get("Request-Method"), "GET"; got != want { 89 t.Errorf("res.Header.Get(%q): got %q, want %q", "Request-Method", got, want) 90 } 91 }