github.com/demonoid81/moby@v0.0.0-20200517203328-62dd8e17c460/pkg/authorization/middleware_test.go (about)

     1  package authorization // import "github.com/demonoid81/moby/pkg/authorization"
     2  
     3  import (
     4  	"net/http"
     5  	"net/http/httptest"
     6  	"strings"
     7  	"testing"
     8  
     9  	"github.com/demonoid81/moby/pkg/plugingetter"
    10  	"gotest.tools/v3/assert"
    11  )
    12  
    13  func TestMiddleware(t *testing.T) {
    14  	pluginNames := []string{"testPlugin1", "testPlugin2"}
    15  	var pluginGetter plugingetter.PluginGetter
    16  	m := NewMiddleware(pluginNames, pluginGetter)
    17  	authPlugins := m.getAuthzPlugins()
    18  	assert.Equal(t, 2, len(authPlugins))
    19  	assert.Equal(t, pluginNames[0], authPlugins[0].Name())
    20  	assert.Equal(t, pluginNames[1], authPlugins[1].Name())
    21  }
    22  
    23  func TestNewResponseModifier(t *testing.T) {
    24  	recorder := httptest.NewRecorder()
    25  	modifier := NewResponseModifier(recorder)
    26  	modifier.Header().Set("H1", "V1")
    27  	modifier.Write([]byte("body"))
    28  	assert.Assert(t, !modifier.Hijacked())
    29  	modifier.WriteHeader(http.StatusInternalServerError)
    30  	assert.Assert(t, modifier.RawBody() != nil)
    31  
    32  	raw, err := modifier.RawHeaders()
    33  	assert.Assert(t, raw != nil)
    34  	assert.NilError(t, err)
    35  
    36  	headerData := strings.Split(strings.TrimSpace(string(raw)), ":")
    37  	assert.Equal(t, "H1", strings.TrimSpace(headerData[0]))
    38  	assert.Equal(t, "V1", strings.TrimSpace(headerData[1]))
    39  
    40  	modifier.Flush()
    41  	modifier.FlushAll()
    42  
    43  	if recorder.Header().Get("H1") != "V1" {
    44  		t.Fatalf("Header value must exists %s", recorder.Header().Get("H1"))
    45  	}
    46  
    47  }
    48  
    49  func setAuthzPlugins(m *Middleware, plugins []Plugin) {
    50  	m.mu.Lock()
    51  	m.plugins = plugins
    52  	m.mu.Unlock()
    53  }