github.com/janma/nomad@v0.11.3/command/agent/agent_endpoint_test.go (about)

     1  package agent
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"io/ioutil"
     9  	"net"
    10  	"net/http"
    11  	"net/http/httptest"
    12  	"net/url"
    13  	"os"
    14  	"strings"
    15  	"sync"
    16  	"syscall"
    17  	"testing"
    18  	"time"
    19  
    20  	msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
    21  	"github.com/hashicorp/nomad/acl"
    22  	"github.com/hashicorp/nomad/helper"
    23  	"github.com/hashicorp/nomad/helper/pool"
    24  	"github.com/hashicorp/nomad/nomad/mock"
    25  	"github.com/hashicorp/nomad/nomad/structs"
    26  	"github.com/hashicorp/nomad/testutil"
    27  	"github.com/stretchr/testify/assert"
    28  	"github.com/stretchr/testify/require"
    29  )
    30  
    31  func TestHTTP_AgentSelf(t *testing.T) {
    32  	t.Parallel()
    33  	require := require.New(t)
    34  
    35  	httpTest(t, nil, func(s *TestAgent) {
    36  		// Make the HTTP request
    37  		req, err := http.NewRequest("GET", "/v1/agent/self", nil)
    38  		require.NoError(err)
    39  		respW := httptest.NewRecorder()
    40  
    41  		// Make the request
    42  		obj, err := s.Server.AgentSelfRequest(respW, req)
    43  		require.NoError(err)
    44  
    45  		// Check the job
    46  		self := obj.(agentSelf)
    47  		require.NotNil(self.Config)
    48  		require.NotNil(self.Config.ACL)
    49  		require.NotEmpty(self.Stats)
    50  
    51  		// Check the Vault config
    52  		require.Empty(self.Config.Vault.Token)
    53  
    54  		// Assign a Vault token and require it is redacted.
    55  		s.Config.Vault.Token = "badc0deb-adc0-deba-dc0d-ebadc0debadc"
    56  		respW = httptest.NewRecorder()
    57  		obj, err = s.Server.AgentSelfRequest(respW, req)
    58  		require.NoError(err)
    59  		self = obj.(agentSelf)
    60  		require.Equal("<redacted>", self.Config.Vault.Token)
    61  
    62  		// Assign a ReplicationToken token and require it is redacted.
    63  		s.Config.ACL.ReplicationToken = "badc0deb-adc0-deba-dc0d-ebadc0debadc"
    64  		respW = httptest.NewRecorder()
    65  		obj, err = s.Server.AgentSelfRequest(respW, req)
    66  		require.NoError(err)
    67  		self = obj.(agentSelf)
    68  		require.Equal("<redacted>", self.Config.ACL.ReplicationToken)
    69  
    70  		// Check the Consul config
    71  		require.Empty(self.Config.Consul.Token)
    72  
    73  		// Assign a Consul token and require it is redacted.
    74  		s.Config.Consul.Token = "badc0deb-adc0-deba-dc0d-ebadc0debadc"
    75  		respW = httptest.NewRecorder()
    76  		obj, err = s.Server.AgentSelfRequest(respW, req)
    77  		require.NoError(err)
    78  		self = obj.(agentSelf)
    79  		require.Equal("<redacted>", self.Config.Consul.Token)
    80  
    81  		// Check the Circonus config
    82  		require.Empty(self.Config.Telemetry.CirconusAPIToken)
    83  
    84  		// Assign a Consul token and require it is redacted.
    85  		s.Config.Telemetry.CirconusAPIToken = "badc0deb-adc0-deba-dc0d-ebadc0debadc"
    86  		respW = httptest.NewRecorder()
    87  		obj, err = s.Server.AgentSelfRequest(respW, req)
    88  		require.NoError(err)
    89  		self = obj.(agentSelf)
    90  		require.Equal("<redacted>", self.Config.Telemetry.CirconusAPIToken)
    91  	})
    92  }
    93  
    94  func TestHTTP_AgentSelf_ACL(t *testing.T) {
    95  	t.Parallel()
    96  	require := require.New(t)
    97  
    98  	httpACLTest(t, nil, func(s *TestAgent) {
    99  		state := s.Agent.server.State()
   100  
   101  		// Make the HTTP request
   102  		req, err := http.NewRequest("GET", "/v1/agent/self", nil)
   103  		require.Nil(err)
   104  
   105  		// Try request without a token and expect failure
   106  		{
   107  			respW := httptest.NewRecorder()
   108  			_, err := s.Server.AgentSelfRequest(respW, req)
   109  			require.NotNil(err)
   110  			require.Equal(err.Error(), structs.ErrPermissionDenied.Error())
   111  		}
   112  
   113  		// Try request with an invalid token and expect failure
   114  		{
   115  			respW := httptest.NewRecorder()
   116  			token := mock.CreatePolicyAndToken(t, state, 1005, "invalid", mock.NodePolicy(acl.PolicyWrite))
   117  			setToken(req, token)
   118  			_, err := s.Server.AgentSelfRequest(respW, req)
   119  			require.NotNil(err)
   120  			require.Equal(err.Error(), structs.ErrPermissionDenied.Error())
   121  		}
   122  
   123  		// Try request with a valid token
   124  		{
   125  			respW := httptest.NewRecorder()
   126  			token := mock.CreatePolicyAndToken(t, state, 1007, "valid", mock.AgentPolicy(acl.PolicyWrite))
   127  			setToken(req, token)
   128  			obj, err := s.Server.AgentSelfRequest(respW, req)
   129  			require.Nil(err)
   130  
   131  			self := obj.(agentSelf)
   132  			require.NotNil(self.Config)
   133  			require.NotNil(self.Stats)
   134  		}
   135  
   136  		// Try request with a root token
   137  		{
   138  			respW := httptest.NewRecorder()
   139  			setToken(req, s.RootToken)
   140  			obj, err := s.Server.AgentSelfRequest(respW, req)
   141  			require.Nil(err)
   142  
   143  			self := obj.(agentSelf)
   144  			require.NotNil(self.Config)
   145  			require.NotNil(self.Stats)
   146  		}
   147  	})
   148  }
   149  
   150  func TestHTTP_AgentJoin(t *testing.T) {
   151  	t.Parallel()
   152  	httpTest(t, nil, func(s *TestAgent) {
   153  		// Determine the join address
   154  		member := s.Agent.Server().LocalMember()
   155  		addr := fmt.Sprintf("%s:%d", member.Addr, member.Port)
   156  
   157  		// Make the HTTP request
   158  		req, err := http.NewRequest("PUT",
   159  			fmt.Sprintf("/v1/agent/join?address=%s&address=%s", addr, addr), nil)
   160  		if err != nil {
   161  			t.Fatalf("err: %v", err)
   162  		}
   163  		respW := httptest.NewRecorder()
   164  
   165  		// Make the request
   166  		obj, err := s.Server.AgentJoinRequest(respW, req)
   167  		if err != nil {
   168  			t.Fatalf("err: %v", err)
   169  		}
   170  
   171  		// Check the job
   172  		join := obj.(joinResult)
   173  		if join.NumJoined != 2 {
   174  			t.Fatalf("bad: %#v", join)
   175  		}
   176  		if join.Error != "" {
   177  			t.Fatalf("bad: %#v", join)
   178  		}
   179  	})
   180  }
   181  
   182  func TestHTTP_AgentMembers(t *testing.T) {
   183  	t.Parallel()
   184  	httpTest(t, nil, func(s *TestAgent) {
   185  		// Make the HTTP request
   186  		req, err := http.NewRequest("GET", "/v1/agent/members", nil)
   187  		if err != nil {
   188  			t.Fatalf("err: %v", err)
   189  		}
   190  		respW := httptest.NewRecorder()
   191  
   192  		// Make the request
   193  		obj, err := s.Server.AgentMembersRequest(respW, req)
   194  		if err != nil {
   195  			t.Fatalf("err: %v", err)
   196  		}
   197  
   198  		// Check the job
   199  		members := obj.(structs.ServerMembersResponse)
   200  		if len(members.Members) != 1 {
   201  			t.Fatalf("bad: %#v", members.Members)
   202  		}
   203  	})
   204  }
   205  
   206  func TestHTTP_AgentMembers_ACL(t *testing.T) {
   207  	t.Parallel()
   208  	require := require.New(t)
   209  
   210  	httpACLTest(t, nil, func(s *TestAgent) {
   211  		state := s.Agent.server.State()
   212  
   213  		// Make the HTTP request
   214  		req, err := http.NewRequest("GET", "/v1/agent/members", nil)
   215  		require.Nil(err)
   216  
   217  		// Try request without a token and expect failure
   218  		{
   219  			respW := httptest.NewRecorder()
   220  			_, err := s.Server.AgentMembersRequest(respW, req)
   221  			require.NotNil(err)
   222  			require.Equal(err.Error(), structs.ErrPermissionDenied.Error())
   223  		}
   224  
   225  		// Try request with an invalid token and expect failure
   226  		{
   227  			respW := httptest.NewRecorder()
   228  			token := mock.CreatePolicyAndToken(t, state, 1005, "invalid", mock.AgentPolicy(acl.PolicyWrite))
   229  			setToken(req, token)
   230  			_, err := s.Server.AgentMembersRequest(respW, req)
   231  			require.NotNil(err)
   232  			require.Equal(err.Error(), structs.ErrPermissionDenied.Error())
   233  		}
   234  
   235  		// Try request with a valid token
   236  		{
   237  			respW := httptest.NewRecorder()
   238  			token := mock.CreatePolicyAndToken(t, state, 1007, "valid", mock.NodePolicy(acl.PolicyRead))
   239  			setToken(req, token)
   240  			obj, err := s.Server.AgentMembersRequest(respW, req)
   241  			require.Nil(err)
   242  
   243  			members := obj.(structs.ServerMembersResponse)
   244  			require.Len(members.Members, 1)
   245  		}
   246  
   247  		// Try request with a root token
   248  		{
   249  			respW := httptest.NewRecorder()
   250  			setToken(req, s.RootToken)
   251  			obj, err := s.Server.AgentMembersRequest(respW, req)
   252  			require.Nil(err)
   253  
   254  			members := obj.(structs.ServerMembersResponse)
   255  			require.Len(members.Members, 1)
   256  		}
   257  	})
   258  }
   259  
   260  func TestHTTP_AgentMonitor(t *testing.T) {
   261  	t.Parallel()
   262  
   263  	t.Run("invalid log_json parameter", func(t *testing.T) {
   264  		httpTest(t, nil, func(s *TestAgent) {
   265  			req, err := http.NewRequest("GET", "/v1/agent/monitor?log_json=no", nil)
   266  			require.Nil(t, err)
   267  			resp := newClosableRecorder()
   268  
   269  			// Make the request
   270  			_, err = s.Server.AgentMonitor(resp, req)
   271  			httpErr := err.(HTTPCodedError).Code()
   272  			require.Equal(t, 400, httpErr)
   273  		})
   274  	})
   275  
   276  	t.Run("unknown log_level", func(t *testing.T) {
   277  		httpTest(t, nil, func(s *TestAgent) {
   278  			req, err := http.NewRequest("GET", "/v1/agent/monitor?log_level=unknown", nil)
   279  			require.Nil(t, err)
   280  			resp := newClosableRecorder()
   281  
   282  			// Make the request
   283  			_, err = s.Server.AgentMonitor(resp, req)
   284  			httpErr := err.(HTTPCodedError).Code()
   285  			require.Equal(t, 400, httpErr)
   286  		})
   287  	})
   288  
   289  	t.Run("check for specific log level", func(t *testing.T) {
   290  		httpTest(t, nil, func(s *TestAgent) {
   291  			req, err := http.NewRequest("GET", "/v1/agent/monitor?log_level=warn", nil)
   292  			require.Nil(t, err)
   293  			resp := newClosableRecorder()
   294  			defer resp.Close()
   295  
   296  			go func() {
   297  				_, err = s.Server.AgentMonitor(resp, req)
   298  				assert.NoError(t, err)
   299  			}()
   300  
   301  			// send the same log until monitor sink is set up
   302  			maxLogAttempts := 10
   303  			tried := 0
   304  			testutil.WaitForResult(func() (bool, error) {
   305  				if tried < maxLogAttempts {
   306  					s.Server.logger.Warn("log that should be sent")
   307  					tried++
   308  				}
   309  
   310  				got := resp.Body.String()
   311  				want := `{"Data":"`
   312  				if strings.Contains(got, want) {
   313  					return true, nil
   314  				}
   315  
   316  				return false, fmt.Errorf("missing expected log, got: %v, want: %v", got, want)
   317  			}, func(err error) {
   318  				require.Fail(t, err.Error())
   319  			})
   320  		})
   321  	})
   322  
   323  	t.Run("plain output", func(t *testing.T) {
   324  		httpTest(t, nil, func(s *TestAgent) {
   325  			req, err := http.NewRequest("GET", "/v1/agent/monitor?log_level=debug&plain=true", nil)
   326  			require.Nil(t, err)
   327  			resp := newClosableRecorder()
   328  			defer resp.Close()
   329  
   330  			go func() {
   331  				_, err = s.Server.AgentMonitor(resp, req)
   332  				assert.NoError(t, err)
   333  			}()
   334  
   335  			// send the same log until monitor sink is set up
   336  			maxLogAttempts := 10
   337  			tried := 0
   338  			testutil.WaitForResult(func() (bool, error) {
   339  				if tried < maxLogAttempts {
   340  					s.Server.logger.Debug("log that should be sent")
   341  					tried++
   342  				}
   343  
   344  				got := resp.Body.String()
   345  				want := `[DEBUG] http: log that should be sent`
   346  				if strings.Contains(got, want) {
   347  					return true, nil
   348  				}
   349  
   350  				return false, fmt.Errorf("missing expected log, got: %v, want: %v", got, want)
   351  			}, func(err error) {
   352  				require.Fail(t, err.Error())
   353  			})
   354  		})
   355  	})
   356  
   357  	t.Run("logs for a specific node", func(t *testing.T) {
   358  		httpTest(t, nil, func(s *TestAgent) {
   359  			req, err := http.NewRequest("GET", "/v1/agent/monitor?log_level=warn&node_id="+s.client.NodeID(), nil)
   360  			require.Nil(t, err)
   361  			resp := newClosableRecorder()
   362  			defer resp.Close()
   363  
   364  			go func() {
   365  				_, err = s.Server.AgentMonitor(resp, req)
   366  				assert.NoError(t, err)
   367  			}()
   368  
   369  			// send the same log until monitor sink is set up
   370  			maxLogAttempts := 10
   371  			tried := 0
   372  			out := ""
   373  			testutil.WaitForResult(func() (bool, error) {
   374  				if tried < maxLogAttempts {
   375  					s.Server.logger.Debug("log that should not be sent")
   376  					s.Server.logger.Warn("log that should be sent")
   377  					tried++
   378  				}
   379  				output, err := ioutil.ReadAll(resp.Body)
   380  				if err != nil {
   381  					return false, err
   382  				}
   383  
   384  				out += string(output)
   385  				want := `{"Data":"`
   386  				if strings.Contains(out, want) {
   387  					return true, nil
   388  				}
   389  
   390  				return false, fmt.Errorf("missing expected log, got: %v, want: %v", out, want)
   391  			}, func(err error) {
   392  				require.Fail(t, err.Error())
   393  			})
   394  		})
   395  	})
   396  
   397  	t.Run("logs for a local client with no server running on agent", func(t *testing.T) {
   398  		httpTest(t, nil, func(s *TestAgent) {
   399  			req, err := http.NewRequest("GET", "/v1/agent/monitor?log_level=warn", nil)
   400  			require.Nil(t, err)
   401  			resp := newClosableRecorder()
   402  			defer resp.Close()
   403  
   404  			go func() {
   405  				// set server to nil to monitor as client
   406  				s.Agent.server = nil
   407  				_, err = s.Server.AgentMonitor(resp, req)
   408  				assert.NoError(t, err)
   409  			}()
   410  
   411  			// send the same log until monitor sink is set up
   412  			maxLogAttempts := 10
   413  			tried := 0
   414  			out := ""
   415  			testutil.WaitForResult(func() (bool, error) {
   416  				if tried < maxLogAttempts {
   417  					s.Agent.logger.Warn("log that should be sent")
   418  					tried++
   419  				}
   420  				output, err := ioutil.ReadAll(resp.Body)
   421  				if err != nil {
   422  					return false, err
   423  				}
   424  
   425  				out += string(output)
   426  				want := `{"Data":"`
   427  				if strings.Contains(out, want) {
   428  					return true, nil
   429  				}
   430  
   431  				return false, fmt.Errorf("missing expected log, got: %v, want: %v", out, want)
   432  			}, func(err error) {
   433  				require.Fail(t, err.Error())
   434  			})
   435  		})
   436  	})
   437  }
   438  
   439  // Scenarios when Pprof requests should be available
   440  // see https://github.com/hashicorp/nomad/issues/6496
   441  // +---------------+------------------+--------+------------------+
   442  // |   Endpoint    |  `enable_debug`  |  ACLs  |  **Available?**  |
   443  // +---------------+------------------+--------+------------------+
   444  // | /debug/pprof  |  unset           |  n/a   |  no              |
   445  // | /debug/pprof  |  `true`          |  n/a   |  yes             |
   446  // | /debug/pprof  |  `false`         |  n/a   |  no              |
   447  // | /agent/pprof  |  unset           |  off   |  no              |
   448  // | /agent/pprof  |  unset           |  on    |  **yes**         |
   449  // | /agent/pprof  |  `true`          |  off   |  yes             |
   450  // | /agent/pprof  |  `false`         |  on    |  **yes**         |
   451  // +---------------+------------------+--------+------------------+
   452  func TestAgent_PprofRequest_Permissions(t *testing.T) {
   453  	trueP, falseP := helper.BoolToPtr(true), helper.BoolToPtr(false)
   454  	cases := []struct {
   455  		acl   *bool
   456  		debug *bool
   457  		ok    bool
   458  	}{
   459  		// manually set to false because test helpers
   460  		// enable to true by default
   461  		// enableDebug:       helper.BoolToPtr(false),
   462  		{debug: nil, ok: false},
   463  		{debug: trueP, ok: true},
   464  		{debug: falseP, ok: false},
   465  		{debug: falseP, acl: falseP, ok: false},
   466  		{acl: trueP, ok: true},
   467  		{acl: falseP, debug: trueP, ok: true},
   468  		{debug: falseP, acl: trueP, ok: true},
   469  	}
   470  
   471  	for _, tc := range cases {
   472  		ptrToStr := func(val *bool) string {
   473  			if val == nil {
   474  				return "unset"
   475  			} else if *val == true {
   476  				return "true"
   477  			} else {
   478  				return "false"
   479  			}
   480  		}
   481  
   482  		t.Run(
   483  			fmt.Sprintf("debug %s, acl %s",
   484  				ptrToStr(tc.debug),
   485  				ptrToStr(tc.acl)),
   486  			func(t *testing.T) {
   487  				cb := func(c *Config) {
   488  					if tc.acl != nil {
   489  						c.ACL.Enabled = *tc.acl
   490  					}
   491  					if tc.debug == nil {
   492  						var nodebug bool
   493  						c.EnableDebug = nodebug
   494  					} else {
   495  						c.EnableDebug = *tc.debug
   496  					}
   497  				}
   498  
   499  				httpTest(t, cb, func(s *TestAgent) {
   500  					state := s.Agent.server.State()
   501  					url := "/v1/agent/pprof/cmdline"
   502  					req, err := http.NewRequest("GET", url, nil)
   503  					require.NoError(t, err)
   504  					respW := httptest.NewRecorder()
   505  
   506  					if tc.acl != nil && *tc.acl {
   507  						token := mock.CreatePolicyAndToken(t, state, 1007, "valid", mock.AgentPolicy(acl.PolicyWrite))
   508  						setToken(req, token)
   509  					}
   510  
   511  					resp, err := s.Server.AgentPprofRequest(respW, req)
   512  					if tc.ok {
   513  						require.NoError(t, err)
   514  						require.NotNil(t, resp)
   515  					} else {
   516  						require.Error(t, err)
   517  						require.Equal(t, structs.ErrPermissionDenied.Error(), err.Error())
   518  					}
   519  				})
   520  			})
   521  	}
   522  }
   523  
   524  func TestAgent_PprofRequest(t *testing.T) {
   525  	cases := []struct {
   526  		desc        string
   527  		url         string
   528  		addNodeID   bool
   529  		addServerID bool
   530  		expectedErr string
   531  		clientOnly  bool
   532  	}{
   533  		{
   534  			desc: "cmdline local server request",
   535  			url:  "/v1/agent/pprof/cmdline",
   536  		},
   537  		{
   538  			desc:       "cmdline local node request",
   539  			url:        "/v1/agent/pprof/cmdline",
   540  			clientOnly: true,
   541  		},
   542  		{
   543  			desc:      "cmdline node request",
   544  			url:       "/v1/agent/pprof/cmdline",
   545  			addNodeID: true,
   546  		},
   547  		{
   548  			desc:        "cmdline server request",
   549  			url:         "/v1/agent/pprof/cmdline",
   550  			addServerID: true,
   551  		},
   552  		{
   553  			desc:        "invalid server request",
   554  			url:         "/v1/agent/pprof/unknown",
   555  			addServerID: true,
   556  			expectedErr: "RPC Error:: 404,Pprof profile not found profile: unknown",
   557  		},
   558  		{
   559  			desc:      "cpu profile request",
   560  			url:       "/v1/agent/pprof/profile",
   561  			addNodeID: true,
   562  		},
   563  		{
   564  			desc:      "trace request",
   565  			url:       "/v1/agent/pprof/trace",
   566  			addNodeID: true,
   567  		},
   568  		{
   569  			desc:      "pprof lookup request",
   570  			url:       "/v1/agent/pprof/goroutine",
   571  			addNodeID: true,
   572  		},
   573  		{
   574  			desc:        "unknown pprof lookup request",
   575  			url:         "/v1/agent/pprof/latency",
   576  			addNodeID:   true,
   577  			expectedErr: "RPC Error:: 404,Pprof profile not found profile: latency",
   578  		},
   579  	}
   580  
   581  	for _, tc := range cases {
   582  		t.Run(tc.desc, func(t *testing.T) {
   583  			httpTest(t, nil, func(s *TestAgent) {
   584  
   585  				// add node or server id query param
   586  				url := tc.url
   587  				if tc.addNodeID {
   588  					url = url + "?node_id=" + s.client.NodeID()
   589  				} else if tc.addServerID {
   590  					url = url + "?server_id=" + s.server.LocalMember().Name
   591  				}
   592  
   593  				if tc.clientOnly {
   594  					s.Agent.server = nil
   595  				}
   596  
   597  				req, err := http.NewRequest("GET", url, nil)
   598  				require.Nil(t, err)
   599  				respW := httptest.NewRecorder()
   600  
   601  				resp, err := s.Server.AgentPprofRequest(respW, req)
   602  
   603  				if tc.expectedErr != "" {
   604  					require.Error(t, err)
   605  					require.EqualError(t, err, tc.expectedErr)
   606  				} else {
   607  					require.NoError(t, err)
   608  					require.NotNil(t, resp)
   609  				}
   610  			})
   611  		})
   612  	}
   613  }
   614  
   615  type closableRecorder struct {
   616  	*httptest.ResponseRecorder
   617  	closer chan bool
   618  }
   619  
   620  func newClosableRecorder() *closableRecorder {
   621  	r := httptest.NewRecorder()
   622  	closer := make(chan bool)
   623  	return &closableRecorder{r, closer}
   624  }
   625  
   626  func (r *closableRecorder) Close() {
   627  	close(r.closer)
   628  }
   629  
   630  func (r *closableRecorder) CloseNotify() <-chan bool {
   631  	return r.closer
   632  }
   633  
   634  func TestHTTP_AgentForceLeave(t *testing.T) {
   635  	t.Parallel()
   636  	httpTest(t, nil, func(s *TestAgent) {
   637  		// Make the HTTP request
   638  		req, err := http.NewRequest("PUT", "/v1/agent/force-leave?node=foo", nil)
   639  		if err != nil {
   640  			t.Fatalf("err: %v", err)
   641  		}
   642  		respW := httptest.NewRecorder()
   643  
   644  		// Make the request
   645  		_, err = s.Server.AgentForceLeaveRequest(respW, req)
   646  		if err != nil {
   647  			t.Fatalf("err: %v", err)
   648  		}
   649  	})
   650  }
   651  
   652  func TestHTTP_AgentForceLeave_ACL(t *testing.T) {
   653  	t.Parallel()
   654  	require := require.New(t)
   655  
   656  	httpACLTest(t, nil, func(s *TestAgent) {
   657  		state := s.Agent.server.State()
   658  
   659  		// Make the HTTP request
   660  		req, err := http.NewRequest("PUT", "/v1/agent/force-leave?node=foo", nil)
   661  		require.Nil(err)
   662  
   663  		// Try request without a token and expect failure
   664  		{
   665  			respW := httptest.NewRecorder()
   666  			_, err := s.Server.AgentForceLeaveRequest(respW, req)
   667  			require.NotNil(err)
   668  			require.Equal(err.Error(), structs.ErrPermissionDenied.Error())
   669  		}
   670  
   671  		// Try request with an invalid token and expect failure
   672  		{
   673  			respW := httptest.NewRecorder()
   674  			token := mock.CreatePolicyAndToken(t, state, 1005, "invalid", mock.NodePolicy(acl.PolicyRead))
   675  			setToken(req, token)
   676  			_, err := s.Server.AgentForceLeaveRequest(respW, req)
   677  			require.NotNil(err)
   678  			require.Equal(err.Error(), structs.ErrPermissionDenied.Error())
   679  		}
   680  
   681  		// Try request with a valid token
   682  		{
   683  			respW := httptest.NewRecorder()
   684  			token := mock.CreatePolicyAndToken(t, state, 1007, "valid", mock.AgentPolicy(acl.PolicyWrite))
   685  			setToken(req, token)
   686  			_, err := s.Server.AgentForceLeaveRequest(respW, req)
   687  			require.Nil(err)
   688  			require.Equal(http.StatusOK, respW.Code)
   689  		}
   690  
   691  		// Try request with a root token
   692  		{
   693  			respW := httptest.NewRecorder()
   694  			setToken(req, s.RootToken)
   695  			_, err := s.Server.AgentForceLeaveRequest(respW, req)
   696  			require.Nil(err)
   697  			require.Equal(http.StatusOK, respW.Code)
   698  		}
   699  	})
   700  }
   701  
   702  func TestHTTP_AgentSetServers(t *testing.T) {
   703  	t.Parallel()
   704  	require := require.New(t)
   705  	httpTest(t, nil, func(s *TestAgent) {
   706  		addr := s.Config.AdvertiseAddrs.RPC
   707  		testutil.WaitForResult(func() (bool, error) {
   708  			conn, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
   709  			if err != nil {
   710  				return false, err
   711  			}
   712  			defer conn.Close()
   713  
   714  			// Write the Nomad RPC byte to set the mode
   715  			if _, err := conn.Write([]byte{byte(pool.RpcNomad)}); err != nil {
   716  				return false, err
   717  			}
   718  
   719  			codec := pool.NewClientCodec(conn)
   720  			args := &structs.GenericRequest{}
   721  			var leader string
   722  			err = msgpackrpc.CallWithCodec(codec, "Status.Leader", args, &leader)
   723  			return leader != "", err
   724  		}, func(err error) {
   725  			t.Fatalf("failed to find leader: %v", err)
   726  		})
   727  
   728  		// Create the request
   729  		req, err := http.NewRequest("PUT", "/v1/agent/servers", nil)
   730  		require.Nil(err)
   731  
   732  		// Send the request
   733  		respW := httptest.NewRecorder()
   734  		_, err = s.Server.AgentServersRequest(respW, req)
   735  		require.NotNil(err)
   736  		require.Contains(err.Error(), "missing server address")
   737  
   738  		// Create a valid request
   739  		req, err = http.NewRequest("PUT", "/v1/agent/servers?address=127.0.0.1%3A4647&address=127.0.0.2%3A4647&address=127.0.0.3%3A4647", nil)
   740  		require.Nil(err)
   741  
   742  		// Send the request which should fail
   743  		respW = httptest.NewRecorder()
   744  		_, err = s.Server.AgentServersRequest(respW, req)
   745  		require.NotNil(err)
   746  
   747  		// Retrieve the servers again
   748  		req, err = http.NewRequest("GET", "/v1/agent/servers", nil)
   749  		require.Nil(err)
   750  		respW = httptest.NewRecorder()
   751  
   752  		// Make the request and check the result
   753  		expected := []string{
   754  			s.GetConfig().AdvertiseAddrs.RPC,
   755  		}
   756  		out, err := s.Server.AgentServersRequest(respW, req)
   757  		require.Nil(err)
   758  		servers := out.([]string)
   759  		require.Len(servers, len(expected))
   760  		require.Equal(expected, servers)
   761  	})
   762  }
   763  
   764  func TestHTTP_AgentSetServers_ACL(t *testing.T) {
   765  	t.Parallel()
   766  	require := require.New(t)
   767  
   768  	httpACLTest(t, nil, func(s *TestAgent) {
   769  		state := s.Agent.server.State()
   770  		addr := s.Config.AdvertiseAddrs.RPC
   771  		testutil.WaitForResult(func() (bool, error) {
   772  			conn, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
   773  			if err != nil {
   774  				return false, err
   775  			}
   776  			defer conn.Close()
   777  
   778  			// Write the Consul RPC byte to set the mode
   779  			if _, err := conn.Write([]byte{byte(pool.RpcNomad)}); err != nil {
   780  				return false, err
   781  			}
   782  
   783  			codec := pool.NewClientCodec(conn)
   784  			args := &structs.GenericRequest{}
   785  			var leader string
   786  			err = msgpackrpc.CallWithCodec(codec, "Status.Leader", args, &leader)
   787  			return leader != "", err
   788  		}, func(err error) {
   789  			t.Fatalf("failed to find leader: %v", err)
   790  		})
   791  
   792  		// Make the HTTP request
   793  		path := fmt.Sprintf("/v1/agent/servers?address=%s", url.QueryEscape(s.GetConfig().AdvertiseAddrs.RPC))
   794  		req, err := http.NewRequest("PUT", path, nil)
   795  		require.Nil(err)
   796  
   797  		// Try request without a token and expect failure
   798  		{
   799  			respW := httptest.NewRecorder()
   800  			_, err := s.Server.AgentServersRequest(respW, req)
   801  			require.NotNil(err)
   802  			require.Equal(err.Error(), structs.ErrPermissionDenied.Error())
   803  		}
   804  
   805  		// Try request with an invalid token and expect failure
   806  		{
   807  			respW := httptest.NewRecorder()
   808  			token := mock.CreatePolicyAndToken(t, state, 1005, "invalid", mock.NodePolicy(acl.PolicyRead))
   809  			setToken(req, token)
   810  			_, err := s.Server.AgentServersRequest(respW, req)
   811  			require.NotNil(err)
   812  			require.Equal(err.Error(), structs.ErrPermissionDenied.Error())
   813  		}
   814  
   815  		// Try request with a valid token
   816  		{
   817  			respW := httptest.NewRecorder()
   818  			token := mock.CreatePolicyAndToken(t, state, 1007, "valid", mock.AgentPolicy(acl.PolicyWrite))
   819  			setToken(req, token)
   820  			_, err := s.Server.AgentServersRequest(respW, req)
   821  			require.Nil(err)
   822  			require.Equal(http.StatusOK, respW.Code)
   823  		}
   824  
   825  		// Try request with a root token
   826  		{
   827  			respW := httptest.NewRecorder()
   828  			setToken(req, s.RootToken)
   829  			_, err := s.Server.AgentServersRequest(respW, req)
   830  			require.Nil(err)
   831  			require.Equal(http.StatusOK, respW.Code)
   832  		}
   833  	})
   834  }
   835  
   836  func TestHTTP_AgentListServers_ACL(t *testing.T) {
   837  	t.Parallel()
   838  	require := require.New(t)
   839  
   840  	httpACLTest(t, nil, func(s *TestAgent) {
   841  		state := s.Agent.server.State()
   842  
   843  		// Create list request
   844  		req, err := http.NewRequest("GET", "/v1/agent/servers", nil)
   845  		require.Nil(err)
   846  
   847  		expected := []string{
   848  			s.GetConfig().AdvertiseAddrs.RPC,
   849  		}
   850  
   851  		// Try request without a token and expect failure
   852  		{
   853  			respW := httptest.NewRecorder()
   854  			_, err := s.Server.AgentServersRequest(respW, req)
   855  			require.NotNil(err)
   856  			require.Equal(err.Error(), structs.ErrPermissionDenied.Error())
   857  		}
   858  
   859  		// Try request with an invalid token and expect failure
   860  		{
   861  			respW := httptest.NewRecorder()
   862  			token := mock.CreatePolicyAndToken(t, state, 1005, "invalid", mock.NodePolicy(acl.PolicyRead))
   863  			setToken(req, token)
   864  			_, err := s.Server.AgentServersRequest(respW, req)
   865  			require.NotNil(err)
   866  			require.Equal(err.Error(), structs.ErrPermissionDenied.Error())
   867  		}
   868  
   869  		// Wait for client to have a server
   870  		testutil.WaitForResult(func() (bool, error) {
   871  			return len(s.client.GetServers()) != 0, fmt.Errorf("no servers")
   872  		}, func(err error) {
   873  			t.Fatal(err)
   874  		})
   875  
   876  		// Try request with a valid token
   877  		{
   878  			respW := httptest.NewRecorder()
   879  			token := mock.CreatePolicyAndToken(t, state, 1007, "valid", mock.AgentPolicy(acl.PolicyRead))
   880  			setToken(req, token)
   881  			out, err := s.Server.AgentServersRequest(respW, req)
   882  			require.Nil(err)
   883  			servers := out.([]string)
   884  			require.Len(servers, len(expected))
   885  			require.Equal(expected, servers)
   886  		}
   887  
   888  		// Try request with a root token
   889  		{
   890  			respW := httptest.NewRecorder()
   891  			setToken(req, s.RootToken)
   892  			out, err := s.Server.AgentServersRequest(respW, req)
   893  			require.Nil(err)
   894  			servers := out.([]string)
   895  			require.Len(servers, len(expected))
   896  			require.Equal(expected, servers)
   897  		}
   898  	})
   899  }
   900  
   901  func TestHTTP_AgentListKeys(t *testing.T) {
   902  	t.Parallel()
   903  
   904  	key1 := "HS5lJ+XuTlYKWaeGYyG+/A=="
   905  
   906  	httpTest(t, func(c *Config) {
   907  		c.Server.EncryptKey = key1
   908  	}, func(s *TestAgent) {
   909  		req, err := http.NewRequest("GET", "/v1/agent/keyring/list", nil)
   910  		if err != nil {
   911  			t.Fatalf("err: %s", err)
   912  		}
   913  		respW := httptest.NewRecorder()
   914  
   915  		out, err := s.Server.KeyringOperationRequest(respW, req)
   916  		require.Nil(t, err)
   917  		kresp := out.(structs.KeyringResponse)
   918  		require.Len(t, kresp.Keys, 1)
   919  	})
   920  }
   921  
   922  func TestHTTP_AgentListKeys_ACL(t *testing.T) {
   923  	t.Parallel()
   924  	require := require.New(t)
   925  
   926  	key1 := "HS5lJ+XuTlYKWaeGYyG+/A=="
   927  
   928  	cb := func(c *Config) {
   929  		c.Server.EncryptKey = key1
   930  	}
   931  
   932  	httpACLTest(t, cb, func(s *TestAgent) {
   933  		state := s.Agent.server.State()
   934  
   935  		// Make the HTTP request
   936  		req, err := http.NewRequest("GET", "/v1/agent/keyring/list", nil)
   937  		require.Nil(err)
   938  
   939  		// Try request without a token and expect failure
   940  		{
   941  			respW := httptest.NewRecorder()
   942  			_, err := s.Server.KeyringOperationRequest(respW, req)
   943  			require.NotNil(err)
   944  			require.Equal(err.Error(), structs.ErrPermissionDenied.Error())
   945  		}
   946  
   947  		// Try request with an invalid token and expect failure
   948  		{
   949  			respW := httptest.NewRecorder()
   950  			token := mock.CreatePolicyAndToken(t, state, 1005, "invalid", mock.AgentPolicy(acl.PolicyRead))
   951  			setToken(req, token)
   952  			_, err := s.Server.KeyringOperationRequest(respW, req)
   953  			require.NotNil(err)
   954  			require.Equal(err.Error(), structs.ErrPermissionDenied.Error())
   955  		}
   956  
   957  		// Try request with a valid token
   958  		{
   959  			respW := httptest.NewRecorder()
   960  			token := mock.CreatePolicyAndToken(t, state, 1007, "valid", mock.AgentPolicy(acl.PolicyWrite))
   961  			setToken(req, token)
   962  			out, err := s.Server.KeyringOperationRequest(respW, req)
   963  			require.Nil(err)
   964  			kresp := out.(structs.KeyringResponse)
   965  			require.Len(kresp.Keys, 1)
   966  			require.Contains(kresp.Keys, key1)
   967  		}
   968  
   969  		// Try request with a root token
   970  		{
   971  			respW := httptest.NewRecorder()
   972  			setToken(req, s.RootToken)
   973  			out, err := s.Server.KeyringOperationRequest(respW, req)
   974  			require.Nil(err)
   975  			kresp := out.(structs.KeyringResponse)
   976  			require.Len(kresp.Keys, 1)
   977  			require.Contains(kresp.Keys, key1)
   978  		}
   979  	})
   980  }
   981  
   982  func TestHTTP_AgentInstallKey(t *testing.T) {
   983  	t.Parallel()
   984  
   985  	key1 := "HS5lJ+XuTlYKWaeGYyG+/A=="
   986  	key2 := "wH1Bn9hlJ0emgWB1JttVRA=="
   987  
   988  	httpTest(t, func(c *Config) {
   989  		c.Server.EncryptKey = key1
   990  	}, func(s *TestAgent) {
   991  		b, err := json.Marshal(&structs.KeyringRequest{Key: key2})
   992  		if err != nil {
   993  			t.Fatalf("err: %v", err)
   994  		}
   995  		req, err := http.NewRequest("GET", "/v1/agent/keyring/install", bytes.NewReader(b))
   996  		if err != nil {
   997  			t.Fatalf("err: %s", err)
   998  		}
   999  		respW := httptest.NewRecorder()
  1000  
  1001  		_, err = s.Server.KeyringOperationRequest(respW, req)
  1002  		if err != nil {
  1003  			t.Fatalf("err: %s", err)
  1004  		}
  1005  		req, err = http.NewRequest("GET", "/v1/agent/keyring/list", bytes.NewReader(b))
  1006  		if err != nil {
  1007  			t.Fatalf("err: %s", err)
  1008  		}
  1009  		respW = httptest.NewRecorder()
  1010  
  1011  		out, err := s.Server.KeyringOperationRequest(respW, req)
  1012  		if err != nil {
  1013  			t.Fatalf("err: %s", err)
  1014  		}
  1015  		kresp := out.(structs.KeyringResponse)
  1016  		if len(kresp.Keys) != 2 {
  1017  			t.Fatalf("bad: %v", kresp)
  1018  		}
  1019  	})
  1020  }
  1021  
  1022  func TestHTTP_AgentRemoveKey(t *testing.T) {
  1023  	t.Parallel()
  1024  
  1025  	key1 := "HS5lJ+XuTlYKWaeGYyG+/A=="
  1026  	key2 := "wH1Bn9hlJ0emgWB1JttVRA=="
  1027  
  1028  	httpTest(t, func(c *Config) {
  1029  		c.Server.EncryptKey = key1
  1030  	}, func(s *TestAgent) {
  1031  		b, err := json.Marshal(&structs.KeyringRequest{Key: key2})
  1032  		if err != nil {
  1033  			t.Fatalf("err: %v", err)
  1034  		}
  1035  
  1036  		req, err := http.NewRequest("GET", "/v1/agent/keyring/install", bytes.NewReader(b))
  1037  		if err != nil {
  1038  			t.Fatalf("err: %s", err)
  1039  		}
  1040  		respW := httptest.NewRecorder()
  1041  		_, err = s.Server.KeyringOperationRequest(respW, req)
  1042  		if err != nil {
  1043  			t.Fatalf("err: %s", err)
  1044  		}
  1045  
  1046  		req, err = http.NewRequest("GET", "/v1/agent/keyring/remove", bytes.NewReader(b))
  1047  		if err != nil {
  1048  			t.Fatalf("err: %s", err)
  1049  		}
  1050  		respW = httptest.NewRecorder()
  1051  		if _, err = s.Server.KeyringOperationRequest(respW, req); err != nil {
  1052  			t.Fatalf("err: %s", err)
  1053  		}
  1054  
  1055  		req, err = http.NewRequest("GET", "/v1/agent/keyring/list", nil)
  1056  		if err != nil {
  1057  			t.Fatalf("err: %s", err)
  1058  		}
  1059  		respW = httptest.NewRecorder()
  1060  		out, err := s.Server.KeyringOperationRequest(respW, req)
  1061  		if err != nil {
  1062  			t.Fatalf("err: %s", err)
  1063  		}
  1064  		kresp := out.(structs.KeyringResponse)
  1065  		if len(kresp.Keys) != 1 {
  1066  			t.Fatalf("bad: %v", kresp)
  1067  		}
  1068  	})
  1069  }
  1070  
  1071  func TestHTTP_AgentHealth_Ok(t *testing.T) {
  1072  	t.Parallel()
  1073  	require := require.New(t)
  1074  
  1075  	// Enable ACLs to ensure they're not enforced
  1076  	httpACLTest(t, nil, func(s *TestAgent) {
  1077  		// No ?type=
  1078  		{
  1079  			req, err := http.NewRequest("GET", "/v1/agent/health", nil)
  1080  			require.Nil(err)
  1081  
  1082  			respW := httptest.NewRecorder()
  1083  			healthI, err := s.Server.HealthRequest(respW, req)
  1084  			require.Nil(err)
  1085  			require.Equal(http.StatusOK, respW.Code)
  1086  			require.NotNil(healthI)
  1087  			health := healthI.(*healthResponse)
  1088  			require.NotNil(health.Client)
  1089  			require.True(health.Client.Ok)
  1090  			require.Equal("ok", health.Client.Message)
  1091  			require.NotNil(health.Server)
  1092  			require.True(health.Server.Ok)
  1093  			require.Equal("ok", health.Server.Message)
  1094  		}
  1095  
  1096  		// type=client
  1097  		{
  1098  			req, err := http.NewRequest("GET", "/v1/agent/health?type=client", nil)
  1099  			require.Nil(err)
  1100  
  1101  			respW := httptest.NewRecorder()
  1102  			healthI, err := s.Server.HealthRequest(respW, req)
  1103  			require.Nil(err)
  1104  			require.Equal(http.StatusOK, respW.Code)
  1105  			require.NotNil(healthI)
  1106  			health := healthI.(*healthResponse)
  1107  			require.NotNil(health.Client)
  1108  			require.True(health.Client.Ok)
  1109  			require.Equal("ok", health.Client.Message)
  1110  			require.Nil(health.Server)
  1111  		}
  1112  
  1113  		// type=server
  1114  		{
  1115  			req, err := http.NewRequest("GET", "/v1/agent/health?type=server", nil)
  1116  			require.Nil(err)
  1117  
  1118  			respW := httptest.NewRecorder()
  1119  			healthI, err := s.Server.HealthRequest(respW, req)
  1120  			require.Nil(err)
  1121  			require.Equal(http.StatusOK, respW.Code)
  1122  			require.NotNil(healthI)
  1123  			health := healthI.(*healthResponse)
  1124  			require.NotNil(health.Server)
  1125  			require.True(health.Server.Ok)
  1126  			require.Equal("ok", health.Server.Message)
  1127  			require.Nil(health.Client)
  1128  		}
  1129  
  1130  		// type=client&type=server
  1131  		{
  1132  			req, err := http.NewRequest("GET", "/v1/agent/health?type=client&type=server", nil)
  1133  			require.Nil(err)
  1134  
  1135  			respW := httptest.NewRecorder()
  1136  			healthI, err := s.Server.HealthRequest(respW, req)
  1137  			require.Nil(err)
  1138  			require.Equal(http.StatusOK, respW.Code)
  1139  			require.NotNil(healthI)
  1140  			health := healthI.(*healthResponse)
  1141  			require.NotNil(health.Client)
  1142  			require.True(health.Client.Ok)
  1143  			require.Equal("ok", health.Client.Message)
  1144  			require.NotNil(health.Server)
  1145  			require.True(health.Server.Ok)
  1146  			require.Equal("ok", health.Server.Message)
  1147  		}
  1148  	})
  1149  }
  1150  
  1151  func TestHTTP_AgentHealth_BadServer(t *testing.T) {
  1152  	t.Parallel()
  1153  	require := require.New(t)
  1154  
  1155  	serverAgent := NewTestAgent(t, "server", nil)
  1156  	defer serverAgent.Shutdown()
  1157  
  1158  	s := makeHTTPServer(t, func(c *Config) {
  1159  		// Disable server to make server health unhealthy if requested
  1160  		c.Server.Enabled = false
  1161  		c.Client.Servers = []string{fmt.Sprintf("localhost:%d", serverAgent.Config.Ports.RPC)}
  1162  	})
  1163  	defer s.Shutdown()
  1164  
  1165  	// No ?type= means server is just skipped
  1166  	{
  1167  		req, err := http.NewRequest("GET", "/v1/agent/health", nil)
  1168  		require.Nil(err)
  1169  
  1170  		respW := httptest.NewRecorder()
  1171  		healthI, err := s.Server.HealthRequest(respW, req)
  1172  		require.Nil(err)
  1173  		require.Equal(http.StatusOK, respW.Code)
  1174  		require.NotNil(healthI)
  1175  		health := healthI.(*healthResponse)
  1176  		require.NotNil(health.Client)
  1177  		require.True(health.Client.Ok)
  1178  		require.Equal("ok", health.Client.Message)
  1179  		require.Nil(health.Server)
  1180  	}
  1181  
  1182  	// type=server means server is considered unhealthy
  1183  	{
  1184  		req, err := http.NewRequest("GET", "/v1/agent/health?type=server", nil)
  1185  		require.Nil(err)
  1186  
  1187  		respW := httptest.NewRecorder()
  1188  		_, err = s.Server.HealthRequest(respW, req)
  1189  		require.NotNil(err)
  1190  		httpErr, ok := err.(HTTPCodedError)
  1191  		require.True(ok)
  1192  		require.Equal(500, httpErr.Code())
  1193  		require.Equal(`{"server":{"ok":false,"message":"server not enabled"}}`, err.Error())
  1194  	}
  1195  }
  1196  
  1197  func TestHTTP_AgentHealth_BadClient(t *testing.T) {
  1198  	t.Parallel()
  1199  	require := require.New(t)
  1200  
  1201  	// Disable client to make server unhealthy if requested
  1202  	cb := func(c *Config) {
  1203  		c.Client.Enabled = false
  1204  	}
  1205  
  1206  	// Enable ACLs to ensure they're not enforced
  1207  	httpACLTest(t, cb, func(s *TestAgent) {
  1208  		// No ?type= means client is just skipped
  1209  		{
  1210  			req, err := http.NewRequest("GET", "/v1/agent/health", nil)
  1211  			require.Nil(err)
  1212  
  1213  			respW := httptest.NewRecorder()
  1214  			healthI, err := s.Server.HealthRequest(respW, req)
  1215  			require.Nil(err)
  1216  			require.Equal(http.StatusOK, respW.Code)
  1217  			require.NotNil(healthI)
  1218  			health := healthI.(*healthResponse)
  1219  			require.NotNil(health.Server)
  1220  			require.True(health.Server.Ok)
  1221  			require.Equal("ok", health.Server.Message)
  1222  			require.Nil(health.Client)
  1223  		}
  1224  
  1225  		// type=client means client is considered unhealthy
  1226  		{
  1227  			req, err := http.NewRequest("GET", "/v1/agent/health?type=client", nil)
  1228  			require.Nil(err)
  1229  
  1230  			respW := httptest.NewRecorder()
  1231  			_, err = s.Server.HealthRequest(respW, req)
  1232  			require.NotNil(err)
  1233  			httpErr, ok := err.(HTTPCodedError)
  1234  			require.True(ok)
  1235  			require.Equal(500, httpErr.Code())
  1236  			require.Equal(`{"client":{"ok":false,"message":"client not enabled"}}`, err.Error())
  1237  		}
  1238  	})
  1239  }
  1240  
  1241  var (
  1242  	errorPipe = &net.OpError{
  1243  		Op:     "write",
  1244  		Net:    "tcp",
  1245  		Source: &net.TCPAddr{},
  1246  		Addr:   &net.TCPAddr{},
  1247  		Err: &os.SyscallError{
  1248  			Syscall: "write",
  1249  			Err:     syscall.EPIPE,
  1250  		},
  1251  	}
  1252  )
  1253  
  1254  // fakeRW is a fake response writer to ease polling streaming responses in a
  1255  // data-race-free way.
  1256  type fakeRW struct {
  1257  	Code      int
  1258  	HeaderMap http.Header
  1259  	buf       *bytes.Buffer
  1260  	closed    bool
  1261  	mu        sync.Mutex
  1262  
  1263  	// Written is ticked whenever a Write occurs and on WriteHeaders if it
  1264  	// is explicitly called
  1265  	Written chan int
  1266  
  1267  	// ClosedErr is the error Write will return once the writer is closed.
  1268  	// Defaults to EPIPE. Must not be mutated concurrently with writes.
  1269  	ClosedErr error
  1270  }
  1271  
  1272  // Header is for setting headers before writing a response. Tests should check
  1273  // the HeaderMap field directly.
  1274  func (f *fakeRW) Header() http.Header {
  1275  	f.mu.Lock()
  1276  	defer f.mu.Unlock()
  1277  
  1278  	if f.Code != 0 {
  1279  		panic("cannot set headers after WriteHeader has been called")
  1280  	}
  1281  
  1282  	return f.HeaderMap
  1283  }
  1284  
  1285  func (f *fakeRW) Write(p []byte) (int, error) {
  1286  	f.mu.Lock()
  1287  	defer f.mu.Unlock()
  1288  
  1289  	if f.closed {
  1290  		// Mimic an EPIPE error
  1291  		return 0, f.ClosedErr
  1292  	}
  1293  
  1294  	if f.Code == 0 {
  1295  		f.Code = 200
  1296  	}
  1297  
  1298  	n, err := f.buf.Write(p)
  1299  	select {
  1300  	case f.Written <- 1:
  1301  	default:
  1302  	}
  1303  	return n, err
  1304  }
  1305  
  1306  // WriteHeader sets Code and FinalHeaders
  1307  func (f *fakeRW) WriteHeader(statusCode int) {
  1308  	f.mu.Lock()
  1309  	defer f.mu.Unlock()
  1310  
  1311  	if f.Code != 0 {
  1312  		panic("cannot call WriteHeader more than once")
  1313  	}
  1314  
  1315  	f.Code = statusCode
  1316  	select {
  1317  	case f.Written <- 1:
  1318  	default:
  1319  	}
  1320  }
  1321  
  1322  // Bytes returns the body bytes written to the buffer. Safe for calling
  1323  // concurrent with writes.
  1324  func (f *fakeRW) Bytes() []byte {
  1325  	f.mu.Lock()
  1326  	defer f.mu.Unlock()
  1327  
  1328  	return f.buf.Bytes()
  1329  }
  1330  
  1331  // Close the writer causing an EPIPE error on future writes. Safe to call
  1332  // concurrently with other methods. Safe to call more than once.
  1333  func (f *fakeRW) Close() {
  1334  	f.mu.Lock()
  1335  	defer f.mu.Unlock()
  1336  	f.closed = true
  1337  }
  1338  
  1339  func NewFakeRW() *fakeRW {
  1340  	return &fakeRW{
  1341  		HeaderMap: make(map[string][]string),
  1342  		buf:       &bytes.Buffer{},
  1343  		Written:   make(chan int, 1),
  1344  		ClosedErr: errorPipe,
  1345  	}
  1346  }
  1347  
  1348  // TestHTTP_XSS_Monitor asserts /v1/agent/monitor is safe against XSS attacks
  1349  // even when log output contains HTML+Javascript.
  1350  func TestHTTP_XSS_Monitor(t *testing.T) {
  1351  	t.Parallel()
  1352  
  1353  	cases := []struct {
  1354  		Name    string
  1355  		Logline string
  1356  		JSON    bool
  1357  	}{
  1358  		{
  1359  			Name:    "Plain",
  1360  			Logline: "--TEST 123--",
  1361  			JSON:    false,
  1362  		},
  1363  		{
  1364  			Name:    "JSON",
  1365  			Logline: "--TEST 123--",
  1366  			JSON:    true,
  1367  		},
  1368  		{
  1369  			Name:    "XSSPlain",
  1370  			Logline: "<script>alert(document.domain);</script>",
  1371  			JSON:    false,
  1372  		},
  1373  		{
  1374  			Name:    "XSSJson",
  1375  			Logline: "<script>alert(document.domain);</script>",
  1376  			JSON:    true,
  1377  		},
  1378  	}
  1379  
  1380  	for i := range cases {
  1381  		tc := cases[i]
  1382  		t.Run(tc.Name, func(t *testing.T) {
  1383  			t.Parallel()
  1384  			s := makeHTTPServer(t, nil)
  1385  			defer s.Shutdown()
  1386  
  1387  			path := fmt.Sprintf("%s/v1/agent/monitor?error_level=error&plain=%t", s.HTTPAddr(), !tc.JSON)
  1388  			req, err := http.NewRequest("GET", path, nil)
  1389  			require.NoError(t, err)
  1390  			resp := NewFakeRW()
  1391  			closedErr := errors.New("sentinel error")
  1392  			resp.ClosedErr = closedErr
  1393  			defer resp.Close()
  1394  
  1395  			errCh := make(chan error, 1)
  1396  			go func() {
  1397  				_, err := s.Server.AgentMonitor(resp, req)
  1398  				errCh <- err
  1399  			}()
  1400  
  1401  			deadline := time.After(3 * time.Second)
  1402  
  1403  		OUTER:
  1404  			for {
  1405  				// Log a needle and look for it in the response haystack
  1406  				s.Server.logger.Error(tc.Logline)
  1407  
  1408  				select {
  1409  				case <-time.After(30 * time.Millisecond):
  1410  					// Give AgentMonitor handler goroutine time to start
  1411  				case <-resp.Written:
  1412  					// Something was written, check it
  1413  				case <-deadline:
  1414  					t.Fatalf("timed out waiting for expected log line; body:\n%s", string(resp.Bytes()))
  1415  				case err := <-errCh:
  1416  					t.Fatalf("AgentMonitor exited unexpectedly: err=%v", err)
  1417  				}
  1418  
  1419  				if !tc.JSON {
  1420  					if bytes.Contains(resp.Bytes(), []byte(tc.Logline)) {
  1421  						// Found needle!
  1422  						break
  1423  					} else {
  1424  						// Try again
  1425  						continue
  1426  					}
  1427  				}
  1428  
  1429  				// Decode JSON
  1430  				r := bytes.NewReader(resp.Bytes())
  1431  				dec := json.NewDecoder(r)
  1432  				for {
  1433  					data := struct{ Data []byte }{}
  1434  					if err := dec.Decode(&data); err != nil {
  1435  						// Probably a partial write, continue
  1436  						continue OUTER
  1437  					}
  1438  
  1439  					if bytes.Contains(data.Data, []byte(tc.Logline)) {
  1440  						// Found needle!
  1441  						break OUTER
  1442  					}
  1443  				}
  1444  
  1445  			}
  1446  
  1447  			// Assert default logs are application/json
  1448  			ct := "text/plain"
  1449  			if tc.JSON {
  1450  				ct = "application/json"
  1451  			}
  1452  			require.Equal(t, []string{ct}, resp.HeaderMap.Values("Content-Type"))
  1453  
  1454  			// Close response writer and log to make AgentMonitor exit
  1455  			resp.Close()
  1456  			s.Server.logger.Error("log again to force a write that detects the closed connection")
  1457  			select {
  1458  			case err := <-errCh:
  1459  				require.EqualError(t, closedErr, err.Error())
  1460  			case <-deadline:
  1461  				t.Fatalf("timed out waiting for closing error from handler")
  1462  			}
  1463  		})
  1464  	}
  1465  }