github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/nomad/client_agent_endpoint.go (about)

     1  package nomad
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"time"
    11  
    12  	log "github.com/hashicorp/go-hclog"
    13  
    14  	sframer "github.com/hashicorp/nomad/client/lib/streamframer"
    15  	cstructs "github.com/hashicorp/nomad/client/structs"
    16  	"github.com/hashicorp/nomad/command/agent/host"
    17  	"github.com/hashicorp/nomad/command/agent/monitor"
    18  	"github.com/hashicorp/nomad/command/agent/pprof"
    19  	"github.com/hashicorp/nomad/helper/pointer"
    20  	"github.com/hashicorp/nomad/nomad/structs"
    21  
    22  	"github.com/hashicorp/go-msgpack/codec"
    23  )
    24  
    25  type Agent struct {
    26  	srv *Server
    27  }
    28  
    29  func NewAgentEndpoint(srv *Server) *Agent {
    30  	return &Agent{srv: srv}
    31  }
    32  
    33  func (a *Agent) register() {
    34  	a.srv.streamingRpcs.Register("Agent.Monitor", a.monitor)
    35  }
    36  
    37  func (a *Agent) Profile(args *structs.AgentPprofRequest, reply *structs.AgentPprofResponse) error {
    38  	// Check ACL for agent write
    39  	aclObj, err := a.srv.ResolveToken(args.AuthToken)
    40  	if err != nil {
    41  		return err
    42  	} else if aclObj != nil && !aclObj.AllowAgentWrite() {
    43  		return structs.ErrPermissionDenied
    44  	}
    45  
    46  	// Forward to different region if necessary
    47  	// this would typically be done in a.srv.forward() but since
    48  	// we are targeting a specific server, not just the leader
    49  	// we must manually handle region forwarding here.
    50  	region := args.RequestRegion()
    51  	if region == "" {
    52  		return fmt.Errorf("missing target RPC")
    53  	}
    54  
    55  	if region != a.srv.config.Region {
    56  		// Mark that we are forwarding
    57  		args.SetForwarded()
    58  		return a.srv.forwardRegion(region, "Agent.Profile", args, reply)
    59  	}
    60  
    61  	// Targeting a node, forward request to node
    62  	if args.NodeID != "" {
    63  		return a.forwardProfileClient(args, reply)
    64  	}
    65  
    66  	// Handle serverID not equal to ours
    67  	if args.ServerID != "" {
    68  		serverToFwd, err := a.forwardFor(args.ServerID, region)
    69  		if err != nil {
    70  			return err
    71  		}
    72  		if serverToFwd != nil {
    73  			return a.srv.forwardServer(serverToFwd, "Agent.Profile", args, reply)
    74  		}
    75  	}
    76  
    77  	// If ACLs are disabled, EnableDebug must be enabled
    78  	if aclObj == nil && !a.srv.config.EnableDebug {
    79  		return structs.ErrPermissionDenied
    80  	}
    81  
    82  	// Process the request on this server
    83  	var resp []byte
    84  	var headers map[string]string
    85  
    86  	// Determine which profile to run and generate profile.
    87  	// Blocks for args.Seconds
    88  	// Our RPC endpoints currently don't support context
    89  	// or request cancellation so using server shutdownCtx as a
    90  	// best effort.
    91  	switch args.ReqType {
    92  	case pprof.CPUReq:
    93  		resp, headers, err = pprof.CPUProfile(a.srv.shutdownCtx, args.Seconds)
    94  	case pprof.CmdReq:
    95  		resp, headers, err = pprof.Cmdline()
    96  	case pprof.LookupReq:
    97  		resp, headers, err = pprof.Profile(args.Profile, args.Debug, args.GC)
    98  	case pprof.TraceReq:
    99  		resp, headers, err = pprof.Trace(a.srv.shutdownCtx, args.Seconds)
   100  	default:
   101  		err = structs.NewErrRPCCoded(404, "Unknown profile request type")
   102  	}
   103  
   104  	if err != nil {
   105  		if pprof.IsErrProfileNotFound(err) {
   106  			return structs.NewErrRPCCoded(404, err.Error())
   107  		}
   108  		return structs.NewErrRPCCoded(500, err.Error())
   109  	}
   110  
   111  	// Copy profile response to reply
   112  	reply.Payload = resp
   113  	reply.HTTPHeaders = headers
   114  	reply.AgentID = a.srv.serf.LocalMember().Name
   115  
   116  	return nil
   117  }
   118  
   119  func (a *Agent) monitor(conn io.ReadWriteCloser) {
   120  	defer conn.Close()
   121  
   122  	// Decode args
   123  	var args cstructs.MonitorRequest
   124  	decoder := codec.NewDecoder(conn, structs.MsgpackHandle)
   125  	encoder := codec.NewEncoder(conn, structs.MsgpackHandle)
   126  
   127  	if err := decoder.Decode(&args); err != nil {
   128  		handleStreamResultError(err, pointer.Of(int64(500)), encoder)
   129  		return
   130  	}
   131  
   132  	// Check agent read permissions
   133  	if aclObj, err := a.srv.ResolveToken(args.AuthToken); err != nil {
   134  		handleStreamResultError(err, nil, encoder)
   135  		return
   136  	} else if aclObj != nil && !aclObj.AllowAgentRead() {
   137  		handleStreamResultError(structs.ErrPermissionDenied, pointer.Of(int64(403)), encoder)
   138  		return
   139  	}
   140  
   141  	logLevel := log.LevelFromString(args.LogLevel)
   142  	if args.LogLevel == "" {
   143  		logLevel = log.LevelFromString("INFO")
   144  	}
   145  
   146  	if logLevel == log.NoLevel {
   147  		handleStreamResultError(errors.New("Unknown log level"), pointer.Of(int64(400)), encoder)
   148  		return
   149  	}
   150  
   151  	// Targeting a node, forward request to node
   152  	if args.NodeID != "" {
   153  		a.forwardMonitorClient(conn, args, encoder, decoder)
   154  		// forwarded request has ended, return
   155  		return
   156  	}
   157  
   158  	region := args.RequestRegion()
   159  	if region == "" {
   160  		handleStreamResultError(fmt.Errorf("missing target RPC"), pointer.Of(int64(400)), encoder)
   161  		return
   162  	}
   163  	if region != a.srv.config.Region {
   164  		// Mark that we are forwarding
   165  		args.SetForwarded()
   166  	}
   167  
   168  	// Try to forward request to remote region/server
   169  	if args.ServerID != "" {
   170  		serverToFwd, err := a.forwardFor(args.ServerID, region)
   171  		if err != nil {
   172  			handleStreamResultError(err, pointer.Of(int64(400)), encoder)
   173  			return
   174  		}
   175  		if serverToFwd != nil {
   176  			a.forwardMonitorServer(conn, serverToFwd, args, encoder, decoder)
   177  			return
   178  		}
   179  	}
   180  
   181  	// NodeID was empty, ServerID was equal to this server,  monitor this server
   182  	ctx, cancel := context.WithCancel(context.Background())
   183  	defer cancel()
   184  
   185  	monitor := monitor.New(512, a.srv.logger, &log.LoggerOptions{
   186  		Level:      logLevel,
   187  		JSONFormat: args.LogJSON,
   188  	})
   189  
   190  	frames := make(chan *sframer.StreamFrame, 32)
   191  	errCh := make(chan error)
   192  	var buf bytes.Buffer
   193  	frameCodec := codec.NewEncoder(&buf, structs.JsonHandle)
   194  
   195  	framer := sframer.NewStreamFramer(frames, 1*time.Second, 200*time.Millisecond, 1024)
   196  	framer.Run()
   197  	defer framer.Destroy()
   198  
   199  	// goroutine to detect remote side closing
   200  	go func() {
   201  		if _, err := conn.Read(nil); err != nil {
   202  			// One end of the pipe explicitly closed, exit
   203  			cancel()
   204  			return
   205  		}
   206  		<-ctx.Done()
   207  	}()
   208  
   209  	logCh := monitor.Start()
   210  	defer monitor.Stop()
   211  	initialOffset := int64(0)
   212  
   213  	// receive logs and build frames
   214  	go func() {
   215  		defer framer.Destroy()
   216  	LOOP:
   217  		for {
   218  			select {
   219  			case log := <-logCh:
   220  				if err := framer.Send("", "log", log, initialOffset); err != nil {
   221  					select {
   222  					case errCh <- err:
   223  					case <-ctx.Done():
   224  					}
   225  					break LOOP
   226  				}
   227  			case <-ctx.Done():
   228  				break LOOP
   229  			}
   230  		}
   231  	}()
   232  
   233  	var streamErr error
   234  OUTER:
   235  	for {
   236  		select {
   237  		case frame, ok := <-frames:
   238  			if !ok {
   239  				// frame may have been closed when an error
   240  				// occurred. Check once more for an error.
   241  				select {
   242  				case streamErr = <-errCh:
   243  					// There was a pending error!
   244  				default:
   245  					// No error, continue on
   246  				}
   247  
   248  				break OUTER
   249  			}
   250  
   251  			var resp cstructs.StreamErrWrapper
   252  			if args.PlainText {
   253  				resp.Payload = frame.Data
   254  			} else {
   255  				if err := frameCodec.Encode(frame); err != nil {
   256  					streamErr = err
   257  					break OUTER
   258  				}
   259  
   260  				resp.Payload = buf.Bytes()
   261  				buf.Reset()
   262  			}
   263  
   264  			if err := encoder.Encode(resp); err != nil {
   265  				streamErr = err
   266  				break OUTER
   267  			}
   268  			encoder.Reset(conn)
   269  		case <-ctx.Done():
   270  			break OUTER
   271  		}
   272  	}
   273  
   274  	if streamErr != nil {
   275  		handleStreamResultError(streamErr, pointer.Of(int64(500)), encoder)
   276  		return
   277  	}
   278  }
   279  
   280  // forwardFor returns a serverParts for a request to be forwarded to.
   281  // A response of nil, nil indicates that the current server is equal to the
   282  // serverID and region so the request should not be forwarded.
   283  func (a *Agent) forwardFor(serverID, region string) (*serverParts, error) {
   284  	var target *serverParts
   285  	var err error
   286  
   287  	if serverID == "leader" {
   288  		isLeader, remoteLeader := a.srv.getLeader()
   289  		if !isLeader && remoteLeader != nil {
   290  			target = remoteLeader
   291  		} else if !isLeader && remoteLeader == nil {
   292  			return nil, structs.ErrNoLeader
   293  		} else if isLeader {
   294  			// This server is current leader do not forward
   295  			return nil, nil
   296  		}
   297  	} else {
   298  		target, err = a.srv.getServer(region, serverID)
   299  		if err != nil {
   300  			return nil, err
   301  		}
   302  	}
   303  
   304  	// Unable to find a server
   305  	if target == nil {
   306  		return nil, fmt.Errorf("unknown nomad server %s", serverID)
   307  	}
   308  
   309  	// ServerID is this current server,
   310  	// No need to forward request
   311  	if target.Name == a.srv.LocalMember().Name {
   312  		return nil, nil
   313  	}
   314  
   315  	return target, nil
   316  }
   317  
   318  func (a *Agent) forwardMonitorClient(conn io.ReadWriteCloser, args cstructs.MonitorRequest, encoder *codec.Encoder, decoder *codec.Decoder) {
   319  	// Get the Connection to the client either by fowarding to another server
   320  	// or creating direct stream
   321  
   322  	state, srv, err := a.findClientConn(args.NodeID)
   323  	if err != nil {
   324  		handleStreamResultError(err, pointer.Of(int64(500)), encoder)
   325  		return
   326  	}
   327  
   328  	var clientConn net.Conn
   329  
   330  	if state == nil {
   331  		conn, err := a.srv.streamingRpc(srv, "Agent.Monitor")
   332  		if err != nil {
   333  			handleStreamResultError(err, nil, encoder)
   334  			return
   335  		}
   336  
   337  		clientConn = conn
   338  	} else {
   339  		stream, err := NodeStreamingRpc(state.Session, "Agent.Monitor")
   340  		if err != nil {
   341  			handleStreamResultError(err, nil, encoder)
   342  			return
   343  		}
   344  		clientConn = stream
   345  	}
   346  	defer clientConn.Close()
   347  
   348  	// Send the Request
   349  	outEncoder := codec.NewEncoder(clientConn, structs.MsgpackHandle)
   350  	if err := outEncoder.Encode(args); err != nil {
   351  		handleStreamResultError(err, nil, encoder)
   352  		return
   353  	}
   354  
   355  	structs.Bridge(conn, clientConn)
   356  }
   357  
   358  func (a *Agent) forwardMonitorServer(conn io.ReadWriteCloser, server *serverParts, args cstructs.MonitorRequest, encoder *codec.Encoder, decoder *codec.Decoder) {
   359  	// empty ServerID to prevent forwarding loop
   360  	args.ServerID = ""
   361  
   362  	serverConn, err := a.srv.streamingRpc(server, "Agent.Monitor")
   363  	if err != nil {
   364  		handleStreamResultError(err, pointer.Of(int64(500)), encoder)
   365  		return
   366  	}
   367  	defer serverConn.Close()
   368  
   369  	// Send the Request
   370  	outEncoder := codec.NewEncoder(serverConn, structs.MsgpackHandle)
   371  	if err := outEncoder.Encode(args); err != nil {
   372  		handleStreamResultError(err, pointer.Of(int64(500)), encoder)
   373  		return
   374  	}
   375  
   376  	structs.Bridge(conn, serverConn)
   377  }
   378  
   379  func (a *Agent) forwardProfileClient(args *structs.AgentPprofRequest, reply *structs.AgentPprofResponse) error {
   380  	state, srv, err := a.findClientConn(args.NodeID)
   381  
   382  	if err != nil {
   383  		return err
   384  	}
   385  
   386  	if srv != nil {
   387  		return a.srv.forwardServer(srv, "Agent.Profile", args, reply)
   388  	}
   389  
   390  	// NodeRpc
   391  	rpcErr := NodeRpc(state.Session, "Agent.Profile", args, reply)
   392  	if rpcErr != nil {
   393  		return rpcErr
   394  	}
   395  
   396  	return nil
   397  }
   398  
   399  // Host returns data about the agent's host system for the `debug` command.
   400  func (a *Agent) Host(args *structs.HostDataRequest, reply *structs.HostDataResponse) error {
   401  
   402  	aclObj, err := a.srv.ResolveToken(args.AuthToken)
   403  	if err != nil {
   404  		return err
   405  	}
   406  	if (aclObj != nil && !aclObj.AllowAgentRead()) ||
   407  		(aclObj == nil && !a.srv.config.EnableDebug) {
   408  		return structs.ErrPermissionDenied
   409  	}
   410  
   411  	// Forward to different region if necessary
   412  	// this would typically be done in a.srv.forward() but since
   413  	// we are targeting a specific server, not just the leader
   414  	// we must manually handle region forwarding here.
   415  	region := args.RequestRegion()
   416  	if region == "" {
   417  		return fmt.Errorf("missing target RPC")
   418  	}
   419  
   420  	if region != a.srv.config.Region {
   421  		// Mark that we are forwarding
   422  		args.SetForwarded()
   423  		return a.srv.forwardRegion(region, "Agent.Host", args, reply)
   424  	}
   425  
   426  	// Targeting a client node, forward request to node
   427  	if args.NodeID != "" {
   428  		client, srv, err := a.findClientConn(args.NodeID)
   429  
   430  		if err != nil {
   431  			return err
   432  		}
   433  
   434  		if srv != nil {
   435  			return a.srv.forwardServer(srv, "Agent.Host", args, reply)
   436  		}
   437  
   438  		return NodeRpc(client.Session, "Agent.Host", args, reply)
   439  	}
   440  
   441  	// Handle serverID not equal to ours
   442  	if args.ServerID != "" {
   443  		srv, err := a.forwardFor(args.ServerID, region)
   444  		if err != nil {
   445  			return err
   446  		}
   447  		if srv != nil {
   448  			return a.srv.forwardServer(srv, "Agent.Host", args, reply)
   449  		}
   450  	}
   451  
   452  	data, err := host.MakeHostData()
   453  	if err != nil {
   454  		return err
   455  	}
   456  
   457  	reply.AgentID = a.srv.serf.LocalMember().Name
   458  	reply.HostData = data
   459  	return nil
   460  }
   461  
   462  // findClientConn is a helper that returns a connection to the client node or, if the client
   463  // is connected to a different server, a serverParts describing the server to which the
   464  // client bound RPC should be forwarded.
   465  func (a *Agent) findClientConn(nodeID string) (*nodeConnState, *serverParts, error) {
   466  	snap, err := a.srv.State().Snapshot()
   467  	if err != nil {
   468  		return nil, nil, structs.NewErrRPCCoded(500, err.Error())
   469  	}
   470  
   471  	node, err := snap.NodeByID(nil, nodeID)
   472  	if err != nil {
   473  		return nil, nil, structs.NewErrRPCCoded(500, err.Error())
   474  	}
   475  
   476  	if node == nil {
   477  		err := fmt.Errorf("Unknown node %q", nodeID)
   478  		return nil, nil, structs.NewErrRPCCoded(404, err.Error())
   479  	}
   480  
   481  	if err := nodeSupportsRpc(node); err != nil {
   482  		return nil, nil, structs.NewErrRPCCoded(400, err.Error())
   483  	}
   484  
   485  	// Get the Connection to the client either by fowarding to another server
   486  	// or creating direct stream
   487  	state, ok := a.srv.getNodeConn(nodeID)
   488  	if ok {
   489  		return state, nil, nil
   490  	}
   491  
   492  	// Determine the server that has a connection to the node
   493  	srv, err := a.srv.serverWithNodeConn(nodeID, a.srv.Region())
   494  	if err != nil {
   495  		code := 500
   496  		if structs.IsErrNoNodeConn(err) {
   497  			code = 404
   498  		}
   499  		return nil, nil, structs.NewErrRPCCoded(code, err.Error())
   500  	}
   501  
   502  	return nil, srv, nil
   503  }