github.com/anycable/anycable-go@v1.5.1/ws/handler_test.go (about)

     1  package ws
     2  
     3  import (
     4  	"net/http/httptest"
     5  	"testing"
     6  
     7  	"github.com/stretchr/testify/assert"
     8  )
     9  
    10  func TestCheckOriginWithoutHeader(t *testing.T) {
    11  	req := httptest.NewRequest("GET", "/", nil)
    12  
    13  	allowedOrigins := ""
    14  	assert.Equal(t, CheckOrigin(allowedOrigins)(req), true)
    15  
    16  	allowedOrigins = "secure.origin"
    17  	assert.Equal(t, CheckOrigin(allowedOrigins)(req), false)
    18  }
    19  
    20  func TestCheckOrigin(t *testing.T) {
    21  	req := httptest.NewRequest("GET", "/", nil)
    22  	req.Header.Set("Origin", "http://my.localhost:8080")
    23  
    24  	allowedOrigins := ""
    25  	assert.Equal(t, CheckOrigin(allowedOrigins)(req), true)
    26  
    27  	allowedOrigins = "my.localhost:8080"
    28  	assert.Equal(t, CheckOrigin(allowedOrigins)(req), true)
    29  
    30  	allowedOrigins = "MY.localhost:8080"
    31  	assert.Equal(t, CheckOrigin(allowedOrigins)(req), true)
    32  
    33  	allowedOrigins = "localhost:8080"
    34  	assert.Equal(t, CheckOrigin(allowedOrigins)(req), false)
    35  
    36  	allowedOrigins = "*.localhost:8080"
    37  	assert.Equal(t, CheckOrigin(allowedOrigins)(req), true)
    38  
    39  	allowedOrigins = "secure.origin,my.localhost:8080"
    40  	assert.Equal(t, CheckOrigin(allowedOrigins)(req), true)
    41  
    42  	allowedOrigins = "secure.origin,*.localhost:8080"
    43  	assert.Equal(t, CheckOrigin(allowedOrigins)(req), true)
    44  
    45  	req.Header.Set("Origin", "http://MY.localhost:8080")
    46  	assert.Equal(t, CheckOrigin(allowedOrigins)(req), true)
    47  }