github.com/masterhung0112/hk_server/v5@v5.0.0-20220302090640-ec71aef15e1c/api4/cors_test.go (about)

     1  // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
     2  // See LICENSE.txt for license information.
     3  
     4  package api4
     5  
     6  import (
     7  	"fmt"
     8  	"net/http"
     9  	"testing"
    10  
    11  	"github.com/stretchr/testify/assert"
    12  	"github.com/stretchr/testify/require"
    13  
    14  	"github.com/masterhung0112/hk_server/v5/model"
    15  	"github.com/masterhung0112/hk_server/v5/store/storetest/mocks"
    16  )
    17  
    18  const (
    19  	acAllowOrigin      = "Access-Control-Allow-Origin"
    20  	acExposeHeaders    = "Access-Control-Expose-Headers"
    21  	acMaxAge           = "Access-Control-Max-Age"
    22  	acAllowCredentials = "Access-Control-Allow-Credentials"
    23  	acAllowMethods     = "Access-Control-Allow-Methods"
    24  	acAllowHeaders     = "Access-Control-Allow-Headers"
    25  )
    26  
    27  func TestCORSRequestHandling(t *testing.T) {
    28  	for name, testcase := range map[string]struct {
    29  		AllowCorsFrom            string
    30  		CorsExposedHeaders       string
    31  		CorsAllowCredentials     bool
    32  		ModifyRequest            func(req *http.Request)
    33  		ExpectedAllowOrigin      string
    34  		ExpectedExposeHeaders    string
    35  		ExpectedAllowCredentials string
    36  	}{
    37  		"NoCORS": {
    38  			"",
    39  			"",
    40  			false,
    41  			func(req *http.Request) {
    42  			},
    43  			"",
    44  			"",
    45  			"",
    46  		},
    47  		"CORSEnabled": {
    48  			"http://somewhere.com",
    49  			"",
    50  			false,
    51  			func(req *http.Request) {
    52  			},
    53  			"",
    54  			"",
    55  			"",
    56  		},
    57  		"CORSEnabledStarOrigin": {
    58  			"*",
    59  			"",
    60  			false,
    61  			func(req *http.Request) {
    62  				req.Header.Set("Origin", "http://pre-release.mattermost.com")
    63  			},
    64  			"*",
    65  			"",
    66  			"",
    67  		},
    68  		"CORSEnabledStarNoOrigin": { // CORS spec requires this, not a bug.
    69  			"*",
    70  			"",
    71  			false,
    72  			func(req *http.Request) {
    73  			},
    74  			"",
    75  			"",
    76  			"",
    77  		},
    78  		"CORSEnabledMatching": {
    79  			"http://mattermost.com",
    80  			"",
    81  			false,
    82  			func(req *http.Request) {
    83  				req.Header.Set("Origin", "http://mattermost.com")
    84  			},
    85  			"http://mattermost.com",
    86  			"",
    87  			"",
    88  		},
    89  		"CORSEnabledMultiple": {
    90  			"http://spinmint.com http://mattermost.com",
    91  			"",
    92  			false,
    93  			func(req *http.Request) {
    94  				req.Header.Set("Origin", "http://mattermost.com")
    95  			},
    96  			"http://mattermost.com",
    97  			"",
    98  			"",
    99  		},
   100  		"CORSEnabledWithCredentials": {
   101  			"http://mattermost.com",
   102  			"",
   103  			true,
   104  			func(req *http.Request) {
   105  				req.Header.Set("Origin", "http://mattermost.com")
   106  			},
   107  			"http://mattermost.com",
   108  			"",
   109  			"true",
   110  		},
   111  		"CORSEnabledWithHeaders": {
   112  			"http://mattermost.com",
   113  			"x-my-special-header x-blueberry",
   114  			true,
   115  			func(req *http.Request) {
   116  				req.Header.Set("Origin", "http://mattermost.com")
   117  			},
   118  			"http://mattermost.com",
   119  			"X-My-Special-Header, X-Blueberry",
   120  			"true",
   121  		},
   122  	} {
   123  		t.Run(name, func(t *testing.T) {
   124  			th := SetupConfigWithStoreMock(t, func(cfg *model.Config) {
   125  				*cfg.ServiceSettings.AllowCorsFrom = testcase.AllowCorsFrom
   126  				*cfg.ServiceSettings.CorsExposedHeaders = testcase.CorsExposedHeaders
   127  				*cfg.ServiceSettings.CorsAllowCredentials = testcase.CorsAllowCredentials
   128  			})
   129  			defer th.TearDown()
   130  			systemStore := mocks.SystemStore{}
   131  			systemStore.On("Get").Return(make(model.StringMap), nil)
   132  			licenseStore := mocks.LicenseStore{}
   133  			licenseStore.On("Get", "").Return(&model.LicenseRecord{}, nil)
   134  			th.App.Srv().Store.(*mocks.Store).On("System").Return(&systemStore)
   135  			th.App.Srv().Store.(*mocks.Store).On("License").Return(&licenseStore)
   136  
   137  			port := th.App.Srv().ListenAddr.Port
   138  			host := fmt.Sprintf("http://localhost:%v", port)
   139  			url := fmt.Sprintf("%v/api/v4/system/ping", host)
   140  
   141  			req, err := http.NewRequest("GET", url, nil)
   142  			require.NoError(t, err)
   143  			testcase.ModifyRequest(req)
   144  
   145  			client := &http.Client{}
   146  			resp, err := client.Do(req)
   147  			require.NoError(t, err)
   148  			assert.Equal(t, http.StatusOK, resp.StatusCode)
   149  			assert.Equal(t, testcase.ExpectedAllowOrigin, resp.Header.Get(acAllowOrigin))
   150  			assert.Equal(t, testcase.ExpectedExposeHeaders, resp.Header.Get(acExposeHeaders))
   151  			assert.Equal(t, "", resp.Header.Get(acMaxAge))
   152  			assert.Equal(t, testcase.ExpectedAllowCredentials, resp.Header.Get(acAllowCredentials))
   153  			assert.Equal(t, "", resp.Header.Get(acAllowMethods))
   154  			assert.Equal(t, "", resp.Header.Get(acAllowHeaders))
   155  		})
   156  	}
   157  }