github.com/rumpl/bof@v23.0.0-rc.2+incompatible/integration/plugin/authz/main_test.go (about)

     1  //go:build !windows
     2  // +build !windows
     3  
     4  package authz // import "github.com/docker/docker/integration/plugin/authz"
     5  
     6  import (
     7  	"encoding/json"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"net/http/httptest"
    12  	"os"
    13  	"strings"
    14  	"testing"
    15  
    16  	"github.com/docker/docker/pkg/authorization"
    17  	"github.com/docker/docker/pkg/plugins"
    18  	"github.com/docker/docker/testutil/daemon"
    19  	"github.com/docker/docker/testutil/environment"
    20  	"gotest.tools/v3/skip"
    21  )
    22  
    23  var (
    24  	testEnv *environment.Execution
    25  	d       *daemon.Daemon
    26  	server  *httptest.Server
    27  )
    28  
    29  func TestMain(m *testing.M) {
    30  	var err error
    31  	testEnv, err = environment.New()
    32  	if err != nil {
    33  		fmt.Println(err)
    34  		os.Exit(1)
    35  	}
    36  	err = environment.EnsureFrozenImagesLinux(testEnv)
    37  	if err != nil {
    38  		fmt.Println(err)
    39  		os.Exit(1)
    40  	}
    41  
    42  	testEnv.Print()
    43  	setupSuite()
    44  	exitCode := m.Run()
    45  	teardownSuite()
    46  
    47  	os.Exit(exitCode)
    48  }
    49  
    50  func setupTest(t *testing.T) func() {
    51  	skip.If(t, testEnv.IsRemoteDaemon, "cannot run daemon when remote daemon")
    52  	skip.If(t, testEnv.DaemonInfo.OSType == "windows")
    53  	skip.If(t, testEnv.IsRootless, "rootless mode has different view of localhost")
    54  	environment.ProtectAll(t, testEnv)
    55  
    56  	d = daemon.New(t, daemon.WithExperimental())
    57  
    58  	return func() {
    59  		if d != nil {
    60  			d.Stop(t)
    61  		}
    62  		testEnv.Clean(t)
    63  	}
    64  }
    65  
    66  func setupSuite() {
    67  	mux := http.NewServeMux()
    68  	server = httptest.NewServer(mux)
    69  
    70  	mux.HandleFunc("/Plugin.Activate", func(w http.ResponseWriter, r *http.Request) {
    71  		b, err := json.Marshal(plugins.Manifest{Implements: []string{authorization.AuthZApiImplements}})
    72  		if err != nil {
    73  			panic("could not marshal json for /Plugin.Activate: " + err.Error())
    74  		}
    75  		w.Write(b)
    76  	})
    77  
    78  	mux.HandleFunc("/AuthZPlugin.AuthZReq", func(w http.ResponseWriter, r *http.Request) {
    79  		defer r.Body.Close()
    80  		body, err := io.ReadAll(r.Body)
    81  		if err != nil {
    82  			panic("could not read body for /AuthZPlugin.AuthZReq: " + err.Error())
    83  		}
    84  		authReq := authorization.Request{}
    85  		err = json.Unmarshal(body, &authReq)
    86  		if err != nil {
    87  			panic("could not unmarshal json for /AuthZPlugin.AuthZReq: " + err.Error())
    88  		}
    89  
    90  		assertBody(authReq.RequestURI, authReq.RequestHeaders, authReq.RequestBody)
    91  		assertAuthHeaders(authReq.RequestHeaders)
    92  
    93  		// Count only server version api
    94  		if strings.HasSuffix(authReq.RequestURI, serverVersionAPI) {
    95  			ctrl.versionReqCount++
    96  		}
    97  
    98  		ctrl.requestsURIs = append(ctrl.requestsURIs, authReq.RequestURI)
    99  
   100  		reqRes := ctrl.reqRes
   101  		if isAllowed(authReq.RequestURI) {
   102  			reqRes = authorization.Response{Allow: true}
   103  		}
   104  		if reqRes.Err != "" {
   105  			w.WriteHeader(http.StatusInternalServerError)
   106  		}
   107  		b, err := json.Marshal(reqRes)
   108  		if err != nil {
   109  			panic("could not marshal json for /AuthZPlugin.AuthZReq: " + err.Error())
   110  		}
   111  
   112  		ctrl.reqUser = authReq.User
   113  		w.Write(b)
   114  	})
   115  
   116  	mux.HandleFunc("/AuthZPlugin.AuthZRes", func(w http.ResponseWriter, r *http.Request) {
   117  		defer r.Body.Close()
   118  		body, err := io.ReadAll(r.Body)
   119  		if err != nil {
   120  			panic("could not read body for /AuthZPlugin.AuthZRes: " + err.Error())
   121  		}
   122  		authReq := authorization.Request{}
   123  		err = json.Unmarshal(body, &authReq)
   124  		if err != nil {
   125  			panic("could not unmarshal json for /AuthZPlugin.AuthZRes: " + err.Error())
   126  		}
   127  
   128  		assertBody(authReq.RequestURI, authReq.ResponseHeaders, authReq.ResponseBody)
   129  		assertAuthHeaders(authReq.ResponseHeaders)
   130  
   131  		// Count only server version api
   132  		if strings.HasSuffix(authReq.RequestURI, serverVersionAPI) {
   133  			ctrl.versionResCount++
   134  		}
   135  		resRes := ctrl.resRes
   136  		if isAllowed(authReq.RequestURI) {
   137  			resRes = authorization.Response{Allow: true}
   138  		}
   139  		if resRes.Err != "" {
   140  			w.WriteHeader(http.StatusInternalServerError)
   141  		}
   142  		b, err := json.Marshal(resRes)
   143  		if err != nil {
   144  			panic("could not marshal json for /AuthZPlugin.AuthZRes: " + err.Error())
   145  		}
   146  		ctrl.resUser = authReq.User
   147  		w.Write(b)
   148  	})
   149  }
   150  
   151  func teardownSuite() {
   152  	if server == nil {
   153  		return
   154  	}
   155  
   156  	server.Close()
   157  }
   158  
   159  // assertAuthHeaders validates authentication headers are removed
   160  func assertAuthHeaders(headers map[string]string) error {
   161  	for k := range headers {
   162  		if strings.Contains(strings.ToLower(k), "auth") || strings.Contains(strings.ToLower(k), "x-registry") {
   163  			panic(fmt.Sprintf("Found authentication headers in request '%v'", headers))
   164  		}
   165  	}
   166  	return nil
   167  }
   168  
   169  // assertBody asserts that body is removed for non text/json requests
   170  func assertBody(requestURI string, headers map[string]string, body []byte) {
   171  	if strings.Contains(strings.ToLower(requestURI), "auth") && len(body) > 0 {
   172  		panic("Body included for authentication endpoint " + string(body))
   173  	}
   174  
   175  	for k, v := range headers {
   176  		if strings.EqualFold(k, "Content-Type") && strings.HasPrefix(v, "text/") || v == "application/json" {
   177  			return
   178  		}
   179  	}
   180  	if len(body) > 0 {
   181  		panic(fmt.Sprintf("Body included while it should not (Headers: '%v')", headers))
   182  	}
   183  }