github.com/masterhung0112/hk_server/v5@v5.0.0-20220302090640-ec71aef15e1c/api4/cors_test.go (about) 1 // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved. 2 // See LICENSE.txt for license information. 3 4 package api4 5 6 import ( 7 "fmt" 8 "net/http" 9 "testing" 10 11 "github.com/stretchr/testify/assert" 12 "github.com/stretchr/testify/require" 13 14 "github.com/masterhung0112/hk_server/v5/model" 15 "github.com/masterhung0112/hk_server/v5/store/storetest/mocks" 16 ) 17 18 const ( 19 acAllowOrigin = "Access-Control-Allow-Origin" 20 acExposeHeaders = "Access-Control-Expose-Headers" 21 acMaxAge = "Access-Control-Max-Age" 22 acAllowCredentials = "Access-Control-Allow-Credentials" 23 acAllowMethods = "Access-Control-Allow-Methods" 24 acAllowHeaders = "Access-Control-Allow-Headers" 25 ) 26 27 func TestCORSRequestHandling(t *testing.T) { 28 for name, testcase := range map[string]struct { 29 AllowCorsFrom string 30 CorsExposedHeaders string 31 CorsAllowCredentials bool 32 ModifyRequest func(req *http.Request) 33 ExpectedAllowOrigin string 34 ExpectedExposeHeaders string 35 ExpectedAllowCredentials string 36 }{ 37 "NoCORS": { 38 "", 39 "", 40 false, 41 func(req *http.Request) { 42 }, 43 "", 44 "", 45 "", 46 }, 47 "CORSEnabled": { 48 "http://somewhere.com", 49 "", 50 false, 51 func(req *http.Request) { 52 }, 53 "", 54 "", 55 "", 56 }, 57 "CORSEnabledStarOrigin": { 58 "*", 59 "", 60 false, 61 func(req *http.Request) { 62 req.Header.Set("Origin", "http://pre-release.mattermost.com") 63 }, 64 "*", 65 "", 66 "", 67 }, 68 "CORSEnabledStarNoOrigin": { // CORS spec requires this, not a bug. 69 "*", 70 "", 71 false, 72 func(req *http.Request) { 73 }, 74 "", 75 "", 76 "", 77 }, 78 "CORSEnabledMatching": { 79 "http://mattermost.com", 80 "", 81 false, 82 func(req *http.Request) { 83 req.Header.Set("Origin", "http://mattermost.com") 84 }, 85 "http://mattermost.com", 86 "", 87 "", 88 }, 89 "CORSEnabledMultiple": { 90 "http://spinmint.com http://mattermost.com", 91 "", 92 false, 93 func(req *http.Request) { 94 req.Header.Set("Origin", "http://mattermost.com") 95 }, 96 "http://mattermost.com", 97 "", 98 "", 99 }, 100 "CORSEnabledWithCredentials": { 101 "http://mattermost.com", 102 "", 103 true, 104 func(req *http.Request) { 105 req.Header.Set("Origin", "http://mattermost.com") 106 }, 107 "http://mattermost.com", 108 "", 109 "true", 110 }, 111 "CORSEnabledWithHeaders": { 112 "http://mattermost.com", 113 "x-my-special-header x-blueberry", 114 true, 115 func(req *http.Request) { 116 req.Header.Set("Origin", "http://mattermost.com") 117 }, 118 "http://mattermost.com", 119 "X-My-Special-Header, X-Blueberry", 120 "true", 121 }, 122 } { 123 t.Run(name, func(t *testing.T) { 124 th := SetupConfigWithStoreMock(t, func(cfg *model.Config) { 125 *cfg.ServiceSettings.AllowCorsFrom = testcase.AllowCorsFrom 126 *cfg.ServiceSettings.CorsExposedHeaders = testcase.CorsExposedHeaders 127 *cfg.ServiceSettings.CorsAllowCredentials = testcase.CorsAllowCredentials 128 }) 129 defer th.TearDown() 130 systemStore := mocks.SystemStore{} 131 systemStore.On("Get").Return(make(model.StringMap), nil) 132 licenseStore := mocks.LicenseStore{} 133 licenseStore.On("Get", "").Return(&model.LicenseRecord{}, nil) 134 th.App.Srv().Store.(*mocks.Store).On("System").Return(&systemStore) 135 th.App.Srv().Store.(*mocks.Store).On("License").Return(&licenseStore) 136 137 port := th.App.Srv().ListenAddr.Port 138 host := fmt.Sprintf("http://localhost:%v", port) 139 url := fmt.Sprintf("%v/api/v4/system/ping", host) 140 141 req, err := http.NewRequest("GET", url, nil) 142 require.NoError(t, err) 143 testcase.ModifyRequest(req) 144 145 client := &http.Client{} 146 resp, err := client.Do(req) 147 require.NoError(t, err) 148 assert.Equal(t, http.StatusOK, resp.StatusCode) 149 assert.Equal(t, testcase.ExpectedAllowOrigin, resp.Header.Get(acAllowOrigin)) 150 assert.Equal(t, testcase.ExpectedExposeHeaders, resp.Header.Get(acExposeHeaders)) 151 assert.Equal(t, "", resp.Header.Get(acMaxAge)) 152 assert.Equal(t, testcase.ExpectedAllowCredentials, resp.Header.Get(acAllowCredentials)) 153 assert.Equal(t, "", resp.Header.Get(acAllowMethods)) 154 assert.Equal(t, "", resp.Header.Get(acAllowHeaders)) 155 }) 156 } 157 }