github.com/m3db/m3@v1.5.0/src/x/debug/debug_test.go (about)

     1  // Copyright (c) 2019 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package debug
    22  
    23  import (
    24  	"archive/zip"
    25  	"bytes"
    26  	"errors"
    27  	"fmt"
    28  	"io"
    29  	"math/rand"
    30  	"net/http"
    31  	"net/http/httptest"
    32  	"testing"
    33  
    34  	"github.com/m3db/m3/src/x/instrument"
    35  
    36  	"github.com/stretchr/testify/assert"
    37  	"github.com/stretchr/testify/require"
    38  )
    39  
    40  const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
    41  
    42  func randStringBytes(n int) string {
    43  	b := make([]byte, n)
    44  	for i := range b {
    45  		b[i] = letterBytes[rand.Intn(len(letterBytes))]
    46  	}
    47  	return string(b)
    48  }
    49  
    50  type fakeSource struct {
    51  	called    bool
    52  	shouldErr bool
    53  	content   string
    54  }
    55  
    56  func (f *fakeSource) Write(w io.Writer, _ *http.Request) error {
    57  	f.called = true
    58  	if f.shouldErr {
    59  		return errors.New("bad write")
    60  	}
    61  	w.Write([]byte(f.content))
    62  	return nil
    63  }
    64  
    65  func TestWriteZip(t *testing.T) {
    66  	zipWriter := NewZipWriter(instrument.NewOptions())
    67  	fs1 := &fakeSource{
    68  		content: "content1",
    69  	}
    70  	fs2 := &fakeSource{
    71  		content: "content2",
    72  	}
    73  	fs3 := &fakeSource{
    74  		content: "",
    75  	}
    76  	zipWriter.RegisterSource("test1", fs1)
    77  	zipWriter.RegisterSource("test2", fs2)
    78  	zipWriter.RegisterSource("test3", fs3)
    79  	buff := bytes.NewBuffer([]byte{})
    80  	err := zipWriter.WriteZip(buff, &http.Request{})
    81  
    82  	bytesReader := bytes.NewReader(buff.Bytes())
    83  	readerCloser, zerr := zip.NewReader(bytesReader, int64(len(buff.Bytes())))
    84  
    85  	require.NoError(t, zerr)
    86  	for _, f := range readerCloser.File {
    87  		var expectedContent string
    88  		if f.Name == "test1" {
    89  			expectedContent = "content1"
    90  		} else if f.Name == "test2" {
    91  			expectedContent = "content2"
    92  		} else if f.Name == "test3" {
    93  			expectedContent = ""
    94  		} else {
    95  			t.Errorf("bad filename from archive %s", f.Name)
    96  		}
    97  
    98  		rc, ferr := f.Open()
    99  		require.NoError(t, ferr)
   100  		content := make([]byte, len(expectedContent))
   101  		rc.Read(content)
   102  		require.Equal(t, expectedContent, string(content))
   103  	}
   104  
   105  	require.True(t, fs1.called)
   106  	require.True(t, fs2.called)
   107  	require.NoError(t, err)
   108  	require.NotZero(t, buff.Len())
   109  }
   110  
   111  func TestWriteZipErr(t *testing.T) {
   112  	zipWriter := NewZipWriter(instrument.NewOptions())
   113  	fs := &fakeSource{
   114  		shouldErr: true,
   115  	}
   116  	zipWriter.RegisterSource("test", fs)
   117  	buff := bytes.NewBuffer([]byte{})
   118  	err := zipWriter.WriteZip(buff, &http.Request{})
   119  	require.Error(t, err)
   120  	require.True(t, fs.called)
   121  }
   122  
   123  func TestRegisterSourceSameName(t *testing.T) {
   124  	zipWriter := NewZipWriter(instrument.NewOptions())
   125  	fs := &fakeSource{}
   126  	err := zipWriter.RegisterSource("test", fs)
   127  	require.NoError(t, err)
   128  	err = zipWriter.RegisterSource("test", fs)
   129  	require.Error(t, err)
   130  }
   131  
   132  func TestHTTPEndpoint(t *testing.T) {
   133  	mux := http.NewServeMux()
   134  
   135  	// Randomizing the path here so we avoid multiple tests
   136  	// registering the same endpoint.
   137  	path := fmt.Sprintf("/debug/%s", randStringBytes(10))
   138  
   139  	zw := NewZipWriter(instrument.NewOptions())
   140  	fs1 := &fakeSource{
   141  		content: "test",
   142  	}
   143  	fs2 := &fakeSource{
   144  		content: "bar",
   145  	}
   146  	err := zw.RegisterSource("test", fs1)
   147  	require.NoError(t, err)
   148  	err = zw.RegisterSource("foo", fs2)
   149  	require.NoError(t, err)
   150  
   151  	err = zw.RegisterHandler(path, mux)
   152  	require.NoError(t, err)
   153  
   154  	buf := bytes.NewBuffer([]byte{})
   155  	req, err := http.NewRequest("GET", path, buf)
   156  	require.NoError(t, err)
   157  
   158  	t.Run("TestDownloadZip", func(t *testing.T) {
   159  		rr := httptest.NewRecorder()
   160  		mux.ServeHTTP(rr, req)
   161  
   162  		require.NotZero(t, rr.Body.Len())
   163  		rawResponse := make([]byte, rr.Body.Len())
   164  		n, err := rr.Body.Read(rawResponse)
   165  		require.NoError(t, err)
   166  		require.NotZero(t, n)
   167  		require.Equal(t, rr.Code, http.StatusOK)
   168  
   169  		bytesReader := bytes.NewReader(rawResponse)
   170  		zipReader, err := zip.NewReader(bytesReader, int64(bytesReader.Len()))
   171  		require.NoError(t, err)
   172  		require.NotNil(t, zipReader)
   173  		for _, f := range zipReader.File {
   174  			f := f
   175  			t.Run(f.Name, func(t *testing.T) {
   176  				var expectedContent string
   177  				switch {
   178  				case f.Name == "test":
   179  					expectedContent = "test"
   180  				case f.Name == "foo":
   181  					expectedContent = "bar"
   182  				default:
   183  					t.Errorf("bad filename from archive %s", f.Name)
   184  				}
   185  
   186  				rc, ferr := f.Open()
   187  				require.NoError(t, ferr)
   188  				defer func() {
   189  					require.NoError(t, rc.Close())
   190  				}()
   191  
   192  				content := make([]byte, len(expectedContent))
   193  				_, err = rc.Read(content)
   194  				if assert.Error(t, err) {
   195  					require.Equal(t, err, io.EOF)
   196  				}
   197  				require.Equal(t, expectedContent, string(content))
   198  			})
   199  		}
   200  	})
   201  
   202  	t.Run("TestDownloadZipFail", func(t *testing.T) {
   203  		fs3 := &fakeSource{
   204  			content:   "oh snap",
   205  			shouldErr: true,
   206  		}
   207  		zw.RegisterSource("test2", fs3)
   208  
   209  		rr := httptest.NewRecorder()
   210  		mux.ServeHTTP(rr, req)
   211  
   212  		require.Equal(t, rr.Code, http.StatusInternalServerError)
   213  	})
   214  }