github.com/google/martian/v3@v3.3.3/proxy_trafficshaping_test.go (about)

     1  package martian
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"io/ioutil"
     7  	"net"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"strings"
    11  	"testing"
    12  	"time"
    13  
    14  	"github.com/google/martian/v3/log"
    15  	"github.com/google/martian/v3/martiantest"
    16  	"github.com/google/martian/v3/trafficshape"
    17  )
    18  
    19  // Tests that sending data of length 600 bytes with max bandwidth of 100 bytes/s takes
    20  // atleast 4.9s. Uses the Close Connection action to immediately close the connection
    21  // upon the proxy writing 600 bytes. (4.9s ~ 5s = 600b /100b/s - 1s)
    22  func TestConstantThrottleAndClose(t *testing.T) {
    23  	log.SetLevel(log.Info)
    24  
    25  	l, err := net.Listen("tcp", "[::]:0")
    26  	if err != nil {
    27  		t.Fatalf("net.Listen(): got %v, want no error", err)
    28  	}
    29  
    30  	tsl := trafficshape.NewListener(l)
    31  	tsh := trafficshape.NewHandler(tsl)
    32  
    33  	// This is the data to be sent.
    34  	testString := strings.Repeat("0", 600)
    35  
    36  	// Traffic shaping config request.
    37  	jsonString :=
    38  		`{
    39  				"trafficshape": {
    40  						"shapes": [
    41  							{
    42  								"url_regex": "http://example/example",
    43  								"throttles": [
    44  									{
    45  										"bytes": "0-",
    46  										"bandwidth": 100
    47  									}
    48  								],
    49  								"close_connections": [
    50  									{
    51  										"byte": 600,
    52  										"count": 1
    53  									}
    54  								]
    55  						}
    56  						]
    57  				}
    58  		}`
    59  
    60  	tsReq, err := http.NewRequest("POST", "test", bytes.NewBufferString(jsonString))
    61  	rw := httptest.NewRecorder()
    62  	tsh.ServeHTTP(rw, tsReq)
    63  	res := rw.Result()
    64  
    65  	if got, want := res.StatusCode, 200; got != want {
    66  		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
    67  	}
    68  
    69  	p := NewProxy()
    70  	defer p.Close()
    71  
    72  	p.SetRequestModifier(nil)
    73  	p.SetResponseModifier(nil)
    74  
    75  	tr := martiantest.NewTransport()
    76  	p.SetRoundTripper(tr)
    77  	p.SetTimeout(15 * time.Second)
    78  
    79  	tm := martiantest.NewModifier()
    80  
    81  	tm.RequestFunc(func(req *http.Request) {
    82  		ctx := NewContext(req)
    83  		ctx.SkipRoundTrip()
    84  	})
    85  
    86  	tm.ResponseFunc(func(res *http.Response) {
    87  		res.StatusCode = http.StatusOK
    88  		res.Body = ioutil.NopCloser(bytes.NewBufferString(testString))
    89  	})
    90  
    91  	p.SetRequestModifier(tm)
    92  	p.SetResponseModifier(tm)
    93  
    94  	go p.Serve(tsl)
    95  
    96  	c1 := make(chan string)
    97  	conn, err := net.Dial("tcp", l.Addr().String())
    98  	defer conn.Close()
    99  	if err != nil {
   100  		t.Fatalf("net.Dial(): got %v, want no error", err)
   101  	}
   102  
   103  	go func() {
   104  		req, err := http.NewRequest("GET", "http://example/example", nil)
   105  		if err != nil {
   106  			t.Fatalf("http.NewRequest(): got %v, want no error", err)
   107  		}
   108  
   109  		if err := req.WriteProxy(conn); err != nil {
   110  			t.Fatalf("req.WriteProxy(): got %v, want no error", err)
   111  		}
   112  
   113  		res, err := http.ReadResponse(bufio.NewReader(conn), req)
   114  		if err != nil {
   115  			t.Fatalf("http.ReadResponse(): got %v, want no error", err)
   116  		}
   117  		body, _ := ioutil.ReadAll(res.Body)
   118  		bodystr := string(body)
   119  		c1 <- bodystr
   120  	}()
   121  
   122  	var bodystr string
   123  	select {
   124  	case bodystringc := <-c1:
   125  		t.Errorf("took < 4.9s, should take at least 4.9s")
   126  		bodystr = bodystringc
   127  	case <-time.After(4900 * time.Millisecond):
   128  		bodystringc := <-c1
   129  		bodystr = bodystringc
   130  	}
   131  
   132  	if bodystr != testString {
   133  		t.Errorf("res.Body: got %s, want %s", bodystr, testString)
   134  	}
   135  }
   136  
   137  // Tests that sleeping for 5s and then closing the connection
   138  // upon reading 200 bytes, with a bandwidth of 5000 bytes/s
   139  // takes at least 4.9s, and results in a correctly trimmed
   140  // response body. (200 0s instead of 500 0s)
   141  func TestSleepAndClose(t *testing.T) {
   142  	log.SetLevel(log.Info)
   143  
   144  	l, err := net.Listen("tcp", "[::]:0")
   145  	if err != nil {
   146  		t.Fatalf("net.Listen(): got %v, want no error", err)
   147  	}
   148  
   149  	tsl := trafficshape.NewListener(l)
   150  	tsh := trafficshape.NewHandler(tsl)
   151  
   152  	// This is the data to be sent.
   153  	testString := strings.Repeat("0", 500)
   154  
   155  	// Traffic shaping config request.
   156  	jsonString :=
   157  		`{
   158  				"trafficshape": {
   159  						"shapes": [
   160  							{
   161  								"url_regex": "http://example/example",
   162  								"throttles": [
   163  									{
   164  										"bytes": "0-",
   165  										"bandwidth": 5000
   166  									}
   167  								],
   168  								"halts": [
   169  									{
   170  										"byte": 100,
   171  										"duration": 5000,
   172  										"count": 1
   173  									}
   174  								],
   175  								"close_connections": [
   176  									{
   177  										"byte": 200,
   178  										"count": 1
   179  									}
   180  								]
   181  						}
   182  						]
   183  				}
   184  		}`
   185  
   186  	tsReq, err := http.NewRequest("POST", "test", bytes.NewBufferString(jsonString))
   187  	rw := httptest.NewRecorder()
   188  	tsh.ServeHTTP(rw, tsReq)
   189  	res := rw.Result()
   190  
   191  	if got, want := res.StatusCode, 200; got != want {
   192  		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
   193  	}
   194  
   195  	p := NewProxy()
   196  	defer p.Close()
   197  
   198  	p.SetRequestModifier(nil)
   199  	p.SetResponseModifier(nil)
   200  
   201  	tr := martiantest.NewTransport()
   202  	p.SetRoundTripper(tr)
   203  	p.SetTimeout(15 * time.Second)
   204  
   205  	tm := martiantest.NewModifier()
   206  
   207  	tm.RequestFunc(func(req *http.Request) {
   208  		ctx := NewContext(req)
   209  		ctx.SkipRoundTrip()
   210  	})
   211  
   212  	tm.ResponseFunc(func(res *http.Response) {
   213  		res.StatusCode = http.StatusOK
   214  		res.Body = ioutil.NopCloser(bytes.NewBufferString(testString))
   215  	})
   216  
   217  	p.SetRequestModifier(tm)
   218  	p.SetResponseModifier(tm)
   219  
   220  	go p.Serve(tsl)
   221  
   222  	c1 := make(chan string)
   223  	conn, err := net.Dial("tcp", l.Addr().String())
   224  	defer conn.Close()
   225  	if err != nil {
   226  		t.Fatalf("net.Dial(): got %v, want no error", err)
   227  	}
   228  
   229  	go func() {
   230  		req, err := http.NewRequest("GET", "http://example/example", nil)
   231  		if err != nil {
   232  			t.Fatalf("http.NewRequest(): got %v, want no error", err)
   233  		}
   234  
   235  		if err := req.WriteProxy(conn); err != nil {
   236  			t.Fatalf("req.WriteProxy(): got %v, want no error", err)
   237  		}
   238  
   239  		res, err := http.ReadResponse(bufio.NewReader(conn), req)
   240  		if err != nil {
   241  			t.Fatalf("http.ReadResponse(): got %v, want no error", err)
   242  		}
   243  		body, _ := ioutil.ReadAll(res.Body)
   244  		bodystr := string(body)
   245  		c1 <- bodystr
   246  	}()
   247  
   248  	var bodystr string
   249  	select {
   250  	case bodystringc := <-c1:
   251  		t.Errorf("took < 4.9s, should take at least 4.9s")
   252  		bodystr = bodystringc
   253  	case <-time.After(4900 * time.Millisecond):
   254  		bodystringc := <-c1
   255  		bodystr = bodystringc
   256  	}
   257  
   258  	if want := strings.Repeat("0", 200); bodystr != want {
   259  		t.Errorf("res.Body: got %s, want %s", bodystr, want)
   260  	}
   261  }
   262  
   263  // Similar to TestConstantThrottleAndClose, except that it applies
   264  // the throttle only in a specific byte range, and modifies the
   265  // the response to lie in the byte range.
   266  func TestConstantThrottleAndCloseByteRange(t *testing.T) {
   267  	log.SetLevel(log.Info)
   268  
   269  	l, err := net.Listen("tcp", "[::]:0")
   270  	if err != nil {
   271  		t.Fatalf("net.Listen(): got %v, want no error", err)
   272  	}
   273  
   274  	tsl := trafficshape.NewListener(l)
   275  	tsh := trafficshape.NewHandler(tsl)
   276  
   277  	// This is the data to be sent.
   278  	testString := strings.Repeat("0", 600)
   279  
   280  	// Traffic shaping config request.
   281  	jsonString :=
   282  		`{
   283  				"trafficshape": {
   284  						"shapes": [
   285  							{
   286  								"url_regex": "http://example/example",
   287  								"throttles": [
   288  									{
   289  										"bytes": "500-",
   290  										"bandwidth": 100
   291  									}
   292  								],
   293  								"close_connections": [
   294  									{
   295  										"byte": 1100,
   296  										"count": 1
   297  									}
   298  								]
   299  						}
   300  						]
   301  				}
   302  		}`
   303  
   304  	tsReq, err := http.NewRequest("POST", "test", bytes.NewBufferString(jsonString))
   305  	rw := httptest.NewRecorder()
   306  	tsh.ServeHTTP(rw, tsReq)
   307  	res := rw.Result()
   308  
   309  	if got, want := res.StatusCode, 200; got != want {
   310  		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
   311  	}
   312  
   313  	p := NewProxy()
   314  	defer p.Close()
   315  
   316  	p.SetRequestModifier(nil)
   317  	p.SetResponseModifier(nil)
   318  
   319  	tr := martiantest.NewTransport()
   320  	p.SetRoundTripper(tr)
   321  	p.SetTimeout(15 * time.Second)
   322  
   323  	tm := martiantest.NewModifier()
   324  
   325  	tm.RequestFunc(func(req *http.Request) {
   326  		ctx := NewContext(req)
   327  		ctx.SkipRoundTrip()
   328  	})
   329  
   330  	tm.ResponseFunc(func(res *http.Response) {
   331  		res.StatusCode = http.StatusPartialContent
   332  		res.Body = ioutil.NopCloser(bytes.NewBufferString(testString))
   333  		res.Header.Set("Content-Range", "bytes 500-1100/1100")
   334  	})
   335  
   336  	p.SetRequestModifier(tm)
   337  	p.SetResponseModifier(tm)
   338  
   339  	go p.Serve(tsl)
   340  
   341  	c1 := make(chan string)
   342  	conn, err := net.Dial("tcp", l.Addr().String())
   343  	defer conn.Close()
   344  	if err != nil {
   345  		t.Fatalf("net.Dial(): got %v, want no error", err)
   346  	}
   347  
   348  	go func() {
   349  		req, err := http.NewRequest("GET", "http://example/example", nil)
   350  		if err != nil {
   351  			t.Fatalf("http.NewRequest(): got %v, want no error", err)
   352  		}
   353  
   354  		if err := req.WriteProxy(conn); err != nil {
   355  			t.Fatalf("req.WriteProxy(): got %v, want no error", err)
   356  		}
   357  
   358  		res, err := http.ReadResponse(bufio.NewReader(conn), req)
   359  		if err != nil {
   360  			t.Fatalf("http.ReadResponse(): got %v, want no error", err)
   361  		}
   362  
   363  		body, _ := ioutil.ReadAll(res.Body)
   364  		bodystr := string(body)
   365  		c1 <- bodystr
   366  	}()
   367  
   368  	var bodystr string
   369  	select {
   370  	case bodystringc := <-c1:
   371  		t.Errorf("took < 4.9s, should take at least 4.9s")
   372  		bodystr = bodystringc
   373  	case <-time.After(4900 * time.Millisecond):
   374  		bodystringc := <-c1
   375  		bodystr = bodystringc
   376  	}
   377  
   378  	if bodystr != testString {
   379  		t.Errorf("res.Body: got %s, want %s", bodystr, testString)
   380  	}
   381  }
   382  
   383  // Opens up 5 concurrent connections, and sets the
   384  // max global bandwidth for the url regex to be 250b/s.
   385  // Every connection tries to read 500b of data, but since
   386  // the global bandwidth for the particular regex is 250,
   387  // it should take at least 5 * 500b / 250b/s -1s = 9s to read
   388  // everything.
   389  func TestMaxBandwidth(t *testing.T) {
   390  	log.SetLevel(log.Info)
   391  
   392  	l, err := net.Listen("tcp", "[::]:0")
   393  	if err != nil {
   394  		t.Fatalf("net.Listen(): got %v, want no error", err)
   395  	}
   396  
   397  	tsl := trafficshape.NewListener(l)
   398  	tsh := trafficshape.NewHandler(tsl)
   399  
   400  	// This is the data to be sent.
   401  	testString := strings.Repeat("0", 500)
   402  
   403  	// Traffic shaping config request.
   404  	jsonString :=
   405  		`{
   406  				"trafficshape": {
   407  						"shapes": [
   408  							{
   409  								"url_regex": "http://example/example",
   410  								"max_global_bandwidth": 250,
   411  								"close_connections": [
   412  									{
   413  										"byte": 500,
   414  										"count": 5
   415  									}
   416  								]
   417  						}
   418  						]
   419  				}
   420  		}`
   421  
   422  	tsReq, err := http.NewRequest("POST", "test", bytes.NewBufferString(jsonString))
   423  	rw := httptest.NewRecorder()
   424  	tsh.ServeHTTP(rw, tsReq)
   425  	res := rw.Result()
   426  
   427  	if got, want := res.StatusCode, 200; got != want {
   428  		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
   429  	}
   430  
   431  	p := NewProxy()
   432  	defer p.Close()
   433  
   434  	p.SetRequestModifier(nil)
   435  	p.SetResponseModifier(nil)
   436  
   437  	tr := martiantest.NewTransport()
   438  	p.SetRoundTripper(tr)
   439  	p.SetTimeout(20 * time.Second)
   440  
   441  	tm := martiantest.NewModifier()
   442  
   443  	tm.RequestFunc(func(req *http.Request) {
   444  		ctx := NewContext(req)
   445  		ctx.SkipRoundTrip()
   446  	})
   447  
   448  	tm.ResponseFunc(func(res *http.Response) {
   449  		res.StatusCode = http.StatusOK
   450  		res.Body = ioutil.NopCloser(bytes.NewBufferString(testString))
   451  	})
   452  
   453  	p.SetRequestModifier(tm)
   454  	p.SetResponseModifier(tm)
   455  
   456  	go p.Serve(tsl)
   457  
   458  	numChannels := 5
   459  
   460  	channels := make([]chan string, numChannels)
   461  
   462  	for i := 0; i < numChannels; i++ {
   463  		channels[i] = make(chan string)
   464  	}
   465  
   466  	for i := 0; i < numChannels; i++ {
   467  		go func(i int) {
   468  			conn, err := net.Dial("tcp", l.Addr().String())
   469  			defer conn.Close()
   470  			if err != nil {
   471  				t.Fatalf("net.Dial(): got %v, want no error", err)
   472  			}
   473  			req, err := http.NewRequest("GET", "http://example/example", nil)
   474  			if err != nil {
   475  				t.Fatalf("http.NewRequest(): got %v, want no error", err)
   476  			}
   477  
   478  			if err := req.WriteProxy(conn); err != nil {
   479  				t.Fatalf("req.WriteProxy(): got %v, want no error", err)
   480  			}
   481  
   482  			res, err := http.ReadResponse(bufio.NewReader(conn), req)
   483  			if err != nil {
   484  				t.Fatalf("http.ReadResponse(): got %v, want no error", err)
   485  			}
   486  
   487  			body, _ := ioutil.ReadAll(res.Body)
   488  			bodystr := string(body)
   489  
   490  			if i != 0 {
   491  				<-channels[i-1]
   492  			}
   493  
   494  			channels[i] <- bodystr
   495  		}(i)
   496  	}
   497  
   498  	var bodystr string
   499  	select {
   500  	case bodystringc := <-channels[numChannels-1]:
   501  		t.Errorf("took < 8.9s, should take at least 8.9s")
   502  		bodystr = bodystringc
   503  	case <-time.After(8900 * time.Millisecond):
   504  		bodystringc := <-channels[numChannels-1]
   505  		bodystr = bodystringc
   506  	}
   507  
   508  	if bodystr != testString {
   509  		t.Errorf("res.Body: got %s, want %s", bodystr, testString)
   510  	}
   511  }
   512  
   513  // Makes 2 requests, with the first one having a byte range starting
   514  // at  byte 250, and adds a close connection action at byte 450.
   515  // The first request should hit the action sooner,
   516  // and delete it. The second request should read the whole
   517  // data (500b)
   518  func TestConcurrentResponseActions(t *testing.T) {
   519  	log.SetLevel(log.Info)
   520  
   521  	l, err := net.Listen("tcp", "[::]:0")
   522  	if err != nil {
   523  		t.Fatalf("net.Listen(): got %v, want no error", err)
   524  	}
   525  
   526  	tsl := trafficshape.NewListener(l)
   527  	tsh := trafficshape.NewHandler(tsl)
   528  
   529  	// This is the data to be sent.
   530  	testString := strings.Repeat("0", 500)
   531  
   532  	// Traffic shaping config request.
   533  	jsonString :=
   534  		`{
   535  				"trafficshape": {
   536  						"shapes": [
   537  							{
   538  								"url_regex": "http://example/example",
   539  								"throttles": [
   540  									{
   541  										"bytes": "-",
   542  										"bandwidth": 250
   543  									}
   544  								],
   545  								"close_connections": [
   546  									{
   547  										"byte": 450,
   548  										"count": 1
   549  									},
   550  									{
   551  										"byte": 500,
   552  										"count": 1
   553  									}
   554  								]
   555  						}
   556  						]
   557  				}
   558  		}`
   559  
   560  	tsReq, err := http.NewRequest("POST", "test", bytes.NewBufferString(jsonString))
   561  	rw := httptest.NewRecorder()
   562  	tsh.ServeHTTP(rw, tsReq)
   563  	res := rw.Result()
   564  
   565  	if got, want := res.StatusCode, 200; got != want {
   566  		t.Fatalf("res.StatusCode: got %d, want %d", got, want)
   567  	}
   568  
   569  	p := NewProxy()
   570  	defer p.Close()
   571  
   572  	p.SetRequestModifier(nil)
   573  	p.SetResponseModifier(nil)
   574  
   575  	tr := martiantest.NewTransport()
   576  	p.SetRoundTripper(tr)
   577  	p.SetTimeout(20 * time.Second)
   578  
   579  	tm := martiantest.NewModifier()
   580  
   581  	tm.RequestFunc(func(req *http.Request) {
   582  		ctx := NewContext(req)
   583  		ctx.SkipRoundTrip()
   584  	})
   585  
   586  	tm.ResponseFunc(func(res *http.Response) {
   587  		cr := res.Request.Header.Get("ContentRange")
   588  		res.StatusCode = http.StatusOK
   589  		res.Body = ioutil.NopCloser(bytes.NewBufferString(testString))
   590  		if cr != "" {
   591  			res.StatusCode = http.StatusPartialContent
   592  			res.Header.Set("Content-Range", cr)
   593  		}
   594  	})
   595  
   596  	p.SetRequestModifier(tm)
   597  	p.SetResponseModifier(tm)
   598  
   599  	go p.Serve(tsl)
   600  
   601  	c1 := make(chan string)
   602  	c2 := make(chan string)
   603  
   604  	go func() {
   605  		conn, err := net.Dial("tcp", l.Addr().String())
   606  		defer conn.Close()
   607  		if err != nil {
   608  			t.Fatalf("net.Dial(): got %v, want no error", err)
   609  		}
   610  		req, err := http.NewRequest("GET", "http://example/example", nil)
   611  		req.Header.Set("ContentRange", "bytes 250-1000/1000")
   612  		if err != nil {
   613  			t.Fatalf("http.NewRequest(): got %v, want no error", err)
   614  		}
   615  
   616  		if err := req.WriteProxy(conn); err != nil {
   617  			t.Fatalf("req.WriteProxy(): got %v, want no error", err)
   618  		}
   619  
   620  		res, err := http.ReadResponse(bufio.NewReader(conn), req)
   621  		if err != nil {
   622  			t.Fatalf("http.ReadResponse(): got %v, want no error", err)
   623  		}
   624  
   625  		body, _ := ioutil.ReadAll(res.Body)
   626  		bodystr := string(body)
   627  		c1 <- bodystr
   628  	}()
   629  
   630  	go func() {
   631  		conn, err := net.Dial("tcp", l.Addr().String())
   632  		defer conn.Close()
   633  		if err != nil {
   634  			t.Fatalf("net.Dial(): got %v, want no error", err)
   635  		}
   636  		req, err := http.NewRequest("GET", "http://example/example", nil)
   637  		if err != nil {
   638  			t.Fatalf("http.NewRequest(): got %v, want no error", err)
   639  		}
   640  
   641  		if err := req.WriteProxy(conn); err != nil {
   642  			t.Fatalf("req.WriteProxy(): got %v, want no error", err)
   643  		}
   644  
   645  		res, err := http.ReadResponse(bufio.NewReader(conn), req)
   646  		if err != nil {
   647  			t.Fatalf("http.ReadResponse(): got %v, want no error", err)
   648  		}
   649  
   650  		body, _ := ioutil.ReadAll(res.Body)
   651  		bodystr := string(body)
   652  		c2 <- bodystr
   653  	}()
   654  
   655  	bodystr1 := <-c1
   656  	bodystr2 := <-c2
   657  
   658  	if want1 := strings.Repeat("0", 200); bodystr1 != want1 {
   659  		t.Errorf("res.Body: got %s, want %s", bodystr1, want1)
   660  	}
   661  	if want2 := strings.Repeat("0", 500); bodystr2 != want2 {
   662  		t.Errorf("res.Body: got %s, want %s", bodystr2, want2)
   663  	}
   664  }