github.com/jwhonce/docker@v0.6.7-0.20190327063223-da823cf3a5a3/integration/plugin/authz/main_test.go (about)

     1  // +build !windows
     2  
     3  package authz // import "github.com/docker/docker/integration/plugin/authz"
     4  
     5  import (
     6  	"encoding/json"
     7  	"fmt"
     8  	"io/ioutil"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"os"
    12  	"strings"
    13  	"testing"
    14  
    15  	"github.com/docker/docker/internal/test/daemon"
    16  	"github.com/docker/docker/internal/test/environment"
    17  	"github.com/docker/docker/pkg/authorization"
    18  	"github.com/docker/docker/pkg/plugins"
    19  	"gotest.tools/skip"
    20  )
    21  
    22  var (
    23  	testEnv *environment.Execution
    24  	d       *daemon.Daemon
    25  	server  *httptest.Server
    26  )
    27  
    28  func TestMain(m *testing.M) {
    29  	var err error
    30  	testEnv, err = environment.New()
    31  	if err != nil {
    32  		fmt.Println(err)
    33  		os.Exit(1)
    34  	}
    35  	err = environment.EnsureFrozenImagesLinux(testEnv)
    36  	if err != nil {
    37  		fmt.Println(err)
    38  		os.Exit(1)
    39  	}
    40  
    41  	testEnv.Print()
    42  	setupSuite()
    43  	exitCode := m.Run()
    44  	teardownSuite()
    45  
    46  	os.Exit(exitCode)
    47  }
    48  
    49  func setupTest(t *testing.T) func() {
    50  	skip.If(t, testEnv.IsRemoteDaemon, "cannot run daemon when remote daemon")
    51  	skip.If(t, testEnv.DaemonInfo.OSType == "windows")
    52  	environment.ProtectAll(t, testEnv)
    53  
    54  	d = daemon.New(t, daemon.WithExperimental)
    55  
    56  	return func() {
    57  		if d != nil {
    58  			d.Stop(t)
    59  		}
    60  		testEnv.Clean(t)
    61  	}
    62  }
    63  
    64  func setupSuite() {
    65  	mux := http.NewServeMux()
    66  	server = httptest.NewServer(mux)
    67  
    68  	mux.HandleFunc("/Plugin.Activate", func(w http.ResponseWriter, r *http.Request) {
    69  		b, err := json.Marshal(plugins.Manifest{Implements: []string{authorization.AuthZApiImplements}})
    70  		if err != nil {
    71  			panic("could not marshal json for /Plugin.Activate: " + err.Error())
    72  		}
    73  		w.Write(b)
    74  	})
    75  
    76  	mux.HandleFunc("/AuthZPlugin.AuthZReq", func(w http.ResponseWriter, r *http.Request) {
    77  		defer r.Body.Close()
    78  		body, err := ioutil.ReadAll(r.Body)
    79  		if err != nil {
    80  			panic("could not read body for /AuthZPlugin.AuthZReq: " + err.Error())
    81  		}
    82  		authReq := authorization.Request{}
    83  		err = json.Unmarshal(body, &authReq)
    84  		if err != nil {
    85  			panic("could not unmarshal json for /AuthZPlugin.AuthZReq: " + err.Error())
    86  		}
    87  
    88  		assertBody(authReq.RequestURI, authReq.RequestHeaders, authReq.RequestBody)
    89  		assertAuthHeaders(authReq.RequestHeaders)
    90  
    91  		// Count only server version api
    92  		if strings.HasSuffix(authReq.RequestURI, serverVersionAPI) {
    93  			ctrl.versionReqCount++
    94  		}
    95  
    96  		ctrl.requestsURIs = append(ctrl.requestsURIs, authReq.RequestURI)
    97  
    98  		reqRes := ctrl.reqRes
    99  		if isAllowed(authReq.RequestURI) {
   100  			reqRes = authorization.Response{Allow: true}
   101  		}
   102  		if reqRes.Err != "" {
   103  			w.WriteHeader(http.StatusInternalServerError)
   104  		}
   105  		b, err := json.Marshal(reqRes)
   106  		if err != nil {
   107  			panic("could not marshal json for /AuthZPlugin.AuthZReq: " + err.Error())
   108  		}
   109  
   110  		ctrl.reqUser = authReq.User
   111  		w.Write(b)
   112  	})
   113  
   114  	mux.HandleFunc("/AuthZPlugin.AuthZRes", func(w http.ResponseWriter, r *http.Request) {
   115  		defer r.Body.Close()
   116  		body, err := ioutil.ReadAll(r.Body)
   117  		if err != nil {
   118  			panic("could not read body for /AuthZPlugin.AuthZRes: " + err.Error())
   119  		}
   120  		authReq := authorization.Request{}
   121  		err = json.Unmarshal(body, &authReq)
   122  		if err != nil {
   123  			panic("could not unmarshal json for /AuthZPlugin.AuthZRes: " + err.Error())
   124  		}
   125  
   126  		assertBody(authReq.RequestURI, authReq.ResponseHeaders, authReq.ResponseBody)
   127  		assertAuthHeaders(authReq.ResponseHeaders)
   128  
   129  		// Count only server version api
   130  		if strings.HasSuffix(authReq.RequestURI, serverVersionAPI) {
   131  			ctrl.versionResCount++
   132  		}
   133  		resRes := ctrl.resRes
   134  		if isAllowed(authReq.RequestURI) {
   135  			resRes = authorization.Response{Allow: true}
   136  		}
   137  		if resRes.Err != "" {
   138  			w.WriteHeader(http.StatusInternalServerError)
   139  		}
   140  		b, err := json.Marshal(resRes)
   141  		if err != nil {
   142  			panic("could not marshal json for /AuthZPlugin.AuthZRes: " + err.Error())
   143  		}
   144  		ctrl.resUser = authReq.User
   145  		w.Write(b)
   146  	})
   147  }
   148  
   149  func teardownSuite() {
   150  	if server == nil {
   151  		return
   152  	}
   153  
   154  	server.Close()
   155  }
   156  
   157  // assertAuthHeaders validates authentication headers are removed
   158  func assertAuthHeaders(headers map[string]string) error {
   159  	for k := range headers {
   160  		if strings.Contains(strings.ToLower(k), "auth") || strings.Contains(strings.ToLower(k), "x-registry") {
   161  			panic(fmt.Sprintf("Found authentication headers in request '%v'", headers))
   162  		}
   163  	}
   164  	return nil
   165  }
   166  
   167  // assertBody asserts that body is removed for non text/json requests
   168  func assertBody(requestURI string, headers map[string]string, body []byte) {
   169  	if strings.Contains(strings.ToLower(requestURI), "auth") && len(body) > 0 {
   170  		panic("Body included for authentication endpoint " + string(body))
   171  	}
   172  
   173  	for k, v := range headers {
   174  		if strings.EqualFold(k, "Content-Type") && strings.HasPrefix(v, "text/") || v == "application/json" {
   175  			return
   176  		}
   177  	}
   178  	if len(body) > 0 {
   179  		panic(fmt.Sprintf("Body included while it should not (Headers: '%v')", headers))
   180  	}
   181  }