github.com/xmidt-org/webpa-common@v1.11.9/device/drain/cancel_test.go (about)

     1  package drain
     2  
     3  import (
     4  	"context"
     5  	"net/http"
     6  	"net/http/httptest"
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/stretchr/testify/assert"
    11  )
    12  
    13  func testCancelNotActive(t *testing.T) {
    14  	var (
    15  		assert = assert.New(t)
    16  
    17  		d      = new(mockDrainer)
    18  		cancel = Cancel{d}
    19  
    20  		response = httptest.NewRecorder()
    21  		request  = httptest.NewRequest("GET", "/", nil)
    22  	)
    23  
    24  	d.On("Cancel").Return((<-chan struct{})(nil), ErrNotActive).Once()
    25  	cancel.ServeHTTP(response, request)
    26  	assert.Equal(http.StatusConflict, response.Code)
    27  
    28  	d.AssertExpectations(t)
    29  }
    30  
    31  func testCancelSuccess(t *testing.T) {
    32  	var (
    33  		assert = assert.New(t)
    34  
    35  		d          = new(mockDrainer)
    36  		cancel     = Cancel{d}
    37  		done       = make(chan struct{})
    38  		cancelWait = make(chan time.Time)
    39  		serveHTTP  = make(chan struct{})
    40  
    41  		response = httptest.NewRecorder()
    42  		request  = httptest.NewRequest("GET", "/", nil)
    43  	)
    44  
    45  	d.On("Cancel").WaitUntil(cancelWait).Return((<-chan struct{})(done), error(nil)).Once()
    46  
    47  	go func() {
    48  		defer close(serveHTTP)
    49  		cancel.ServeHTTP(response, request)
    50  	}()
    51  
    52  	cancelWait <- time.Time{}
    53  	close(done)
    54  	select {
    55  	case <-serveHTTP:
    56  		// passing
    57  	case <-time.After(5 * time.Second):
    58  		assert.Fail("ServeHTTP did not return")
    59  		return
    60  	}
    61  
    62  	assert.Equal(http.StatusOK, response.Code)
    63  	d.AssertExpectations(t)
    64  }
    65  
    66  func testCancelTimeout(t *testing.T) {
    67  	var (
    68  		assert = assert.New(t)
    69  
    70  		d          = new(mockDrainer)
    71  		cancel     = Cancel{d}
    72  		done       = make(chan struct{})
    73  		cancelWait = make(chan time.Time)
    74  		serveHTTP  = make(chan struct{})
    75  
    76  		ctx, ctxCancel = context.WithCancel(context.Background())
    77  		response       = httptest.NewRecorder()
    78  		request        = httptest.NewRequest("GET", "/", nil).WithContext(ctx)
    79  	)
    80  
    81  	d.On("Cancel").WaitUntil(cancelWait).Return((<-chan struct{})(done), error(nil)).Once()
    82  
    83  	go func() {
    84  		defer close(serveHTTP)
    85  		cancel.ServeHTTP(response, request)
    86  	}()
    87  
    88  	cancelWait <- time.Time{}
    89  	ctxCancel()
    90  	select {
    91  	case <-serveHTTP:
    92  		// passing
    93  	case <-time.After(5 * time.Second):
    94  		assert.Fail("ServeHTTP did not return")
    95  		return
    96  	}
    97  
    98  	assert.Equal(http.StatusOK, response.Code)
    99  	d.AssertExpectations(t)
   100  }
   101  
   102  func TestCancel(t *testing.T) {
   103  	t.Run("NotActive", testCancelNotActive)
   104  	t.Run("Success", testCancelSuccess)
   105  	t.Run("Timeout", testCancelTimeout)
   106  }