github.com/xmidt-org/webpa-common@v1.11.9/device/drain/start_test.go (about)

     1  package drain
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/xmidt-org/webpa-common/device/devicegate"
    13  
    14  	"github.com/stretchr/testify/assert"
    15  	"github.com/xmidt-org/webpa-common/logging"
    16  )
    17  
    18  func testStartServeHTTPDefaultLogger(t *testing.T) {
    19  	var (
    20  		assert = assert.New(t)
    21  
    22  		d                     = new(mockDrainer)
    23  		done  <-chan struct{} = make(chan struct{})
    24  		start                 = Start{d}
    25  
    26  		response = httptest.NewRecorder()
    27  		request  = httptest.NewRequest("POST", "/", nil)
    28  	)
    29  
    30  	d.On("Start", Job{}).Return(done, Job{Count: 126, Percent: 10, Rate: 12, Tick: 5 * time.Minute}, error(nil))
    31  	start.ServeHTTP(response, request)
    32  	assert.Equal(http.StatusOK, response.Code)
    33  	assert.Equal("application/json", response.Header().Get("Content-Type"))
    34  	assert.JSONEq(
    35  		`{"count": 126, "percent": 10, "rate": 12, "tick": "5m0s"}`,
    36  		response.Body.String(),
    37  	)
    38  
    39  	d.AssertExpectations(t)
    40  }
    41  
    42  func testStartServeHTTPValid(t *testing.T) {
    43  	testData := []struct {
    44  		uri      string
    45  		expected Job
    46  	}{
    47  		{
    48  			uri:      "/foo",
    49  			expected: Job{},
    50  		},
    51  		{
    52  			uri:      "/foo?count=100",
    53  			expected: Job{Count: 100},
    54  		},
    55  		{
    56  			uri:      "/foo?rate=10",
    57  			expected: Job{Rate: 10},
    58  		},
    59  		{
    60  			uri:      "/foo?rate=23&tick=1m",
    61  			expected: Job{Rate: 23, Tick: time.Minute},
    62  		},
    63  		{
    64  			uri:      "/foo?count=22&rate=10&tick=20s",
    65  			expected: Job{Count: 22, Rate: 10, Tick: 20 * time.Second},
    66  		},
    67  	}
    68  
    69  	for _, record := range testData {
    70  		t.Run(record.uri, func(t *testing.T) {
    71  			var (
    72  				assert = assert.New(t)
    73  
    74  				d                     = new(mockDrainer)
    75  				done  <-chan struct{} = make(chan struct{})
    76  				start                 = Start{d}
    77  
    78  				ctx      = logging.WithLogger(context.Background(), logging.NewTestLogger(nil, t))
    79  				response = httptest.NewRecorder()
    80  				request  = httptest.NewRequest("POST", record.uri, nil).WithContext(ctx)
    81  			)
    82  
    83  			d.On("Start", record.expected).Return(done, Job{Count: 47192, Percent: 57, Rate: 500, Tick: 37 * time.Second, DrainFilter: record.expected.DrainFilter}, error(nil)).Once()
    84  			start.ServeHTTP(response, request)
    85  			assert.Equal(http.StatusOK, response.Code)
    86  			assert.Equal("application/json", response.Header().Get("Content-Type"))
    87  			assert.JSONEq(
    88  				`{"count": 47192, "percent": 57, "rate": 500, "tick": "37s"}`,
    89  				response.Body.String(),
    90  			)
    91  			d.AssertExpectations(t)
    92  		})
    93  	}
    94  }
    95  
    96  func testStartServeHTTPWithBody(t *testing.T) {
    97  	df := &drainFilter{
    98  		filter: &devicegate.FilterGate{
    99  			FilterStore: devicegate.FilterStore(map[string]devicegate.Set{
   100  				"test": &devicegate.FilterSet{Set: map[interface{}]bool{
   101  					"test1": true,
   102  					"test2": true,
   103  				}},
   104  			}),
   105  		},
   106  		filterRequest: devicegate.FilterRequest{
   107  			Key:    "test",
   108  			Values: []interface{}{"test1", "test2"},
   109  		},
   110  	}
   111  
   112  	testData := []struct {
   113  		description        string
   114  		body               []byte
   115  		expected           Job
   116  		expectedJSON       string
   117  		expectedStatusCode int
   118  	}{
   119  		{
   120  			description:        "Success with body",
   121  			body:               []byte(`{"key": "test", "values": ["test1", "test2"]}`),
   122  			expected:           Job{Count: 22, Rate: 10, Tick: 20 * time.Second, DrainFilter: df},
   123  			expectedJSON:       `{"count": 47192, "percent": 57, "rate": 500, "tick": "37s", "filter":{"key": "test", "values": ["test1", "test2"]}}`,
   124  			expectedStatusCode: http.StatusOK,
   125  		},
   126  		{
   127  			description:        "Unmarshal error",
   128  			body:               []byte(`this is not a filter request`),
   129  			expected:           Job{Count: 22, Rate: 10, Tick: 20 * time.Second},
   130  			expectedStatusCode: http.StatusBadRequest,
   131  		},
   132  		{
   133  			description:        "Empty Body",
   134  			body:               []byte(`{}`),
   135  			expected:           Job{Count: 22, Rate: 10, Tick: 20 * time.Second},
   136  			expectedStatusCode: http.StatusOK,
   137  		},
   138  		{
   139  			description:        "No value field",
   140  			body:               []byte(`{"key": "test"}`),
   141  			expected:           Job{Count: 22, Rate: 10, Tick: 20 * time.Second},
   142  			expectedStatusCode: http.StatusOK,
   143  		},
   144  		{
   145  			description:        "No key",
   146  			body:               []byte(`{"values": ["test1", "test2"]}`),
   147  			expected:           Job{Count: 22, Rate: 10, Tick: 20 * time.Second},
   148  			expectedStatusCode: http.StatusOK,
   149  		},
   150  		{
   151  			description:        "Empty values array",
   152  			body:               []byte(`{"key": "test", "values": []}`),
   153  			expected:           Job{Count: 22, Rate: 10, Tick: 20 * time.Second},
   154  			expectedStatusCode: http.StatusOK,
   155  		},
   156  	}
   157  
   158  	for _, record := range testData {
   159  		t.Run(record.description, func(t *testing.T) {
   160  			var (
   161  				assert = assert.New(t)
   162  
   163  				d                     = new(mockDrainer)
   164  				done  <-chan struct{} = make(chan struct{})
   165  				start                 = Start{d}
   166  
   167  				ctx      = logging.WithLogger(context.Background(), logging.NewTestLogger(nil, t))
   168  				response = httptest.NewRecorder()
   169  				request  = httptest.NewRequest("POST", "/foo?count=22&rate=10&tick=20s", bytes.NewBuffer(record.body)).WithContext(ctx)
   170  			)
   171  
   172  			if record.expectedStatusCode == http.StatusOK {
   173  				d.On("Start", record.expected).Return(done, Job{Count: 47192, Percent: 57, Rate: 500, Tick: 37 * time.Second, DrainFilter: record.expected.DrainFilter}, error(nil)).Once()
   174  			}
   175  			start.ServeHTTP(response, request)
   176  			assert.Equal(record.expectedStatusCode, response.Code)
   177  			assert.Equal("application/json", response.Header().Get("Content-Type"))
   178  			if record.expectedStatusCode == http.StatusOK {
   179  				if len(record.expectedJSON) == 0 {
   180  					assert.JSONEq(
   181  						`{"count": 47192, "percent": 57, "rate": 500, "tick": "37s"}`,
   182  						response.Body.String(),
   183  					)
   184  				} else {
   185  					assert.JSONEq(record.expectedJSON, response.Body.String())
   186  				}
   187  			}
   188  
   189  			d.AssertExpectations(t)
   190  		})
   191  	}
   192  
   193  }
   194  
   195  func testStartServeHTTPParseFormError(t *testing.T) {
   196  	var (
   197  		assert = assert.New(t)
   198  
   199  		d     = new(mockDrainer)
   200  		start = Start{d}
   201  
   202  		ctx      = logging.WithLogger(context.Background(), logging.NewTestLogger(nil, t))
   203  		response = httptest.NewRecorder()
   204  		request  = httptest.NewRequest("POST", "/foo?%TT*&&", nil).WithContext(ctx)
   205  	)
   206  
   207  	start.ServeHTTP(response, request)
   208  	assert.Equal(http.StatusBadRequest, response.Code)
   209  	d.AssertExpectations(t)
   210  }
   211  
   212  func testStartServeHTTPInvalidQuery(t *testing.T) {
   213  	var (
   214  		assert = assert.New(t)
   215  
   216  		d     = new(mockDrainer)
   217  		start = Start{d}
   218  
   219  		ctx      = logging.WithLogger(context.Background(), logging.NewTestLogger(nil, t))
   220  		response = httptest.NewRecorder()
   221  		request  = httptest.NewRequest("POST", "/foo?count=asdf", nil).WithContext(ctx)
   222  	)
   223  
   224  	start.ServeHTTP(response, request)
   225  	assert.Equal(http.StatusBadRequest, response.Code)
   226  	d.AssertExpectations(t)
   227  }
   228  
   229  func testStartServeHTTPStartError(t *testing.T) {
   230  	var (
   231  		assert = assert.New(t)
   232  
   233  		d             = new(mockDrainer)
   234  		done          <-chan struct{}
   235  		start         = Start{d}
   236  		expectedError = errors.New("expected")
   237  
   238  		ctx      = logging.WithLogger(context.Background(), logging.NewTestLogger(nil, t))
   239  		response = httptest.NewRecorder()
   240  		request  = httptest.NewRequest("POST", "/foo?count=100", nil).WithContext(ctx)
   241  	)
   242  
   243  	d.On("Start", Job{Count: 100}).Return(done, Job{}, expectedError).Once()
   244  	start.ServeHTTP(response, request)
   245  	assert.Equal(http.StatusConflict, response.Code)
   246  	d.AssertExpectations(t)
   247  }
   248  
   249  func TestStart(t *testing.T) {
   250  	t.Run("ServeHTTP", func(t *testing.T) {
   251  		t.Run("DefaultLogger", testStartServeHTTPDefaultLogger)
   252  		t.Run("Valid", testStartServeHTTPValid)
   253  		t.Run("WithBody", testStartServeHTTPWithBody)
   254  		t.Run("ParseFormError", testStartServeHTTPParseFormError)
   255  		t.Run("InvalidQuery", testStartServeHTTPInvalidQuery)
   256  		t.Run("StartError", testStartServeHTTPStartError)
   257  	})
   258  }