github.com/wangkui503/aero@v1.0.0/Context_test.go (about)

     1  package aero_test
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"io"
     7  	"io/ioutil"
     8  	"net/http"
     9  	"net/http/httptest"
    10  	"strconv"
    11  	"strings"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/aerogo/session"
    16  	jsoniter "github.com/json-iterator/go"
    17  
    18  	"github.com/aerogo/aero"
    19  	"github.com/stretchr/testify/assert"
    20  )
    21  
    22  func TestContextResponseHeader(t *testing.T) {
    23  	app := aero.New()
    24  
    25  	// Register route
    26  	app.Get("/", func(ctx *aero.Context) string {
    27  		ctx.Response().Header().Set("X-Custom", "42")
    28  		return ctx.Text(helloWorld)
    29  	})
    30  
    31  	// Get response
    32  	response := getResponse(app, "/")
    33  
    34  	// Verify response
    35  	assert.Equal(t, http.StatusOK, response.Code)
    36  	assert.Equal(t, helloWorld, response.Body.String())
    37  	assert.Equal(t, "42", response.Header().Get("X-Custom"))
    38  }
    39  
    40  func TestContextError(t *testing.T) {
    41  	app := aero.New()
    42  
    43  	// Register route
    44  	app.Get("/", func(ctx *aero.Context) string {
    45  		return ctx.Error(http.StatusUnauthorized, "Not authorized", errors.New("Not logged in"))
    46  	})
    47  
    48  	app.Get("/explanation-only", func(ctx *aero.Context) string {
    49  		return ctx.Error(http.StatusUnauthorized, "Not authorized", nil)
    50  	})
    51  
    52  	app.Get("/unknown-error", func(ctx *aero.Context) string {
    53  		return ctx.Error(http.StatusUnauthorized)
    54  	})
    55  
    56  	// Verify response with known error
    57  	response := getResponse(app, "/")
    58  	assert.Equal(t, http.StatusUnauthorized, response.Code)
    59  	assert.Contains(t, response.Body.String(), "Not logged in")
    60  
    61  	// Verify response with explanation only
    62  	response = getResponse(app, "/explanation-only")
    63  	assert.Equal(t, http.StatusUnauthorized, response.Code)
    64  	assert.Contains(t, response.Body.String(), "Not authorized")
    65  
    66  	// Verify response with unknown error
    67  	response = getResponse(app, "/unknown-error")
    68  	assert.Equal(t, http.StatusUnauthorized, response.Code)
    69  	assert.Contains(t, response.Body.String(), "Unknown error")
    70  }
    71  
    72  func TestContextURI(t *testing.T) {
    73  	app := aero.New()
    74  
    75  	// Register route
    76  	app.Get("/uri", func(ctx *aero.Context) string {
    77  		return ctx.URI()
    78  	})
    79  
    80  	app.Get("/set-uri", func(ctx *aero.Context) string {
    81  		ctx.SetURI("/hello")
    82  		return ctx.URI()
    83  	})
    84  
    85  	// Verify response with read-only URI
    86  	response := getResponse(app, "/uri")
    87  	assert.Equal(t, http.StatusOK, response.Code)
    88  	assert.Contains(t, response.Body.String(), "/uri")
    89  
    90  	// Verify response with modified URI
    91  	response = getResponse(app, "/set-uri")
    92  	assert.Equal(t, http.StatusOK, response.Code)
    93  	assert.Contains(t, response.Body.String(), "/hello")
    94  }
    95  
    96  func TestContextRealIP(t *testing.T) {
    97  	app := aero.New()
    98  
    99  	// Register route
   100  	app.Get("/ip", func(ctx *aero.Context) string {
   101  		return ctx.RealIP()
   102  	})
   103  
   104  	// Get response
   105  	response := getResponse(app, "/ip")
   106  
   107  	// Verify response
   108  	assert.Equal(t, http.StatusOK, response.Code)
   109  	assert.Contains(t, response.Body.String(), "")
   110  }
   111  
   112  func TestContextSession(t *testing.T) {
   113  	app := aero.New()
   114  
   115  	// Register route
   116  	app.Get("/", func(ctx *aero.Context) string {
   117  		assert.Equal(t, false, ctx.HasSession())
   118  		ctx.Session().Set("custom", helloWorld)
   119  		assert.Equal(t, true, ctx.HasSession())
   120  
   121  		return ctx.Text(ctx.Session().GetString("custom"))
   122  	})
   123  
   124  	// Get response
   125  	response := getResponse(app, "/")
   126  
   127  	// Verify response
   128  	assert.Equal(t, http.StatusOK, response.Code)
   129  	assert.Equal(t, helloWorld, response.Body.String())
   130  }
   131  
   132  func TestContextSessionInvalidCookie(t *testing.T) {
   133  	app := aero.New()
   134  
   135  	// Register route
   136  	app.Get("/", func(ctx *aero.Context) string {
   137  		assert.Equal(t, false, ctx.HasSession())
   138  		ctx.Session().Set("custom", helloWorld)
   139  		assert.Equal(t, true, ctx.HasSession())
   140  
   141  		return ctx.Text(ctx.Session().GetString("custom"))
   142  	})
   143  
   144  	// Create request
   145  	request, _ := http.NewRequest("GET", "/", nil)
   146  	request.Header.Set("Accept-Encoding", "gzip")
   147  	request.Header.Set("Cookie", "sid=invalid")
   148  
   149  	// Get response
   150  	response := httptest.NewRecorder()
   151  	app.Handler().ServeHTTP(response, request)
   152  
   153  	// Verify response
   154  	assert.Equal(t, http.StatusOK, response.Code)
   155  	assert.Equal(t, helloWorld, response.Body.String())
   156  }
   157  
   158  func TestContextSessionValidCookie(t *testing.T) {
   159  	app := aero.New()
   160  
   161  	// Register routes
   162  	app.Get("/1", func(ctx *aero.Context) string {
   163  		assert.Equal(t, false, ctx.HasSession())
   164  		ctx.Session().Set("custom", helloWorld)
   165  		assert.Equal(t, true, ctx.HasSession())
   166  		assert.Equal(t, ctx.Session().GetString("sid"), ctx.Session().ID())
   167  
   168  		return ctx.Text(ctx.Session().GetString("custom"))
   169  	})
   170  
   171  	app.Get("/2", func(ctx *aero.Context) string {
   172  		assert.Equal(t, true, ctx.HasSession())
   173  		assert.Equal(t, ctx.Session().GetString("sid"), ctx.Session().ID())
   174  
   175  		return ctx.Text(ctx.Session().GetString("custom"))
   176  	})
   177  
   178  	app.Get("/3", func(ctx *aero.Context) string {
   179  		assert.Equal(t, ctx.Session().GetString("sid"), ctx.Session().ID())
   180  
   181  		return ctx.Text(ctx.Session().GetString("custom"))
   182  	})
   183  
   184  	// Create request 1
   185  	request1, _ := http.NewRequest("GET", "/1", nil)
   186  
   187  	// Get response 1
   188  	response1 := httptest.NewRecorder()
   189  	app.Handler().ServeHTTP(response1, request1)
   190  
   191  	// Verify response 1
   192  	assert.Equal(t, http.StatusOK, response1.Code)
   193  	assert.Equal(t, helloWorld, response1.Body.String())
   194  
   195  	setCookie := response1.Header().Get("Set-Cookie")
   196  	assert.NotEmpty(t, setCookie)
   197  	assert.Contains(t, setCookie, "sid=")
   198  
   199  	cookieParts := strings.Split(setCookie, ";")
   200  	sidLine := strings.TrimSpace(cookieParts[0])
   201  	sidParts := strings.Split(sidLine, "=")
   202  	sid := sidParts[1]
   203  	assert.True(t, session.IsValidID(sid))
   204  
   205  	// Create request 2
   206  	request2, _ := http.NewRequest("GET", "/2", nil)
   207  	request2.AddCookie(&http.Cookie{
   208  		Name:  "sid",
   209  		Value: sid,
   210  	})
   211  
   212  	// Get response 2
   213  	response2 := httptest.NewRecorder()
   214  	app.Handler().ServeHTTP(response2, request2)
   215  
   216  	// Verify response 2
   217  	assert.Equal(t, http.StatusOK, response2.Code)
   218  	assert.Equal(t, helloWorld, response2.Body.String())
   219  
   220  	// Create request 3
   221  	request3, _ := http.NewRequest("GET", "/3", nil)
   222  	request3.AddCookie(&http.Cookie{
   223  		Name:  "sid",
   224  		Value: sid,
   225  	})
   226  
   227  	// Get response 3
   228  	response3 := httptest.NewRecorder()
   229  	app.Handler().ServeHTTP(response3, request3)
   230  
   231  	// Verify response 3
   232  	assert.Equal(t, http.StatusOK, response3.Code)
   233  	assert.Equal(t, helloWorld, response3.Body.String())
   234  }
   235  
   236  func TestContextContentTypes(t *testing.T) {
   237  	app := aero.New()
   238  
   239  	// Register routes
   240  	app.Get("/json", func(ctx *aero.Context) string {
   241  		return ctx.JSON(app.Config)
   242  	})
   243  
   244  	app.Get("/jsonld", func(ctx *aero.Context) string {
   245  		return ctx.JSONLinkedData(app.Config)
   246  	})
   247  
   248  	app.Get("/html", func(ctx *aero.Context) string {
   249  		return ctx.HTML("<html></html>")
   250  	})
   251  
   252  	app.Get("/css", func(ctx *aero.Context) string {
   253  		return ctx.CSS("body{}")
   254  	})
   255  
   256  	app.Get("/js", func(ctx *aero.Context) string {
   257  		return ctx.JavaScript("console.log(42)")
   258  	})
   259  
   260  	app.Get("/files/*file", func(ctx *aero.Context) string {
   261  		return ctx.File(ctx.Get("file"))
   262  	})
   263  
   264  	// Get responses
   265  	responseJSON := getResponse(app, "/json")
   266  	responseJSONLD := getResponse(app, "/jsonld")
   267  	responseHTML := getResponse(app, "/html")
   268  	responseCSS := getResponse(app, "/css")
   269  	responseJS := getResponse(app, "/js")
   270  	responseFile := getResponse(app, "/files/Application.go")
   271  	responseMediaFile := getResponse(app, "/files/docs/usage.gif")
   272  
   273  	// Verify JSON response
   274  	json, _ := jsoniter.Marshal(app.Config)
   275  	assert.Equal(t, http.StatusOK, responseJSON.Code)
   276  	assert.Equal(t, json, responseJSON.Body.Bytes())
   277  	assert.Contains(t, responseJSON.Header().Get("Content-Type"), "application/json")
   278  
   279  	// Verify JSON+LD response
   280  	assert.Equal(t, http.StatusOK, responseJSONLD.Code)
   281  	assert.Equal(t, json, responseJSONLD.Body.Bytes())
   282  	assert.Contains(t, responseJSONLD.Header().Get("Content-Type"), "application/ld+json")
   283  
   284  	// Verify HTML response
   285  	assert.Equal(t, http.StatusOK, responseHTML.Code)
   286  	assert.Equal(t, "<html></html>", responseHTML.Body.String())
   287  	assert.Contains(t, responseHTML.Header().Get("Content-Type"), "text/html")
   288  
   289  	// Verify CSS response
   290  	assert.Equal(t, http.StatusOK, responseCSS.Code)
   291  	assert.Equal(t, "body{}", responseCSS.Body.String())
   292  	assert.Contains(t, responseCSS.Header().Get("Content-Type"), "text/css")
   293  
   294  	// Verify JS response
   295  	assert.Equal(t, http.StatusOK, responseJS.Code)
   296  	assert.Equal(t, "console.log(42)", responseJS.Body.String())
   297  	assert.Contains(t, responseJS.Header().Get("Content-Type"), "application/javascript")
   298  
   299  	// Verify file response
   300  	appSourceCode, _ := ioutil.ReadFile("Application.go")
   301  	assert.Equal(t, http.StatusOK, responseFile.Code)
   302  	assert.Equal(t, appSourceCode, responseFile.Body.Bytes())
   303  	assert.Contains(t, responseFile.Header().Get("Content-Type"), "text/plain")
   304  
   305  	// Verify media file response
   306  	imageData, _ := ioutil.ReadFile("docs/usage.gif")
   307  	assert.Equal(t, http.StatusOK, responseMediaFile.Code)
   308  	assert.Equal(t, imageData, responseMediaFile.Body.Bytes())
   309  	assert.Contains(t, responseMediaFile.Header().Get("Content-Type"), "image/gif")
   310  }
   311  
   312  func TestContextReader(t *testing.T) {
   313  	app := aero.New()
   314  	config, _ := jsoniter.MarshalToString(app.Config)
   315  
   316  	// ReadAll
   317  	app.Get("/readall", func(ctx *aero.Context) string {
   318  		reader, writer := io.Pipe()
   319  
   320  		go func() {
   321  			defer writer.Close()
   322  			encoder := jsoniter.NewEncoder(writer)
   323  			encoder.Encode(app.Config)
   324  		}()
   325  
   326  		return ctx.ReadAll(reader)
   327  	})
   328  
   329  	// Reader
   330  	app.Get("/reader", func(ctx *aero.Context) string {
   331  		reader, writer := io.Pipe()
   332  
   333  		go func() {
   334  			defer writer.Close()
   335  			encoder := jsoniter.NewEncoder(writer)
   336  			encoder.Encode(app.Config)
   337  		}()
   338  
   339  		return ctx.Reader(reader)
   340  	})
   341  
   342  	// ReadSeeker
   343  	app.Get("/readseeker", func(ctx *aero.Context) string {
   344  		return ctx.ReadSeeker(strings.NewReader(config))
   345  	})
   346  
   347  	routes := []string{
   348  		"/readall",
   349  		"/reader",
   350  		"/readseeker",
   351  	}
   352  
   353  	for _, route := range routes {
   354  		// Verify response
   355  		response := getResponse(app, route)
   356  		assert.Equal(t, http.StatusOK, response.Code)
   357  		assert.Equal(t, config, strings.TrimSpace(response.Body.String()))
   358  	}
   359  }
   360  
   361  func TestContextHTTP2Push(t *testing.T) {
   362  	app := aero.New()
   363  	app.Config.Push = append(app.Config.Push, "/pushed.css")
   364  
   365  	// Register routes
   366  	app.Get("/", func(ctx *aero.Context) string {
   367  		return ctx.HTML("<html></html>")
   368  	})
   369  
   370  	app.Get("/pushed.css", func(ctx *aero.Context) string {
   371  		return ctx.CSS("body{}")
   372  	})
   373  
   374  	// Add no-op push condition
   375  	app.AddPushCondition(func(ctx *aero.Context) bool {
   376  		return true
   377  	})
   378  
   379  	// Get response
   380  	response := getResponse(app, "/")
   381  
   382  	// Verify response
   383  	assert.Equal(t, http.StatusOK, response.Code)
   384  	assert.Equal(t, "<html></html>", response.Body.String())
   385  }
   386  
   387  func TestContextGetInt(t *testing.T) {
   388  	app := aero.New()
   389  
   390  	// Register route
   391  	app.Get("/:number", func(ctx *aero.Context) string {
   392  		number, err := ctx.GetInt("number")
   393  		assert.NoError(t, err)
   394  		assert.NotZero(t, number)
   395  
   396  		return ctx.Text(strconv.Itoa(number * 2))
   397  	})
   398  
   399  	// Get response
   400  	response := getResponse(app, "/21")
   401  
   402  	// Verify response
   403  	assert.Equal(t, http.StatusOK, response.Code)
   404  	assert.Equal(t, "42", response.Body.String())
   405  }
   406  
   407  func TestContextUserAgent(t *testing.T) {
   408  	app := aero.New()
   409  	agent := "Luke Skywalker"
   410  
   411  	// Register route
   412  	app.Get("/", func(ctx *aero.Context) string {
   413  		userAgent := ctx.UserAgent()
   414  		return ctx.Text(userAgent)
   415  	})
   416  
   417  	// Create request
   418  	request, _ := http.NewRequest("GET", "/", nil)
   419  	request.Header.Set("User-Agent", agent)
   420  
   421  	// Get response
   422  	response := httptest.NewRecorder()
   423  	app.Handler().ServeHTTP(response, request)
   424  
   425  	// Verify response
   426  	assert.Equal(t, http.StatusOK, response.Code)
   427  	assert.Equal(t, agent, response.Body.String())
   428  }
   429  
   430  func TestContextRedirect(t *testing.T) {
   431  	app := aero.New()
   432  
   433  	// Register routes
   434  	app.Get("/permanent", func(ctx *aero.Context) string {
   435  		return ctx.RedirectPermanently("/target")
   436  	})
   437  
   438  	app.Get("/temporary", func(ctx *aero.Context) string {
   439  		return ctx.Redirect("/target")
   440  	})
   441  
   442  	// Get temporary response
   443  	response := getResponse(app, "/temporary")
   444  
   445  	// Verify response
   446  	assert.Equal(t, http.StatusFound, response.Code)
   447  	assert.Equal(t, "", response.Body.String())
   448  
   449  	// Get permanent response
   450  	response = getResponse(app, "/permanent")
   451  
   452  	// Verify response
   453  	assert.Equal(t, http.StatusMovedPermanently, response.Code)
   454  	assert.Equal(t, "", response.Body.String())
   455  }
   456  
   457  func TestContextQuery(t *testing.T) {
   458  	app := aero.New()
   459  	search := "Luke Skywalker"
   460  
   461  	// Register route
   462  	app.Get("/", func(ctx *aero.Context) string {
   463  		search := ctx.Query("search")
   464  		return ctx.Text(search)
   465  	})
   466  
   467  	// Create request
   468  	request, _ := http.NewRequest("GET", "/?search="+search, nil)
   469  
   470  	// Get response
   471  	response := httptest.NewRecorder()
   472  	app.Handler().ServeHTTP(response, request)
   473  
   474  	// Verify response
   475  	assert.Equal(t, http.StatusOK, response.Code)
   476  	assert.Equal(t, search, response.Body.String())
   477  }
   478  
   479  func TestContextEventStream(t *testing.T) {
   480  	app := aero.New()
   481  
   482  	// Register route
   483  	app.Get("/", func(ctx *aero.Context) string {
   484  		stream := aero.NewEventStream()
   485  
   486  		go func() {
   487  			for {
   488  				select {
   489  				case <-stream.Closed:
   490  					close(stream.Events)
   491  					return
   492  
   493  				case <-time.After(10 * time.Millisecond):
   494  					stream.Events <- &aero.Event{
   495  						Name: "ping",
   496  						Data: "{}",
   497  					}
   498  
   499  					stream.Events <- &aero.Event{
   500  						Name: "ping",
   501  						Data: []byte("{}"),
   502  					}
   503  
   504  					stream.Events <- &aero.Event{
   505  						Name: "ping",
   506  						Data: struct {
   507  							Message string `json:"message"`
   508  						}{
   509  							Message: "Hello",
   510  						},
   511  					}
   512  
   513  					stream.Events <- &aero.Event{
   514  						Name: "ping",
   515  						Data: nil,
   516  					}
   517  				}
   518  			}
   519  		}()
   520  
   521  		return ctx.EventStream(stream)
   522  	})
   523  
   524  	// Create request
   525  	request, _ := http.NewRequest("GET", "/", nil)
   526  	ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
   527  	defer cancel()
   528  	request = request.WithContext(ctx)
   529  
   530  	// Get response
   531  	response := httptest.NewRecorder()
   532  	app.Handler().ServeHTTP(response, request)
   533  
   534  	// Verify response
   535  	assert.Equal(t, http.StatusOK, response.Code)
   536  }
   537  
   538  func TestBigResponse(t *testing.T) {
   539  	text := strings.Repeat("Hello World", 1000000)
   540  	app := aero.New()
   541  
   542  	// Make sure GZip is enabled
   543  	assert.Equal(t, true, app.Config.GZip)
   544  
   545  	// Register route
   546  	app.Get("/", func(ctx *aero.Context) string {
   547  		return ctx.Text(text)
   548  	})
   549  
   550  	// Get response
   551  	response := getResponse(app, "/")
   552  
   553  	// Verify the response
   554  	assert.Equal(t, http.StatusOK, response.Code)
   555  	assert.Equal(t, "gzip", response.Header().Get("Content-Encoding"))
   556  }
   557  
   558  func TestBigResponseNoGzip(t *testing.T) {
   559  	text := strings.Repeat("Hello World", 1000000)
   560  	app := aero.New()
   561  
   562  	// Register route
   563  	app.Get("/", func(ctx *aero.Context) string {
   564  		return ctx.Text(text)
   565  	})
   566  
   567  	// Create request and record response
   568  	request, _ := http.NewRequest("GET", "/", nil)
   569  	response := httptest.NewRecorder()
   570  	app.Handler().ServeHTTP(response, request)
   571  
   572  	// Verify the response
   573  	assert.Equal(t, http.StatusOK, response.Code)
   574  	assert.Equal(t, "", response.Header().Get("Content-Encoding"))
   575  }
   576  
   577  func TestBigResponse304(t *testing.T) {
   578  	text := strings.Repeat("Hello World", 1000000)
   579  	app := aero.New()
   580  
   581  	// Register route
   582  	app.Get("/", func(ctx *aero.Context) string {
   583  		return ctx.Text(text)
   584  	})
   585  
   586  	// Create request and record response
   587  	request, _ := http.NewRequest("GET", "/", nil)
   588  	response := httptest.NewRecorder()
   589  	app.Handler().ServeHTTP(response, request)
   590  	etag := response.Header().Get("ETag")
   591  
   592  	// Verify the response
   593  	assert.Equal(t, http.StatusOK, response.Code)
   594  	assert.NotEmpty(t, response.Body.String())
   595  
   596  	// Set if-none-match to the etag we just received
   597  	request, _ = http.NewRequest("GET", "/", nil)
   598  	request.Header.Set("If-None-Match", etag)
   599  	response = httptest.NewRecorder()
   600  	app.Handler().ServeHTTP(response, request)
   601  
   602  	// Verify the response
   603  	assert.Equal(t, 304, response.Code)
   604  	assert.Empty(t, response.Body.String())
   605  }