github.com/mattermost/mattermost-server/v5@v5.39.3/api4/cors_test.go (about)

     1  // Copyright (c) 2015-present Mattermost, Inc. All Rights Reserved.
     2  // See LICENSE.txt for license information.
     3  
     4  package api4
     5  
     6  import (
     7  	"fmt"
     8  	"net/http"
     9  	"testing"
    10  
    11  	"github.com/stretchr/testify/assert"
    12  	"github.com/stretchr/testify/require"
    13  
    14  	"github.com/mattermost/mattermost-server/v5/model"
    15  	"github.com/mattermost/mattermost-server/v5/store/storetest/mocks"
    16  )
    17  
    18  const (
    19  	acAllowOrigin      = "Access-Control-Allow-Origin"
    20  	acExposeHeaders    = "Access-Control-Expose-Headers"
    21  	acMaxAge           = "Access-Control-Max-Age"
    22  	acAllowCredentials = "Access-Control-Allow-Credentials"
    23  	acAllowMethods     = "Access-Control-Allow-Methods"
    24  	acAllowHeaders     = "Access-Control-Allow-Headers"
    25  )
    26  
    27  func TestCORSRequestHandling(t *testing.T) {
    28  	for name, testcase := range map[string]struct {
    29  		AllowCorsFrom            string
    30  		CorsExposedHeaders       string
    31  		CorsAllowCredentials     bool
    32  		ModifyRequest            func(req *http.Request)
    33  		ExpectedAllowOrigin      string
    34  		ExpectedExposeHeaders    string
    35  		ExpectedAllowCredentials string
    36  	}{
    37  		"NoCORS": {
    38  			"",
    39  			"",
    40  			false,
    41  			func(req *http.Request) {
    42  			},
    43  			"",
    44  			"",
    45  			"",
    46  		},
    47  		"CORSEnabled": {
    48  			"http://somewhere.com",
    49  			"",
    50  			false,
    51  			func(req *http.Request) {
    52  			},
    53  			"",
    54  			"",
    55  			"",
    56  		},
    57  		"CORSEnabledStarOrigin": {
    58  			"*",
    59  			"",
    60  			false,
    61  			func(req *http.Request) {
    62  				req.Header.Set("Origin", "http://pre-release.mattermost.com")
    63  			},
    64  			"*",
    65  			"",
    66  			"",
    67  		},
    68  		"CORSEnabledStarNoOrigin": { // CORS spec requires this, not a bug.
    69  			"*",
    70  			"",
    71  			false,
    72  			func(req *http.Request) {
    73  			},
    74  			"",
    75  			"",
    76  			"",
    77  		},
    78  		"CORSEnabledMatching": {
    79  			"http://mattermost.com",
    80  			"",
    81  			false,
    82  			func(req *http.Request) {
    83  				req.Header.Set("Origin", "http://mattermost.com")
    84  			},
    85  			"http://mattermost.com",
    86  			"",
    87  			"",
    88  		},
    89  		"CORSEnabledMultiple": {
    90  			"http://spinmint.com http://mattermost.com",
    91  			"",
    92  			false,
    93  			func(req *http.Request) {
    94  				req.Header.Set("Origin", "http://mattermost.com")
    95  			},
    96  			"http://mattermost.com",
    97  			"",
    98  			"",
    99  		},
   100  		"CORSEnabledWithCredentials": {
   101  			"http://mattermost.com",
   102  			"",
   103  			true,
   104  			func(req *http.Request) {
   105  				req.Header.Set("Origin", "http://mattermost.com")
   106  			},
   107  			"http://mattermost.com",
   108  			"",
   109  			"true",
   110  		},
   111  		"CORSEnabledWithHeaders": {
   112  			"http://mattermost.com",
   113  			"x-my-special-header x-blueberry",
   114  			true,
   115  			func(req *http.Request) {
   116  				req.Header.Set("Origin", "http://mattermost.com")
   117  			},
   118  			"http://mattermost.com",
   119  			"X-My-Special-Header, X-Blueberry",
   120  			"true",
   121  		},
   122  	} {
   123  		t.Run(name, func(t *testing.T) {
   124  			th := SetupConfigWithStoreMock(t, func(cfg *model.Config) {
   125  				*cfg.ServiceSettings.AllowCorsFrom = testcase.AllowCorsFrom
   126  				*cfg.ServiceSettings.CorsExposedHeaders = testcase.CorsExposedHeaders
   127  				*cfg.ServiceSettings.CorsAllowCredentials = testcase.CorsAllowCredentials
   128  			})
   129  			defer th.TearDown()
   130  			licenseStore := mocks.LicenseStore{}
   131  			licenseStore.On("Get", "").Return(&model.LicenseRecord{}, nil)
   132  			th.App.Srv().Store.(*mocks.Store).On("License").Return(&licenseStore)
   133  
   134  			port := th.App.Srv().ListenAddr.Port
   135  			host := fmt.Sprintf("http://localhost:%v", port)
   136  			url := fmt.Sprintf("%v/api/v4/system/ping", host)
   137  
   138  			req, err := http.NewRequest("GET", url, nil)
   139  			require.NoError(t, err)
   140  			testcase.ModifyRequest(req)
   141  
   142  			client := &http.Client{}
   143  			resp, err := client.Do(req)
   144  			require.NoError(t, err)
   145  			assert.Equal(t, http.StatusOK, resp.StatusCode)
   146  			assert.Equal(t, testcase.ExpectedAllowOrigin, resp.Header.Get(acAllowOrigin))
   147  			assert.Equal(t, testcase.ExpectedExposeHeaders, resp.Header.Get(acExposeHeaders))
   148  			assert.Equal(t, "", resp.Header.Get(acMaxAge))
   149  			assert.Equal(t, testcase.ExpectedAllowCredentials, resp.Header.Get(acAllowCredentials))
   150  			assert.Equal(t, "", resp.Header.Get(acAllowMethods))
   151  			assert.Equal(t, "", resp.Header.Get(acAllowHeaders))
   152  		})
   153  	}
   154  }