github.com/google/martian/v3@v3.3.3/body/body_modifier_test.go (about)

     1  // Copyright 2015 Google Inc. All rights reserved.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package body
    16  
    17  import (
    18  	"bytes"
    19  	"encoding/base64"
    20  	"fmt"
    21  	"io"
    22  	"io/ioutil"
    23  	"mime/multipart"
    24  	"net/http"
    25  	"strings"
    26  	"testing"
    27  
    28  	"github.com/google/martian/v3/messageview"
    29  	"github.com/google/martian/v3/parse"
    30  	"github.com/google/martian/v3/proxyutil"
    31  )
    32  
    33  func TestBodyModifier(t *testing.T) {
    34  	mod := NewModifier([]byte("text"), "text/plain")
    35  
    36  	req, err := http.NewRequest("GET", "/", strings.NewReader(""))
    37  	if err != nil {
    38  		t.Fatalf("NewRequest(): got %v, want no error", err)
    39  	}
    40  	req.Header.Set("Content-Encoding", "gzip")
    41  
    42  	if err := mod.ModifyRequest(req); err != nil {
    43  		t.Fatalf("ModifyRequest(): got %v, want no error", err)
    44  	}
    45  
    46  	if got, want := req.Header.Get("Content-Type"), "text/plain"; got != want {
    47  		t.Errorf("req.Header.Get(%q): got %v, want %v", "Content-Type", got, want)
    48  	}
    49  	if got, want := req.ContentLength, int64(len([]byte("text"))); got != want {
    50  		t.Errorf("req.ContentLength: got %d, want %d", got, want)
    51  	}
    52  	if got, want := req.Header.Get("Content-Encoding"), ""; got != want {
    53  		t.Errorf("req.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want)
    54  	}
    55  
    56  	got, err := ioutil.ReadAll(req.Body)
    57  	if err != nil {
    58  		t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
    59  	}
    60  	req.Body.Close()
    61  
    62  	if want := []byte("text"); !bytes.Equal(got, want) {
    63  		t.Errorf("res.Body: got %q, want %q", got, want)
    64  	}
    65  
    66  	res := proxyutil.NewResponse(200, nil, req)
    67  	res.Header.Set("Content-Encoding", "gzip")
    68  
    69  	if err := mod.ModifyResponse(res); err != nil {
    70  		t.Fatalf("ModifyResponse(): got %v, want no error", err)
    71  	}
    72  
    73  	if got, want := res.Header.Get("Content-Type"), "text/plain"; got != want {
    74  		t.Errorf("res.Header.Get(%q): got %v, want %v", "Content-Type", got, want)
    75  	}
    76  	if got, want := res.ContentLength, int64(len([]byte("text"))); got != want {
    77  		t.Errorf("res.ContentLength: got %d, want %d", got, want)
    78  	}
    79  	if got, want := res.Header.Get("Content-Encoding"), ""; got != want {
    80  		t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want)
    81  	}
    82  
    83  	got, err = ioutil.ReadAll(res.Body)
    84  	if err != nil {
    85  		t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
    86  	}
    87  	res.Body.Close()
    88  
    89  	if want := []byte("text"); !bytes.Equal(got, want) {
    90  		t.Errorf("res.Body: got %q, want %q", got, want)
    91  	}
    92  }
    93  func TestRangeHeaderRequestSingleRange(t *testing.T) {
    94  	mod := NewModifier([]byte("0123456789"), "text/plain")
    95  
    96  	req, err := http.NewRequest("GET", "/", strings.NewReader(""))
    97  	if err != nil {
    98  		t.Fatalf("NewRequest(): got %v, want no error", err)
    99  	}
   100  	req.Header.Set("Range", "bytes=1-4")
   101  
   102  	res := proxyutil.NewResponse(200, nil, req)
   103  
   104  	if err := mod.ModifyResponse(res); err != nil {
   105  		t.Fatalf("ModifyResponse(): got %v, want no error", err)
   106  	}
   107  
   108  	if got, want := res.StatusCode, http.StatusPartialContent; got != want {
   109  		t.Errorf("res.Status: got %v, want %v", got, want)
   110  	}
   111  	if got, want := res.ContentLength, int64(len([]byte("1234"))); got != want {
   112  		t.Errorf("res.ContentLength: got %d, want %d", got, want)
   113  	}
   114  	if got, want := res.Header.Get("Content-Range"), "bytes 1-4/10"; got != want {
   115  		t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want)
   116  	}
   117  
   118  	got, err := ioutil.ReadAll(res.Body)
   119  	if err != nil {
   120  		t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
   121  	}
   122  	res.Body.Close()
   123  
   124  	if want := []byte("1234"); !bytes.Equal(got, want) {
   125  		t.Errorf("res.Body: got %q, want %q", got, want)
   126  	}
   127  }
   128  
   129  func TestRangeHeaderRequestSingleRangeHasAllTheBytes(t *testing.T) {
   130  	mod := NewModifier([]byte("0123456789"), "text/plain")
   131  
   132  	req, err := http.NewRequest("GET", "/", strings.NewReader(""))
   133  	if err != nil {
   134  		t.Fatalf("NewRequest(): got %v, want no error", err)
   135  	}
   136  	req.Header.Set("Range", "bytes=0-")
   137  
   138  	res := proxyutil.NewResponse(200, nil, req)
   139  
   140  	if err := mod.ModifyResponse(res); err != nil {
   141  		t.Fatalf("ModifyResponse(): got %v, want no error", err)
   142  	}
   143  
   144  	if got, want := res.StatusCode, http.StatusPartialContent; got != want {
   145  		t.Errorf("res.Status: got %v, want %v", got, want)
   146  	}
   147  	if got, want := res.ContentLength, int64(len([]byte("0123456789"))); got != want {
   148  		t.Errorf("res.ContentLength: got %d, want %d", got, want)
   149  	}
   150  	if got, want := res.Header.Get("Content-Range"), "bytes 0-9/10"; got != want {
   151  		t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want)
   152  	}
   153  
   154  	got, err := ioutil.ReadAll(res.Body)
   155  	if err != nil {
   156  		t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
   157  	}
   158  	res.Body.Close()
   159  
   160  	if want := []byte("0123456789"); !bytes.Equal(got, want) {
   161  		t.Errorf("res.Body: got %q, want %q", got, want)
   162  	}
   163  }
   164  
   165  func TestRangeNoEndingIndexSpecified(t *testing.T) {
   166  	mod := NewModifier([]byte("0123456789"), "text/plain")
   167  
   168  	req, err := http.NewRequest("GET", "/", strings.NewReader(""))
   169  	if err != nil {
   170  		t.Fatalf("NewRequest(): got %v, want no error", err)
   171  	}
   172  	req.Header.Set("Range", "bytes=8-")
   173  
   174  	res := proxyutil.NewResponse(200, nil, req)
   175  
   176  	if err := mod.ModifyResponse(res); err != nil {
   177  		t.Fatalf("ModifyResponse(): got %v, want no error", err)
   178  	}
   179  
   180  	if got, want := res.StatusCode, http.StatusPartialContent; got != want {
   181  		t.Errorf("res.Status: got %v, want %v", got, want)
   182  	}
   183  	if got, want := res.ContentLength, int64(len([]byte("89"))); got != want {
   184  		t.Errorf("res.ContentLength: got %d, want %d", got, want)
   185  	}
   186  	if got, want := res.Header.Get("Content-Range"), "bytes 8-9/10"; got != want {
   187  		t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Encoding", got, want)
   188  	}
   189  }
   190  
   191  func TestRangeHeaderMultipartRange(t *testing.T) {
   192  	mod := NewModifier([]byte("0123456789"), "text/plain")
   193  	bndry := "3d6b6a416f9b5"
   194  	mod.SetBoundary(bndry)
   195  
   196  	req, err := http.NewRequest("GET", "/", strings.NewReader(""))
   197  	if err != nil {
   198  		t.Fatalf("NewRequest(): got %v, want no error", err)
   199  	}
   200  	req.Header.Set("Range", "bytes=1-4, 7-9")
   201  
   202  	res := proxyutil.NewResponse(200, nil, req)
   203  	if err := mod.ModifyResponse(res); err != nil {
   204  		t.Fatalf("ModifyResponse(): got %v, want no error", err)
   205  	}
   206  
   207  	if got, want := res.StatusCode, http.StatusPartialContent; got != want {
   208  		t.Errorf("res.Status: got %v, want %v", got, want)
   209  	}
   210  
   211  	if got, want := res.Header.Get("Content-Type"), "multipart/byteranges; boundary=3d6b6a416f9b5"; got != want {
   212  		t.Errorf("res.Header.Get(%q): got %q, want %q", "Content-Type", got, want)
   213  	}
   214  
   215  	mv := messageview.New()
   216  	if err := mv.SnapshotResponse(res); err != nil {
   217  		t.Fatalf("mv.SnapshotResponse(res): got %v, want no error", err)
   218  	}
   219  
   220  	br, err := mv.BodyReader()
   221  	if err != nil {
   222  		t.Fatalf("mv.BodyReader(): got %v, want no error", err)
   223  	}
   224  
   225  	mpr := multipart.NewReader(br, bndry)
   226  	prt1, err := mpr.NextPart()
   227  	if err != nil {
   228  		t.Fatalf("mpr.NextPart(): got %v, want no error", err)
   229  	}
   230  	defer prt1.Close()
   231  
   232  	if got, want := prt1.Header.Get("Content-Type"), "text/plain"; got != want {
   233  		t.Errorf("prt1.Header.Get(%q): got %q, want %q", "Content-Type", got, want)
   234  	}
   235  
   236  	if got, want := prt1.Header.Get("Content-Range"), "bytes 1-4/10"; got != want {
   237  		t.Errorf("prt1.Header.Get(%q): got %q, want %q", "Content-Range", got, want)
   238  	}
   239  
   240  	prt1b, err := ioutil.ReadAll(prt1)
   241  	if err != nil {
   242  		t.Errorf("ioutil.Readall(prt1): got %v, want no error", err)
   243  	}
   244  
   245  	if got, want := string(prt1b), "1234"; got != want {
   246  		t.Errorf("prt1 body: got %s, want %s", got, want)
   247  	}
   248  
   249  	prt2, err := mpr.NextPart()
   250  	if err != nil {
   251  		t.Fatalf("mpr.NextPart(): got %v, want no error", err)
   252  	}
   253  	defer prt2.Close()
   254  
   255  	if got, want := prt2.Header.Get("Content-Type"), "text/plain"; got != want {
   256  		t.Errorf("prt2.Header.Get(%q): got %q, want %q", "Content-Type", got, want)
   257  	}
   258  
   259  	if got, want := prt2.Header.Get("Content-Range"), "bytes 7-9/10"; got != want {
   260  		t.Errorf("prt2.Header.Get(%q): got %q, want %q", "Content-Range", got, want)
   261  	}
   262  
   263  	prt2b, err := ioutil.ReadAll(prt2)
   264  	if err != io.ErrUnexpectedEOF && err != nil {
   265  		t.Errorf("ioutil.Readall(prt2): got %v, want no error", err)
   266  	}
   267  
   268  	if got, want := string(prt2b), "789"; got != want {
   269  		t.Errorf("prt2 body: got %s, want %s", got, want)
   270  	}
   271  
   272  	_, err = mpr.NextPart()
   273  	if err == nil {
   274  		t.Errorf("mpr.NextPart: want io.EOF, got no error")
   275  	}
   276  	if err != io.EOF {
   277  		t.Errorf("mpr.NextPart: want io.EOF, got %v", err)
   278  	}
   279  }
   280  
   281  func TestModifierFromJSON(t *testing.T) {
   282  	data := base64.StdEncoding.EncodeToString([]byte("data"))
   283  	msg := fmt.Sprintf(`{
   284  	  "body.Modifier":{
   285  		  "scope": ["response"],
   286    	  "contentType": "text/plain",
   287  	  	"body": %q
   288      }
   289  	}`, data)
   290  
   291  	r, err := parse.FromJSON([]byte(msg))
   292  	if err != nil {
   293  		t.Fatalf("parse.FromJSON(): got %v, want no error", err)
   294  	}
   295  
   296  	resmod := r.ResponseModifier()
   297  
   298  	if resmod == nil {
   299  		t.Fatalf("resmod: got nil, want not nil")
   300  	}
   301  
   302  	req, err := http.NewRequest("GET", "/", strings.NewReader(""))
   303  	if err != nil {
   304  		t.Fatalf("NewRequest(): got %v, want no error", err)
   305  	}
   306  
   307  	res := proxyutil.NewResponse(200, nil, req)
   308  	if err := resmod.ModifyResponse(res); err != nil {
   309  		t.Fatalf("resmod.ModifyResponse(): got %v, want no error", err)
   310  	}
   311  
   312  	if got, want := res.Header.Get("Content-Type"), "text/plain"; got != want {
   313  		t.Errorf("res.Header.Get(%q): got %v, want %v", "Content-Type", got, want)
   314  	}
   315  
   316  	got, err := ioutil.ReadAll(res.Body)
   317  	if err != nil {
   318  		t.Fatalf("ioutil.ReadAll(): got %v, want no error", err)
   319  	}
   320  	res.Body.Close()
   321  
   322  	if want := []byte("data"); !bytes.Equal(got, want) {
   323  		t.Errorf("res.Body: got %q, want %q", got, want)
   324  	}
   325  }