github.com/demonoid81/moby@v0.0.0-20200517203328-62dd8e17c460/integration/plugin/authz/main_test.go (about)

     1  // +build !windows
     2  
     3  package authz // import "github.com/demonoid81/moby/integration/plugin/authz"
     4  
     5  import (
     6  	"encoding/json"
     7  	"fmt"
     8  	"io/ioutil"
     9  	"net/http"
    10  	"net/http/httptest"
    11  	"os"
    12  	"strings"
    13  	"testing"
    14  
    15  	"github.com/demonoid81/moby/pkg/authorization"
    16  	"github.com/demonoid81/moby/pkg/plugins"
    17  	"github.com/demonoid81/moby/testutil/daemon"
    18  	"github.com/demonoid81/moby/testutil/environment"
    19  	"gotest.tools/v3/skip"
    20  )
    21  
    22  var (
    23  	testEnv *environment.Execution
    24  	d       *daemon.Daemon
    25  	server  *httptest.Server
    26  )
    27  
    28  func TestMain(m *testing.M) {
    29  	var err error
    30  	testEnv, err = environment.New()
    31  	if err != nil {
    32  		fmt.Println(err)
    33  		os.Exit(1)
    34  	}
    35  	err = environment.EnsureFrozenImagesLinux(testEnv)
    36  	if err != nil {
    37  		fmt.Println(err)
    38  		os.Exit(1)
    39  	}
    40  
    41  	testEnv.Print()
    42  	setupSuite()
    43  	exitCode := m.Run()
    44  	teardownSuite()
    45  
    46  	os.Exit(exitCode)
    47  }
    48  
    49  func setupTest(t *testing.T) func() {
    50  	skip.If(t, testEnv.IsRemoteDaemon, "cannot run daemon when remote daemon")
    51  	skip.If(t, testEnv.DaemonInfo.OSType == "windows")
    52  	skip.If(t, testEnv.IsRootless, "rootless mode has different view of localhost")
    53  	environment.ProtectAll(t, testEnv)
    54  
    55  	d = daemon.New(t, daemon.WithExperimental())
    56  
    57  	return func() {
    58  		if d != nil {
    59  			d.Stop(t)
    60  		}
    61  		testEnv.Clean(t)
    62  	}
    63  }
    64  
    65  func setupSuite() {
    66  	mux := http.NewServeMux()
    67  	server = httptest.NewServer(mux)
    68  
    69  	mux.HandleFunc("/Plugin.Activate", func(w http.ResponseWriter, r *http.Request) {
    70  		b, err := json.Marshal(plugins.Manifest{Implements: []string{authorization.AuthZApiImplements}})
    71  		if err != nil {
    72  			panic("could not marshal json for /Plugin.Activate: " + err.Error())
    73  		}
    74  		w.Write(b)
    75  	})
    76  
    77  	mux.HandleFunc("/AuthZPlugin.AuthZReq", func(w http.ResponseWriter, r *http.Request) {
    78  		defer r.Body.Close()
    79  		body, err := ioutil.ReadAll(r.Body)
    80  		if err != nil {
    81  			panic("could not read body for /AuthZPlugin.AuthZReq: " + err.Error())
    82  		}
    83  		authReq := authorization.Request{}
    84  		err = json.Unmarshal(body, &authReq)
    85  		if err != nil {
    86  			panic("could not unmarshal json for /AuthZPlugin.AuthZReq: " + err.Error())
    87  		}
    88  
    89  		assertBody(authReq.RequestURI, authReq.RequestHeaders, authReq.RequestBody)
    90  		assertAuthHeaders(authReq.RequestHeaders)
    91  
    92  		// Count only server version api
    93  		if strings.HasSuffix(authReq.RequestURI, serverVersionAPI) {
    94  			ctrl.versionReqCount++
    95  		}
    96  
    97  		ctrl.requestsURIs = append(ctrl.requestsURIs, authReq.RequestURI)
    98  
    99  		reqRes := ctrl.reqRes
   100  		if isAllowed(authReq.RequestURI) {
   101  			reqRes = authorization.Response{Allow: true}
   102  		}
   103  		if reqRes.Err != "" {
   104  			w.WriteHeader(http.StatusInternalServerError)
   105  		}
   106  		b, err := json.Marshal(reqRes)
   107  		if err != nil {
   108  			panic("could not marshal json for /AuthZPlugin.AuthZReq: " + err.Error())
   109  		}
   110  
   111  		ctrl.reqUser = authReq.User
   112  		w.Write(b)
   113  	})
   114  
   115  	mux.HandleFunc("/AuthZPlugin.AuthZRes", func(w http.ResponseWriter, r *http.Request) {
   116  		defer r.Body.Close()
   117  		body, err := ioutil.ReadAll(r.Body)
   118  		if err != nil {
   119  			panic("could not read body for /AuthZPlugin.AuthZRes: " + err.Error())
   120  		}
   121  		authReq := authorization.Request{}
   122  		err = json.Unmarshal(body, &authReq)
   123  		if err != nil {
   124  			panic("could not unmarshal json for /AuthZPlugin.AuthZRes: " + err.Error())
   125  		}
   126  
   127  		assertBody(authReq.RequestURI, authReq.ResponseHeaders, authReq.ResponseBody)
   128  		assertAuthHeaders(authReq.ResponseHeaders)
   129  
   130  		// Count only server version api
   131  		if strings.HasSuffix(authReq.RequestURI, serverVersionAPI) {
   132  			ctrl.versionResCount++
   133  		}
   134  		resRes := ctrl.resRes
   135  		if isAllowed(authReq.RequestURI) {
   136  			resRes = authorization.Response{Allow: true}
   137  		}
   138  		if resRes.Err != "" {
   139  			w.WriteHeader(http.StatusInternalServerError)
   140  		}
   141  		b, err := json.Marshal(resRes)
   142  		if err != nil {
   143  			panic("could not marshal json for /AuthZPlugin.AuthZRes: " + err.Error())
   144  		}
   145  		ctrl.resUser = authReq.User
   146  		w.Write(b)
   147  	})
   148  }
   149  
   150  func teardownSuite() {
   151  	if server == nil {
   152  		return
   153  	}
   154  
   155  	server.Close()
   156  }
   157  
   158  // assertAuthHeaders validates authentication headers are removed
   159  func assertAuthHeaders(headers map[string]string) error {
   160  	for k := range headers {
   161  		if strings.Contains(strings.ToLower(k), "auth") || strings.Contains(strings.ToLower(k), "x-registry") {
   162  			panic(fmt.Sprintf("Found authentication headers in request '%v'", headers))
   163  		}
   164  	}
   165  	return nil
   166  }
   167  
   168  // assertBody asserts that body is removed for non text/json requests
   169  func assertBody(requestURI string, headers map[string]string, body []byte) {
   170  	if strings.Contains(strings.ToLower(requestURI), "auth") && len(body) > 0 {
   171  		panic("Body included for authentication endpoint " + string(body))
   172  	}
   173  
   174  	for k, v := range headers {
   175  		if strings.EqualFold(k, "Content-Type") && strings.HasPrefix(v, "text/") || v == "application/json" {
   176  			return
   177  		}
   178  	}
   179  	if len(body) > 0 {
   180  		panic(fmt.Sprintf("Body included while it should not (Headers: '%v')", headers))
   181  	}
   182  }