github.com/moby/docker@v26.1.3+incompatible/integration/plugin/authz/main_test.go (about)

     1  //go:build !windows
     2  
     3  package authz // import "github.com/docker/docker/integration/plugin/authz"
     4  
     5  import (
     6  	"context"
     7  	"encoding/json"
     8  	"fmt"
     9  	"io"
    10  	"net/http"
    11  	"net/http/httptest"
    12  	"os"
    13  	"strings"
    14  	"testing"
    15  
    16  	"github.com/docker/docker/pkg/authorization"
    17  	"github.com/docker/docker/pkg/plugins"
    18  	"github.com/docker/docker/testutil"
    19  	"github.com/docker/docker/testutil/daemon"
    20  	"github.com/docker/docker/testutil/environment"
    21  	"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
    22  	"go.opentelemetry.io/otel"
    23  	"go.opentelemetry.io/otel/attribute"
    24  	"go.opentelemetry.io/otel/codes"
    25  	"gotest.tools/v3/skip"
    26  )
    27  
    28  var (
    29  	testEnv     *environment.Execution
    30  	d           *daemon.Daemon
    31  	server      *httptest.Server
    32  	baseContext context.Context
    33  )
    34  
    35  func TestMain(m *testing.M) {
    36  	shutdown := testutil.ConfigureTracing()
    37  
    38  	ctx, span := otel.Tracer("").Start(context.Background(), "integration/plugin/authz.TestMain")
    39  	baseContext = ctx
    40  
    41  	var err error
    42  	testEnv, err = environment.New(ctx)
    43  	if err != nil {
    44  		span.SetStatus(codes.Error, err.Error())
    45  		span.End()
    46  		shutdown(ctx)
    47  		panic(err)
    48  	}
    49  	err = environment.EnsureFrozenImagesLinux(ctx, testEnv)
    50  	if err != nil {
    51  		span.SetStatus(codes.Error, err.Error())
    52  		span.End()
    53  		shutdown(ctx)
    54  		panic(err)
    55  	}
    56  
    57  	testEnv.Print()
    58  	setupSuite()
    59  	exitCode := m.Run()
    60  	teardownSuite()
    61  
    62  	if exitCode != 0 {
    63  		span.SetAttributes(attribute.Int("exit", exitCode))
    64  		span.SetStatus(codes.Error, "m.Run() exited with non-zero exit code")
    65  	}
    66  	shutdown(ctx)
    67  
    68  	os.Exit(exitCode)
    69  }
    70  
    71  func setupTest(t *testing.T) context.Context {
    72  	skip.If(t, testEnv.IsRemoteDaemon, "cannot run daemon when remote daemon")
    73  	skip.If(t, testEnv.DaemonInfo.OSType == "windows")
    74  	skip.If(t, testEnv.IsRootless, "rootless mode has different view of localhost")
    75  
    76  	ctx := testutil.StartSpan(baseContext, t)
    77  	environment.ProtectAll(ctx, t, testEnv)
    78  
    79  	d = daemon.New(t, daemon.WithExperimental())
    80  
    81  	t.Cleanup(func() {
    82  		if d != nil {
    83  			d.Stop(t)
    84  		}
    85  		testEnv.Clean(ctx, t)
    86  	})
    87  	return ctx
    88  }
    89  
    90  func setupSuite() {
    91  	mux := http.NewServeMux()
    92  	server = httptest.NewServer(otelhttp.NewHandler(mux, ""))
    93  
    94  	mux.HandleFunc("/Plugin.Activate", func(w http.ResponseWriter, r *http.Request) {
    95  		b, err := json.Marshal(plugins.Manifest{Implements: []string{authorization.AuthZApiImplements}})
    96  		if err != nil {
    97  			panic("could not marshal json for /Plugin.Activate: " + err.Error())
    98  		}
    99  		w.Write(b)
   100  	})
   101  
   102  	mux.HandleFunc("/AuthZPlugin.AuthZReq", func(w http.ResponseWriter, r *http.Request) {
   103  		defer r.Body.Close()
   104  		body, err := io.ReadAll(r.Body)
   105  		if err != nil {
   106  			panic("could not read body for /AuthZPlugin.AuthZReq: " + err.Error())
   107  		}
   108  		authReq := authorization.Request{}
   109  		err = json.Unmarshal(body, &authReq)
   110  		if err != nil {
   111  			panic("could not unmarshal json for /AuthZPlugin.AuthZReq: " + err.Error())
   112  		}
   113  
   114  		assertBody(authReq.RequestURI, authReq.RequestHeaders, authReq.RequestBody)
   115  		assertAuthHeaders(authReq.RequestHeaders)
   116  
   117  		// Count only server version api
   118  		if strings.HasSuffix(authReq.RequestURI, serverVersionAPI) {
   119  			ctrl.versionReqCount++
   120  		}
   121  
   122  		ctrl.requestsURIs = append(ctrl.requestsURIs, authReq.RequestURI)
   123  
   124  		reqRes := ctrl.reqRes
   125  		if isAllowed(authReq.RequestURI) {
   126  			reqRes = authorization.Response{Allow: true}
   127  		}
   128  		if reqRes.Err != "" {
   129  			w.WriteHeader(http.StatusInternalServerError)
   130  		}
   131  		b, err := json.Marshal(reqRes)
   132  		if err != nil {
   133  			panic("could not marshal json for /AuthZPlugin.AuthZReq: " + err.Error())
   134  		}
   135  
   136  		ctrl.reqUser = authReq.User
   137  		w.Write(b)
   138  	})
   139  
   140  	mux.HandleFunc("/AuthZPlugin.AuthZRes", func(w http.ResponseWriter, r *http.Request) {
   141  		defer r.Body.Close()
   142  		body, err := io.ReadAll(r.Body)
   143  		if err != nil {
   144  			panic("could not read body for /AuthZPlugin.AuthZRes: " + err.Error())
   145  		}
   146  		authReq := authorization.Request{}
   147  		err = json.Unmarshal(body, &authReq)
   148  		if err != nil {
   149  			panic("could not unmarshal json for /AuthZPlugin.AuthZRes: " + err.Error())
   150  		}
   151  
   152  		assertBody(authReq.RequestURI, authReq.ResponseHeaders, authReq.ResponseBody)
   153  		assertAuthHeaders(authReq.ResponseHeaders)
   154  
   155  		// Count only server version api
   156  		if strings.HasSuffix(authReq.RequestURI, serverVersionAPI) {
   157  			ctrl.versionResCount++
   158  		}
   159  		resRes := ctrl.resRes
   160  		if isAllowed(authReq.RequestURI) {
   161  			resRes = authorization.Response{Allow: true}
   162  		}
   163  		if resRes.Err != "" {
   164  			w.WriteHeader(http.StatusInternalServerError)
   165  		}
   166  		b, err := json.Marshal(resRes)
   167  		if err != nil {
   168  			panic("could not marshal json for /AuthZPlugin.AuthZRes: " + err.Error())
   169  		}
   170  		ctrl.resUser = authReq.User
   171  		w.Write(b)
   172  	})
   173  }
   174  
   175  func teardownSuite() {
   176  	if server == nil {
   177  		return
   178  	}
   179  
   180  	server.Close()
   181  }
   182  
   183  // assertAuthHeaders validates authentication headers are removed
   184  func assertAuthHeaders(headers map[string]string) error {
   185  	for k := range headers {
   186  		if strings.Contains(strings.ToLower(k), "auth") || strings.Contains(strings.ToLower(k), "x-registry") {
   187  			panic(fmt.Sprintf("Found authentication headers in request '%v'", headers))
   188  		}
   189  	}
   190  	return nil
   191  }
   192  
   193  // assertBody asserts that body is removed for non text/json requests
   194  func assertBody(requestURI string, headers map[string]string, body []byte) {
   195  	if strings.Contains(strings.ToLower(requestURI), "auth") && len(body) > 0 {
   196  		panic("Body included for authentication endpoint " + string(body))
   197  	}
   198  
   199  	for k, v := range headers {
   200  		if strings.EqualFold(k, "Content-Type") && strings.HasPrefix(v, "text/") || v == "application/json" {
   201  			return
   202  		}
   203  	}
   204  	if len(body) > 0 {
   205  		panic(fmt.Sprintf("Body included while it should not (Headers: '%v')", headers))
   206  	}
   207  }