github.com/jxgolibs/go-oauth2-server@v1.0.1/test-util/helpers.go (about)

     1  package testutil
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"io"
     7  	"log"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"net/url"
    11  	"strings"
    12  	"testing"
    13  
    14  	"github.com/RichardKnop/go-oauth2-server/util/response"
    15  	"github.com/RichardKnop/jsonhal"
    16  	"github.com/gorilla/mux"
    17  	"github.com/stretchr/testify/assert"
    18  )
    19  
    20  // TestGetErrorExpectedResponse ...
    21  func TestGetErrorExpectedResponse(t *testing.T, router *mux.Router, url, accessToken, msg string, code int, assertExpectations func()) {
    22  	TestErrorExpectedResponse(t, router, "GET", url, nil, accessToken, msg, code, assertExpectations)
    23  }
    24  
    25  // TestPutErrorExpectedResponse ...
    26  func TestPutErrorExpectedResponse(t *testing.T, router *mux.Router, url string, data io.Reader, accessToken, msg string, code int, assertExpectations func()) {
    27  	TestErrorExpectedResponse(t, router, "PUT", url, data, accessToken, msg, code, assertExpectations)
    28  }
    29  
    30  // TestPostErrorExpectedResponse ...
    31  func TestPostErrorExpectedResponse(t *testing.T, router *mux.Router, url string, data io.Reader, accessToken, msg string, code int, assertExpectations func()) {
    32  	TestErrorExpectedResponse(t, router, "POST", url, data, accessToken, msg, code, assertExpectations)
    33  }
    34  
    35  // TestErrorExpectedResponse is the generic test code for testing for a bad response
    36  func TestErrorExpectedResponse(t *testing.T, router *mux.Router, method, url string, data io.Reader, accessToken, msg string, code int, assertExpectations func()) {
    37  	// Prepare a request
    38  	r, err := http.NewRequest(
    39  		method,
    40  		url,
    41  		data,
    42  	)
    43  	assert.NoError(t, err)
    44  
    45  	// Optionally add a bearer token to headers
    46  	if accessToken != "" {
    47  		r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
    48  	}
    49  
    50  	// And serve the request
    51  	w := httptest.NewRecorder()
    52  	router.ServeHTTP(w, r)
    53  
    54  	TestResponseForError(t, w, msg, code)
    55  
    56  	assertExpectations()
    57  }
    58  
    59  // TestResponseForError tests a response w to see if it returned an error msg with http code
    60  func TestResponseForError(t *testing.T, w *httptest.ResponseRecorder, msg string, code int) {
    61  	if code != w.Code {
    62  		log.Print(w.Body.String())
    63  	}
    64  	assert.Equal(
    65  		t,
    66  		code,
    67  		w.Code,
    68  		fmt.Sprintf("Expected a %d response but got %d", code, w.Code),
    69  	)
    70  	assert.NotNil(t, w)
    71  	TestResponseBody(t, w, getErrorJSON(msg))
    72  }
    73  
    74  // TestEmptyResponse tests an empty 204 response
    75  func TestEmptyResponse(t *testing.T, w *httptest.ResponseRecorder) {
    76  	assert.Equal(t, 204, w.Code)
    77  	TestResponseBody(t, w, "")
    78  }
    79  
    80  // TestResponseObject tests response body is equal to expected object in JSON form
    81  func TestResponseObject(t *testing.T, w *httptest.ResponseRecorder, expected interface{}, code int) {
    82  	if code != w.Code {
    83  		log.Print(w.Body.String())
    84  	}
    85  	assert.Equal(
    86  		t,
    87  		code,
    88  		w.Code,
    89  		fmt.Sprintf("Expected a %d response but got %d", code, w.Code),
    90  	)
    91  	jsonBytes, err := json.Marshal(expected)
    92  	assert.NoError(t, err)
    93  	assert.NotNil(t, w)
    94  	assert.Equal(
    95  		t,
    96  		string(jsonBytes),
    97  		strings.TrimRight(w.Body.String(), "\n"),
    98  		"Should have returned correct body text",
    99  	)
   100  }
   101  
   102  // TestResponseBody tests response body is equal to expected string
   103  func TestResponseBody(t *testing.T, w *httptest.ResponseRecorder, expected string) {
   104  	assert.Equal(
   105  		t,
   106  		expected,
   107  		strings.TrimRight(w.Body.String(), "\n"),
   108  		"Should have returned correct body text",
   109  	)
   110  }
   111  
   112  // TestListValidResponse ...
   113  func TestListValidResponse(t *testing.T, router *mux.Router, path, entity, accessToken string, items []interface{}, assertExpectations func()) {
   114  	TestListValidResponseWithParams(t, router, path, entity, accessToken, items, assertExpectations, nil)
   115  }
   116  
   117  // TestListValidResponseWithParams tests a list endpoint for a valid response with default settings
   118  func TestListValidResponseWithParams(t *testing.T, router *mux.Router, path, entity, accessToken string, items []interface{}, assertExpectations func(), params map[string]string) {
   119  	u, err := url.Parse(fmt.Sprintf("http://1.2.3.4/v1/%s", path))
   120  	assert.NoError(t, err)
   121  
   122  	// add any params
   123  	for k, v := range params {
   124  		q := u.Query()
   125  		q.Set(k, v)
   126  		u.RawQuery = q.Encode()
   127  	}
   128  
   129  	// Prepare a request
   130  	r, err := http.NewRequest(
   131  		"GET",
   132  		u.String(),
   133  		nil,
   134  	)
   135  	assert.NoError(t, err)
   136  
   137  	if accessToken != "" {
   138  		r.Header.Set("Authorization", fmt.Sprintf("Bearer %s", accessToken))
   139  	}
   140  
   141  	// And serve the request
   142  	w := httptest.NewRecorder()
   143  	router.ServeHTTP(w, r)
   144  
   145  	// Check that the mock object expectations were met
   146  	assertExpectations()
   147  
   148  	// Check the status code
   149  	assert.Equal(t, http.StatusOK, w.Code)
   150  
   151  	baseURI := u.RequestURI()
   152  
   153  	q := u.Query()
   154  	q.Set("page", "1")
   155  	u.RawQuery = q.Encode()
   156  
   157  	pagedURI := u.RequestURI()
   158  
   159  	expected := &response.ListResponse{
   160  		Hal: jsonhal.Hal{
   161  			Links: map[string]*jsonhal.Link{
   162  				"self": {
   163  					Href: baseURI,
   164  				},
   165  				"first": {
   166  					Href: pagedURI,
   167  				},
   168  				"last": {
   169  					Href: pagedURI,
   170  				},
   171  				"prev": new(jsonhal.Link),
   172  				"next": new(jsonhal.Link),
   173  			},
   174  			Embedded: map[string]jsonhal.Embedded{
   175  				entity: jsonhal.Embedded(items),
   176  			},
   177  		},
   178  		Count: uint(len(items)),
   179  		Page:  1,
   180  	}
   181  	expectedJSON, err := json.Marshal(expected)
   182  
   183  	if assert.NoError(t, err, "JSON marshalling failed") {
   184  		TestResponseBody(t, w, string(expectedJSON))
   185  	}
   186  }
   187  
   188  func getErrorJSON(msg string) string {
   189  	return fmt.Sprintf("{\"error\":\"%s\"}", msg)
   190  }