github.com/kaisenlinux/docker.io@v0.0.0-20230510090727-ea55db55fac7/engine/pkg/authorization/authz_unix_test.go (about)

     1  //go:build !windows
     2  // +build !windows
     3  
     4  // TODO Windows: This uses a Unix socket for testing. This might be possible
     5  // to port to Windows using a named pipe instead.
     6  
     7  package authorization // import "github.com/docker/docker/pkg/authorization"
     8  
     9  import (
    10  	"bytes"
    11  	"encoding/json"
    12  	"io"
    13  	"net"
    14  	"net/http"
    15  	"net/http/httptest"
    16  	"os"
    17  	"path"
    18  	"reflect"
    19  	"strings"
    20  	"testing"
    21  
    22  	"github.com/docker/docker/pkg/plugins"
    23  	"github.com/docker/go-connections/tlsconfig"
    24  	"github.com/gorilla/mux"
    25  )
    26  
    27  const (
    28  	pluginAddress = "authz-test-plugin.sock"
    29  )
    30  
    31  func TestAuthZRequestPluginError(t *testing.T) {
    32  	server := authZPluginTestServer{t: t}
    33  	server.start()
    34  	defer server.stop()
    35  
    36  	authZPlugin := createTestPlugin(t, server.socketAddress())
    37  
    38  	request := Request{
    39  		User:           "user",
    40  		RequestBody:    []byte("sample body"),
    41  		RequestURI:     "www.authz.com/auth",
    42  		RequestMethod:  http.MethodGet,
    43  		RequestHeaders: map[string]string{"header": "value"},
    44  	}
    45  	server.replayResponse = Response{
    46  		Err: "an error",
    47  	}
    48  
    49  	actualResponse, err := authZPlugin.AuthZRequest(&request)
    50  	if err != nil {
    51  		t.Fatalf("Failed to authorize request %v", err)
    52  	}
    53  
    54  	if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
    55  		t.Fatal("Response must be equal")
    56  	}
    57  	if !reflect.DeepEqual(request, server.recordedRequest) {
    58  		t.Fatal("Requests must be equal")
    59  	}
    60  }
    61  
    62  func TestAuthZRequestPlugin(t *testing.T) {
    63  	server := authZPluginTestServer{t: t}
    64  	server.start()
    65  	defer server.stop()
    66  
    67  	authZPlugin := createTestPlugin(t, server.socketAddress())
    68  
    69  	request := Request{
    70  		User:           "user",
    71  		RequestBody:    []byte("sample body"),
    72  		RequestURI:     "www.authz.com/auth",
    73  		RequestMethod:  http.MethodGet,
    74  		RequestHeaders: map[string]string{"header": "value"},
    75  	}
    76  	server.replayResponse = Response{
    77  		Allow: true,
    78  		Msg:   "Sample message",
    79  	}
    80  
    81  	actualResponse, err := authZPlugin.AuthZRequest(&request)
    82  	if err != nil {
    83  		t.Fatalf("Failed to authorize request %v", err)
    84  	}
    85  
    86  	if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
    87  		t.Fatal("Response must be equal")
    88  	}
    89  	if !reflect.DeepEqual(request, server.recordedRequest) {
    90  		t.Fatal("Requests must be equal")
    91  	}
    92  }
    93  
    94  func TestAuthZResponsePlugin(t *testing.T) {
    95  	server := authZPluginTestServer{t: t}
    96  	server.start()
    97  	defer server.stop()
    98  
    99  	authZPlugin := createTestPlugin(t, server.socketAddress())
   100  
   101  	request := Request{
   102  		User:        "user",
   103  		RequestURI:  "something.com/auth",
   104  		RequestBody: []byte("sample body"),
   105  	}
   106  	server.replayResponse = Response{
   107  		Allow: true,
   108  		Msg:   "Sample message",
   109  	}
   110  
   111  	actualResponse, err := authZPlugin.AuthZResponse(&request)
   112  	if err != nil {
   113  		t.Fatalf("Failed to authorize request %v", err)
   114  	}
   115  
   116  	if !reflect.DeepEqual(server.replayResponse, *actualResponse) {
   117  		t.Fatal("Response must be equal")
   118  	}
   119  	if !reflect.DeepEqual(request, server.recordedRequest) {
   120  		t.Fatal("Requests must be equal")
   121  	}
   122  }
   123  
   124  func TestResponseModifier(t *testing.T) {
   125  	r := httptest.NewRecorder()
   126  	m := NewResponseModifier(r)
   127  	m.Header().Set("h1", "v1")
   128  	m.Write([]byte("body"))
   129  	m.WriteHeader(http.StatusInternalServerError)
   130  
   131  	m.FlushAll()
   132  	if r.Header().Get("h1") != "v1" {
   133  		t.Fatalf("Header value must exists %s", r.Header().Get("h1"))
   134  	}
   135  	if !reflect.DeepEqual(r.Body.Bytes(), []byte("body")) {
   136  		t.Fatalf("Body value must exists %s", r.Body.Bytes())
   137  	}
   138  	if r.Code != http.StatusInternalServerError {
   139  		t.Fatalf("Status code must be correct %d", r.Code)
   140  	}
   141  }
   142  
   143  func TestDrainBody(t *testing.T) {
   144  	tests := []struct {
   145  		length             int // length is the message length send to drainBody
   146  		expectedBodyLength int // expectedBodyLength is the expected body length after drainBody is called
   147  	}{
   148  		{10, 10},                           // Small message size
   149  		{maxBodySize - 1, maxBodySize - 1}, // Max message size
   150  		{maxBodySize * 2, 0},               // Large message size (skip copying body)
   151  
   152  	}
   153  
   154  	for _, test := range tests {
   155  		msg := strings.Repeat("a", test.length)
   156  		body, closer, err := drainBody(io.NopCloser(bytes.NewReader([]byte(msg))))
   157  		if err != nil {
   158  			t.Fatal(err)
   159  		}
   160  		if len(body) != test.expectedBodyLength {
   161  			t.Fatalf("Body must be copied, actual length: '%d'", len(body))
   162  		}
   163  		if closer == nil {
   164  			t.Fatal("Closer must not be nil")
   165  		}
   166  		modified, err := io.ReadAll(closer)
   167  		if err != nil {
   168  			t.Fatalf("Error must not be nil: '%v'", err)
   169  		}
   170  		if len(modified) != len(msg) {
   171  			t.Fatalf("Result should not be truncated. Original length: '%d', new length: '%d'", len(msg), len(modified))
   172  		}
   173  	}
   174  }
   175  
   176  func TestSendBody(t *testing.T) {
   177  	var (
   178  		url       = "nothing.com"
   179  		testcases = []struct {
   180  			contentType string
   181  			expected    bool
   182  		}{
   183  			{
   184  				contentType: "application/json",
   185  				expected:    true,
   186  			},
   187  			{
   188  				contentType: "Application/json",
   189  				expected:    true,
   190  			},
   191  			{
   192  				contentType: "application/JSON",
   193  				expected:    true,
   194  			},
   195  			{
   196  				contentType: "APPLICATION/JSON",
   197  				expected:    true,
   198  			},
   199  			{
   200  				contentType: "application/json; charset=utf-8",
   201  				expected:    true,
   202  			},
   203  			{
   204  				contentType: "application/json;charset=utf-8",
   205  				expected:    true,
   206  			},
   207  			{
   208  				contentType: "application/json; charset=UTF8",
   209  				expected:    true,
   210  			},
   211  			{
   212  				contentType: "application/json;charset=UTF8",
   213  				expected:    true,
   214  			},
   215  			{
   216  				contentType: "text/html",
   217  				expected:    false,
   218  			},
   219  			{
   220  				contentType: "",
   221  				expected:    false,
   222  			},
   223  		}
   224  	)
   225  
   226  	for _, testcase := range testcases {
   227  		header := http.Header{}
   228  		header.Set("Content-Type", testcase.contentType)
   229  
   230  		if b := sendBody(url, header); b != testcase.expected {
   231  			t.Fatalf("Unexpected Content-Type; Expected: %t, Actual: %t", testcase.expected, b)
   232  		}
   233  	}
   234  }
   235  
   236  func TestResponseModifierOverride(t *testing.T) {
   237  	r := httptest.NewRecorder()
   238  	m := NewResponseModifier(r)
   239  	m.Header().Set("h1", "v1")
   240  	m.Write([]byte("body"))
   241  	m.WriteHeader(http.StatusInternalServerError)
   242  
   243  	overrideHeader := make(http.Header)
   244  	overrideHeader.Add("h1", "v2")
   245  	overrideHeaderBytes, err := json.Marshal(overrideHeader)
   246  	if err != nil {
   247  		t.Fatalf("override header failed %v", err)
   248  	}
   249  
   250  	m.OverrideHeader(overrideHeaderBytes)
   251  	m.OverrideBody([]byte("override body"))
   252  	m.OverrideStatusCode(http.StatusNotFound)
   253  	m.FlushAll()
   254  	if r.Header().Get("h1") != "v2" {
   255  		t.Fatalf("Header value must exists %s", r.Header().Get("h1"))
   256  	}
   257  	if !reflect.DeepEqual(r.Body.Bytes(), []byte("override body")) {
   258  		t.Fatalf("Body value must exists %s", r.Body.Bytes())
   259  	}
   260  	if r.Code != http.StatusNotFound {
   261  		t.Fatalf("Status code must be correct %d", r.Code)
   262  	}
   263  }
   264  
   265  // createTestPlugin creates a new sample authorization plugin
   266  func createTestPlugin(t *testing.T, socketAddress string) *authorizationPlugin {
   267  	client, err := plugins.NewClient("unix:///"+socketAddress, &tlsconfig.Options{InsecureSkipVerify: true})
   268  	if err != nil {
   269  		t.Fatalf("Failed to create client %v", err)
   270  	}
   271  
   272  	return &authorizationPlugin{name: "plugin", plugin: client}
   273  }
   274  
   275  // AuthZPluginTestServer is a simple server that implements the authZ plugin interface
   276  type authZPluginTestServer struct {
   277  	listener net.Listener
   278  	t        *testing.T
   279  	// request stores the request sent from the daemon to the plugin
   280  	recordedRequest Request
   281  	// response stores the response sent from the plugin to the daemon
   282  	replayResponse Response
   283  	server         *httptest.Server
   284  	tmpDir         string
   285  }
   286  
   287  func (t *authZPluginTestServer) socketAddress() string {
   288  	return path.Join(t.tmpDir, pluginAddress)
   289  }
   290  
   291  // start starts the test server that implements the plugin
   292  func (t *authZPluginTestServer) start() {
   293  	var err error
   294  	t.tmpDir, err = os.MkdirTemp("", "authz")
   295  	if err != nil {
   296  		t.t.Fatal(err)
   297  	}
   298  
   299  	r := mux.NewRouter()
   300  	l, err := net.Listen("unix", t.socketAddress())
   301  	if err != nil {
   302  		t.t.Fatal(err)
   303  	}
   304  	t.listener = l
   305  	r.HandleFunc("/Plugin.Activate", t.activate)
   306  	r.HandleFunc("/"+AuthZApiRequest, t.auth)
   307  	r.HandleFunc("/"+AuthZApiResponse, t.auth)
   308  	t.server = &httptest.Server{
   309  		Listener: l,
   310  		Config: &http.Server{
   311  			Handler: r,
   312  			Addr:    pluginAddress,
   313  		},
   314  	}
   315  	t.server.Start()
   316  }
   317  
   318  // stop stops the test server that implements the plugin
   319  func (t *authZPluginTestServer) stop() {
   320  	t.server.Close()
   321  	_ = os.RemoveAll(t.tmpDir)
   322  	if t.listener != nil {
   323  		t.listener.Close()
   324  	}
   325  }
   326  
   327  // auth is a used to record/replay the authentication api messages
   328  func (t *authZPluginTestServer) auth(w http.ResponseWriter, r *http.Request) {
   329  	t.recordedRequest = Request{}
   330  	body, err := io.ReadAll(r.Body)
   331  	if err != nil {
   332  		t.t.Fatal(err)
   333  	}
   334  	r.Body.Close()
   335  	json.Unmarshal(body, &t.recordedRequest)
   336  	b, err := json.Marshal(t.replayResponse)
   337  	if err != nil {
   338  		t.t.Fatal(err)
   339  	}
   340  	w.Write(b)
   341  }
   342  
   343  func (t *authZPluginTestServer) activate(w http.ResponseWriter, r *http.Request) {
   344  	b, err := json.Marshal(plugins.Manifest{Implements: []string{AuthZApiImplements}})
   345  	if err != nil {
   346  		t.t.Fatal(err)
   347  	}
   348  	w.Write(b)
   349  }