github.com/sijibomii/docker@v0.0.0-20231230191044-5cf6ca554647/pkg/authorization/authz_unix_test.go (about)

     1  // +build !windows
     2  
     3  // TODO Windows: This uses a Unix socket for testing. This might be possible
     4  // to port to Windows using a named pipe instead.
     5  
     6  package authorization
     7  
     8  import (
     9  	"encoding/json"
    10  	"io/ioutil"
    11  	"log"
    12  	"net"
    13  	"net/http"
    14  	"net/http/httptest"
    15  	"os"
    16  	"path"
    17  	"reflect"
    18  	"testing"
    19  
    20  	"bytes"
    21  	"strings"
    22  
    23  	"github.com/docker/docker/pkg/plugins"
    24  	"github.com/docker/go-connections/tlsconfig"
    25  	"github.com/gorilla/mux"
    26  )
    27  
    28  const pluginAddress = "authzplugin.sock"
    29  
    30  func TestAuthZRequestPluginError(t *testing.T) {
    31  	server := authZPluginTestServer{t: t}
    32  	go server.start()
    33  	defer server.stop()
    34  
    35  	authZPlugin := createTestPlugin(t)
    36  
    37  	request := Request{
    38  		User:           "user",
    39  		RequestBody:    []byte("sample body"),
    40  		RequestURI:     "www.authz.com",
    41  		RequestMethod:  "GET",
    42  		RequestHeaders: map[string]string{"header": "value"},
    43  	}
    44  	server.replayResponse = Response{
    45  		Err: "an error",
    46  	}
    47  
    48  	actualResponse, err := authZPlugin.AuthZRequest(&request)
    49  	if err != nil {
    50  		t.Fatalf("Failed to authorize request %v", err)
    51  	}
    52  
    53  	if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
    54  		t.Fatalf("Response must be equal")
    55  	}
    56  	if !reflect.DeepEqual(request, server.recordedRequest) {
    57  		t.Fatalf("Requests must be equal")
    58  	}
    59  }
    60  
    61  func TestAuthZRequestPlugin(t *testing.T) {
    62  	server := authZPluginTestServer{t: t}
    63  	go server.start()
    64  	defer server.stop()
    65  
    66  	authZPlugin := createTestPlugin(t)
    67  
    68  	request := Request{
    69  		User:           "user",
    70  		RequestBody:    []byte("sample body"),
    71  		RequestURI:     "www.authz.com",
    72  		RequestMethod:  "GET",
    73  		RequestHeaders: map[string]string{"header": "value"},
    74  	}
    75  	server.replayResponse = Response{
    76  		Allow: true,
    77  		Msg:   "Sample message",
    78  	}
    79  
    80  	actualResponse, err := authZPlugin.AuthZRequest(&request)
    81  	if err != nil {
    82  		t.Fatalf("Failed to authorize request %v", err)
    83  	}
    84  
    85  	if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
    86  		t.Fatalf("Response must be equal")
    87  	}
    88  	if !reflect.DeepEqual(request, server.recordedRequest) {
    89  		t.Fatalf("Requests must be equal")
    90  	}
    91  }
    92  
    93  func TestAuthZResponsePlugin(t *testing.T) {
    94  	server := authZPluginTestServer{t: t}
    95  	go server.start()
    96  	defer server.stop()
    97  
    98  	authZPlugin := createTestPlugin(t)
    99  
   100  	request := Request{
   101  		User:        "user",
   102  		RequestBody: []byte("sample body"),
   103  	}
   104  	server.replayResponse = Response{
   105  		Allow: true,
   106  		Msg:   "Sample message",
   107  	}
   108  
   109  	actualResponse, err := authZPlugin.AuthZResponse(&request)
   110  	if err != nil {
   111  		t.Fatalf("Failed to authorize request %v", err)
   112  	}
   113  
   114  	if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
   115  		t.Fatalf("Response must be equal")
   116  	}
   117  	if !reflect.DeepEqual(request, server.recordedRequest) {
   118  		t.Fatalf("Requests must be equal")
   119  	}
   120  }
   121  
   122  func TestResponseModifier(t *testing.T) {
   123  	r := httptest.NewRecorder()
   124  	m := NewResponseModifier(r)
   125  	m.Header().Set("h1", "v1")
   126  	m.Write([]byte("body"))
   127  	m.WriteHeader(500)
   128  
   129  	m.FlushAll()
   130  	if r.Header().Get("h1") != "v1" {
   131  		t.Fatalf("Header value must exists %s", r.Header().Get("h1"))
   132  	}
   133  	if !reflect.DeepEqual(r.Body.Bytes(), []byte("body")) {
   134  		t.Fatalf("Body value must exists %s", r.Body.Bytes())
   135  	}
   136  	if r.Code != 500 {
   137  		t.Fatalf("Status code must be correct %d", r.Code)
   138  	}
   139  }
   140  
   141  func TestDrainBody(t *testing.T) {
   142  
   143  	tests := []struct {
   144  		length             int // length is the message length send to drainBody
   145  		expectedBodyLength int // expectedBodyLength is the expected body length after drainBody is called
   146  	}{
   147  		{10, 10}, // Small message size
   148  		{maxBodySize - 1, maxBodySize - 1}, // Max message size
   149  		{maxBodySize * 2, 0},               // Large message size (skip copying body)
   150  
   151  	}
   152  
   153  	for _, test := range tests {
   154  
   155  		msg := strings.Repeat("a", test.length)
   156  		body, closer, err := drainBody(ioutil.NopCloser(bytes.NewReader([]byte(msg))))
   157  		if len(body) != test.expectedBodyLength {
   158  			t.Fatalf("Body must be copied, actual length: '%d'", len(body))
   159  		}
   160  		if closer == nil {
   161  			t.Fatalf("Closer must not be nil")
   162  		}
   163  		if err != nil {
   164  			t.Fatalf("Error must not be nil: '%v'", err)
   165  		}
   166  		modified, err := ioutil.ReadAll(closer)
   167  		if err != nil {
   168  			t.Fatalf("Error must not be nil: '%v'", err)
   169  		}
   170  		if len(modified) != len(msg) {
   171  			t.Fatalf("Result should not be truncated. Original length: '%d', new length: '%d'", len(msg), len(modified))
   172  		}
   173  	}
   174  }
   175  
   176  func TestResponseModifierOverride(t *testing.T) {
   177  	r := httptest.NewRecorder()
   178  	m := NewResponseModifier(r)
   179  	m.Header().Set("h1", "v1")
   180  	m.Write([]byte("body"))
   181  	m.WriteHeader(500)
   182  
   183  	overrideHeader := make(http.Header)
   184  	overrideHeader.Add("h1", "v2")
   185  	overrideHeaderBytes, err := json.Marshal(overrideHeader)
   186  	if err != nil {
   187  		t.Fatalf("override header failed %v", err)
   188  	}
   189  
   190  	m.OverrideHeader(overrideHeaderBytes)
   191  	m.OverrideBody([]byte("override body"))
   192  	m.OverrideStatusCode(404)
   193  	m.FlushAll()
   194  	if r.Header().Get("h1") != "v2" {
   195  		t.Fatalf("Header value must exists %s", r.Header().Get("h1"))
   196  	}
   197  	if !reflect.DeepEqual(r.Body.Bytes(), []byte("override body")) {
   198  		t.Fatalf("Body value must exists %s", r.Body.Bytes())
   199  	}
   200  	if r.Code != 404 {
   201  		t.Fatalf("Status code must be correct %d", r.Code)
   202  	}
   203  }
   204  
   205  // createTestPlugin creates a new sample authorization plugin
   206  func createTestPlugin(t *testing.T) *authorizationPlugin {
   207  	plugin := &plugins.Plugin{Name: "authz"}
   208  	pwd, err := os.Getwd()
   209  	if err != nil {
   210  		log.Fatal(err)
   211  	}
   212  
   213  	plugin.Client, err = plugins.NewClient("unix:///"+path.Join(pwd, pluginAddress), tlsconfig.Options{InsecureSkipVerify: true})
   214  	if err != nil {
   215  		t.Fatalf("Failed to create client %v", err)
   216  	}
   217  
   218  	return &authorizationPlugin{name: "plugin", plugin: plugin}
   219  }
   220  
   221  // AuthZPluginTestServer is a simple server that implements the authZ plugin interface
   222  type authZPluginTestServer struct {
   223  	listener net.Listener
   224  	t        *testing.T
   225  	// request stores the request sent from the daemon to the plugin
   226  	recordedRequest Request
   227  	// response stores the response sent from the plugin to the daemon
   228  	replayResponse Response
   229  }
   230  
   231  // start starts the test server that implements the plugin
   232  func (t *authZPluginTestServer) start() {
   233  	r := mux.NewRouter()
   234  	os.Remove(pluginAddress)
   235  	l, err := net.ListenUnix("unix", &net.UnixAddr{Name: pluginAddress, Net: "unix"})
   236  	if err != nil {
   237  		t.t.Fatalf("Failed to listen %v", err)
   238  	}
   239  	t.listener = l
   240  
   241  	r.HandleFunc("/Plugin.Activate", t.activate)
   242  	r.HandleFunc("/"+AuthZApiRequest, t.auth)
   243  	r.HandleFunc("/"+AuthZApiResponse, t.auth)
   244  	t.listener, _ = net.Listen("tcp", pluginAddress)
   245  	server := http.Server{Handler: r, Addr: pluginAddress}
   246  	server.Serve(l)
   247  }
   248  
   249  // stop stops the test server that implements the plugin
   250  func (t *authZPluginTestServer) stop() {
   251  	os.Remove(pluginAddress)
   252  	if t.listener != nil {
   253  		t.listener.Close()
   254  	}
   255  }
   256  
   257  // auth is a used to record/replay the authentication api messages
   258  func (t *authZPluginTestServer) auth(w http.ResponseWriter, r *http.Request) {
   259  	t.recordedRequest = Request{}
   260  	defer r.Body.Close()
   261  	body, err := ioutil.ReadAll(r.Body)
   262  	json.Unmarshal(body, &t.recordedRequest)
   263  	b, err := json.Marshal(t.replayResponse)
   264  	if err != nil {
   265  		log.Fatal(err)
   266  	}
   267  	w.Write(b)
   268  }
   269  
   270  func (t *authZPluginTestServer) activate(w http.ResponseWriter, r *http.Request) {
   271  	b, err := json.Marshal(plugins.Manifest{Implements: []string{AuthZApiImplements}})
   272  	if err != nil {
   273  		log.Fatal(err)
   274  	}
   275  	w.Write(b)
   276  }