github.com/thomasobenaus/nomad@v0.11.1/nomad/client_agent_endpoint.go (about)

     1  package nomad
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"time"
    11  
    12  	log "github.com/hashicorp/go-hclog"
    13  	sframer "github.com/hashicorp/nomad/client/lib/streamframer"
    14  	cstructs "github.com/hashicorp/nomad/client/structs"
    15  	"github.com/hashicorp/nomad/command/agent/monitor"
    16  	"github.com/hashicorp/nomad/command/agent/pprof"
    17  	"github.com/hashicorp/nomad/helper"
    18  	"github.com/hashicorp/nomad/nomad/structs"
    19  
    20  	"github.com/hashicorp/go-msgpack/codec"
    21  )
    22  
    23  type Agent struct {
    24  	srv *Server
    25  }
    26  
    27  func (a *Agent) register() {
    28  	a.srv.streamingRpcs.Register("Agent.Monitor", a.monitor)
    29  }
    30  
    31  func (a *Agent) Profile(args *structs.AgentPprofRequest, reply *structs.AgentPprofResponse) error {
    32  	// Check ACL for agent write
    33  	aclObj, err := a.srv.ResolveToken(args.AuthToken)
    34  	if err != nil {
    35  		return err
    36  	} else if aclObj != nil && !aclObj.AllowAgentWrite() {
    37  		return structs.ErrPermissionDenied
    38  	}
    39  
    40  	// Forward to different region if necessary
    41  	// this would typically be done in a.srv.forward() but since
    42  	// we are targeting a specific server, not just the leader
    43  	// we must manually handle region forwarding here.
    44  	region := args.RequestRegion()
    45  	if region == "" {
    46  		return fmt.Errorf("missing target RPC")
    47  	}
    48  
    49  	if region != a.srv.config.Region {
    50  		// Mark that we are forwarding
    51  		args.SetForwarded()
    52  		return a.srv.forwardRegion(region, "Agent.Profile", args, reply)
    53  	}
    54  
    55  	// Targeting a node, forward request to node
    56  	if args.NodeID != "" {
    57  		return a.forwardProfileClient(args, reply)
    58  	}
    59  
    60  	// Handle serverID not equal to ours
    61  	if args.ServerID != "" {
    62  		serverToFwd, err := a.forwardFor(args.ServerID, region)
    63  		if err != nil {
    64  			return err
    65  		}
    66  		if serverToFwd != nil {
    67  			return a.srv.forwardServer(serverToFwd, "Agent.Profile", args, reply)
    68  		}
    69  	}
    70  
    71  	// If ACLs are disabled, EnableDebug must be enabled
    72  	if aclObj == nil && !a.srv.config.EnableDebug {
    73  		return structs.ErrPermissionDenied
    74  	}
    75  
    76  	// Process the request on this server
    77  	var resp []byte
    78  	var headers map[string]string
    79  
    80  	// Determine which profile to run and generate profile.
    81  	// Blocks for args.Seconds
    82  	// Our RPC endpoints currently don't support context
    83  	// or request cancellation so using server shutdownCtx as a
    84  	// best effort.
    85  	switch args.ReqType {
    86  	case pprof.CPUReq:
    87  		resp, headers, err = pprof.CPUProfile(a.srv.shutdownCtx, args.Seconds)
    88  	case pprof.CmdReq:
    89  		resp, headers, err = pprof.Cmdline()
    90  	case pprof.LookupReq:
    91  		resp, headers, err = pprof.Profile(args.Profile, args.Debug, args.GC)
    92  	case pprof.TraceReq:
    93  		resp, headers, err = pprof.Trace(a.srv.shutdownCtx, args.Seconds)
    94  	default:
    95  		err = structs.NewErrRPCCoded(404, "Unknown profile request type")
    96  	}
    97  
    98  	if err != nil {
    99  		if pprof.IsErrProfileNotFound(err) {
   100  			return structs.NewErrRPCCoded(404, err.Error())
   101  		}
   102  		return structs.NewErrRPCCoded(500, err.Error())
   103  	}
   104  
   105  	// Copy profile response to reply
   106  	reply.Payload = resp
   107  	reply.HTTPHeaders = headers
   108  	reply.AgentID = a.srv.serf.LocalMember().Name
   109  
   110  	return nil
   111  }
   112  
   113  func (a *Agent) monitor(conn io.ReadWriteCloser) {
   114  	defer conn.Close()
   115  
   116  	// Decode args
   117  	var args cstructs.MonitorRequest
   118  	decoder := codec.NewDecoder(conn, structs.MsgpackHandle)
   119  	encoder := codec.NewEncoder(conn, structs.MsgpackHandle)
   120  
   121  	if err := decoder.Decode(&args); err != nil {
   122  		handleStreamResultError(err, helper.Int64ToPtr(500), encoder)
   123  		return
   124  	}
   125  
   126  	// Check agent read permissions
   127  	if aclObj, err := a.srv.ResolveToken(args.AuthToken); err != nil {
   128  		handleStreamResultError(err, nil, encoder)
   129  		return
   130  	} else if aclObj != nil && !aclObj.AllowAgentRead() {
   131  		handleStreamResultError(structs.ErrPermissionDenied, helper.Int64ToPtr(403), encoder)
   132  		return
   133  	}
   134  
   135  	logLevel := log.LevelFromString(args.LogLevel)
   136  	if args.LogLevel == "" {
   137  		logLevel = log.LevelFromString("INFO")
   138  	}
   139  
   140  	if logLevel == log.NoLevel {
   141  		handleStreamResultError(errors.New("Unknown log level"), helper.Int64ToPtr(400), encoder)
   142  		return
   143  	}
   144  
   145  	// Targeting a node, forward request to node
   146  	if args.NodeID != "" {
   147  		a.forwardMonitorClient(conn, args, encoder, decoder)
   148  		// forwarded request has ended, return
   149  		return
   150  	}
   151  
   152  	region := args.RequestRegion()
   153  	if region == "" {
   154  		handleStreamResultError(fmt.Errorf("missing target RPC"), helper.Int64ToPtr(400), encoder)
   155  		return
   156  	}
   157  	if region != a.srv.config.Region {
   158  		// Mark that we are forwarding
   159  		args.SetForwarded()
   160  	}
   161  
   162  	// Try to forward request to remote region/server
   163  	if args.ServerID != "" {
   164  		serverToFwd, err := a.forwardFor(args.ServerID, region)
   165  		if err != nil {
   166  			handleStreamResultError(err, helper.Int64ToPtr(400), encoder)
   167  			return
   168  		}
   169  		if serverToFwd != nil {
   170  			a.forwardMonitorServer(conn, serverToFwd, args, encoder, decoder)
   171  			return
   172  		}
   173  	}
   174  
   175  	// NodeID was empty, ServerID was equal to this server,  monitor this server
   176  	ctx, cancel := context.WithCancel(context.Background())
   177  	defer cancel()
   178  
   179  	monitor := monitor.New(512, a.srv.logger, &log.LoggerOptions{
   180  		Level:      logLevel,
   181  		JSONFormat: args.LogJSON,
   182  	})
   183  
   184  	frames := make(chan *sframer.StreamFrame, 32)
   185  	errCh := make(chan error)
   186  	var buf bytes.Buffer
   187  	frameCodec := codec.NewEncoder(&buf, structs.JsonHandle)
   188  
   189  	framer := sframer.NewStreamFramer(frames, 1*time.Second, 200*time.Millisecond, 1024)
   190  	framer.Run()
   191  	defer framer.Destroy()
   192  
   193  	// goroutine to detect remote side closing
   194  	go func() {
   195  		if _, err := conn.Read(nil); err != nil {
   196  			// One end of the pipe explicitly closed, exit
   197  			cancel()
   198  			return
   199  		}
   200  		select {
   201  		case <-ctx.Done():
   202  			return
   203  		}
   204  	}()
   205  
   206  	logCh := monitor.Start()
   207  	defer monitor.Stop()
   208  	initialOffset := int64(0)
   209  
   210  	// receive logs and build frames
   211  	go func() {
   212  		defer framer.Destroy()
   213  	LOOP:
   214  		for {
   215  			select {
   216  			case log := <-logCh:
   217  				if err := framer.Send("", "log", log, initialOffset); err != nil {
   218  					select {
   219  					case errCh <- err:
   220  					case <-ctx.Done():
   221  					}
   222  					break LOOP
   223  				}
   224  			case <-ctx.Done():
   225  				break LOOP
   226  			}
   227  		}
   228  	}()
   229  
   230  	var streamErr error
   231  OUTER:
   232  	for {
   233  		select {
   234  		case frame, ok := <-frames:
   235  			if !ok {
   236  				// frame may have been closed when an error
   237  				// occurred. Check once more for an error.
   238  				select {
   239  				case streamErr = <-errCh:
   240  					// There was a pending error!
   241  				default:
   242  					// No error, continue on
   243  				}
   244  
   245  				break OUTER
   246  			}
   247  
   248  			var resp cstructs.StreamErrWrapper
   249  			if args.PlainText {
   250  				resp.Payload = frame.Data
   251  			} else {
   252  				if err := frameCodec.Encode(frame); err != nil {
   253  					streamErr = err
   254  					break OUTER
   255  				}
   256  
   257  				resp.Payload = buf.Bytes()
   258  				buf.Reset()
   259  			}
   260  
   261  			if err := encoder.Encode(resp); err != nil {
   262  				streamErr = err
   263  				break OUTER
   264  			}
   265  			encoder.Reset(conn)
   266  		case <-ctx.Done():
   267  			break OUTER
   268  		}
   269  	}
   270  
   271  	if streamErr != nil {
   272  		handleStreamResultError(streamErr, helper.Int64ToPtr(500), encoder)
   273  		return
   274  	}
   275  }
   276  
   277  // forwardFor returns a serverParts for a request to be forwarded to.
   278  // A response of nil, nil indicates that the current server is equal to the
   279  // serverID and region so the request should not be forwarded.
   280  func (a *Agent) forwardFor(serverID, region string) (*serverParts, error) {
   281  	var target *serverParts
   282  	var err error
   283  
   284  	if serverID == "leader" {
   285  		isLeader, remoteLeader := a.srv.getLeader()
   286  		if !isLeader && remoteLeader != nil {
   287  			target = remoteLeader
   288  		} else if !isLeader && remoteLeader == nil {
   289  			return nil, structs.ErrNoLeader
   290  		} else if isLeader {
   291  			// This server is current leader do not forward
   292  			return nil, nil
   293  		}
   294  	} else {
   295  		target, err = a.srv.getServer(region, serverID)
   296  		if err != nil {
   297  			return nil, err
   298  		}
   299  	}
   300  
   301  	// Unable to find a server
   302  	if target == nil {
   303  		return nil, fmt.Errorf("unknown nomad server %s", serverID)
   304  	}
   305  
   306  	// ServerID is this current server,
   307  	// No need to forward request
   308  	if target.Name == a.srv.LocalMember().Name {
   309  		return nil, nil
   310  	}
   311  
   312  	return target, nil
   313  }
   314  
   315  func (a *Agent) forwardMonitorClient(conn io.ReadWriteCloser, args cstructs.MonitorRequest, encoder *codec.Encoder, decoder *codec.Decoder) {
   316  	nodeID := args.NodeID
   317  
   318  	snap, err := a.srv.State().Snapshot()
   319  	if err != nil {
   320  		handleStreamResultError(err, nil, encoder)
   321  		return
   322  	}
   323  
   324  	node, err := snap.NodeByID(nil, nodeID)
   325  	if err != nil {
   326  		handleStreamResultError(err, helper.Int64ToPtr(500), encoder)
   327  		return
   328  	}
   329  
   330  	if node == nil {
   331  		err := fmt.Errorf("Unknown node %q", nodeID)
   332  		handleStreamResultError(err, helper.Int64ToPtr(400), encoder)
   333  		return
   334  	}
   335  
   336  	if err := nodeSupportsRpc(node); err != nil {
   337  		handleStreamResultError(err, helper.Int64ToPtr(400), encoder)
   338  		return
   339  	}
   340  
   341  	// Get the Connection to the client either by fowarding to another server
   342  	// or creating direct stream
   343  	var clientConn net.Conn
   344  	state, ok := a.srv.getNodeConn(nodeID)
   345  	if !ok {
   346  		// Determine the server that has a connection to the node
   347  		srv, err := a.srv.serverWithNodeConn(nodeID, a.srv.Region())
   348  		if err != nil {
   349  			var code *int64
   350  			if structs.IsErrNoNodeConn(err) {
   351  				code = helper.Int64ToPtr(404)
   352  			}
   353  			handleStreamResultError(err, code, encoder)
   354  			return
   355  		}
   356  		conn, err := a.srv.streamingRpc(srv, "Agent.Monitor")
   357  		if err != nil {
   358  			handleStreamResultError(err, nil, encoder)
   359  			return
   360  		}
   361  
   362  		clientConn = conn
   363  	} else {
   364  		stream, err := NodeStreamingRpc(state.Session, "Agent.Monitor")
   365  		if err != nil {
   366  			handleStreamResultError(err, nil, encoder)
   367  			return
   368  		}
   369  		clientConn = stream
   370  	}
   371  	defer clientConn.Close()
   372  
   373  	// Send the Request
   374  	outEncoder := codec.NewEncoder(clientConn, structs.MsgpackHandle)
   375  	if err := outEncoder.Encode(args); err != nil {
   376  		handleStreamResultError(err, nil, encoder)
   377  		return
   378  	}
   379  
   380  	structs.Bridge(conn, clientConn)
   381  	return
   382  }
   383  
   384  func (a *Agent) forwardMonitorServer(conn io.ReadWriteCloser, server *serverParts, args cstructs.MonitorRequest, encoder *codec.Encoder, decoder *codec.Decoder) {
   385  	// empty ServerID to prevent forwarding loop
   386  	args.ServerID = ""
   387  
   388  	serverConn, err := a.srv.streamingRpc(server, "Agent.Monitor")
   389  	if err != nil {
   390  		handleStreamResultError(err, helper.Int64ToPtr(500), encoder)
   391  		return
   392  	}
   393  	defer serverConn.Close()
   394  
   395  	// Send the Request
   396  	outEncoder := codec.NewEncoder(serverConn, structs.MsgpackHandle)
   397  	if err := outEncoder.Encode(args); err != nil {
   398  		handleStreamResultError(err, helper.Int64ToPtr(500), encoder)
   399  		return
   400  	}
   401  
   402  	structs.Bridge(conn, serverConn)
   403  	return
   404  }
   405  
   406  func (a *Agent) forwardProfileClient(args *structs.AgentPprofRequest, reply *structs.AgentPprofResponse) error {
   407  	nodeID := args.NodeID
   408  
   409  	snap, err := a.srv.State().Snapshot()
   410  	if err != nil {
   411  		return structs.NewErrRPCCoded(500, err.Error())
   412  	}
   413  
   414  	node, err := snap.NodeByID(nil, nodeID)
   415  	if err != nil {
   416  		return structs.NewErrRPCCoded(500, err.Error())
   417  	}
   418  
   419  	if node == nil {
   420  		err := fmt.Errorf("Unknown node %q", nodeID)
   421  		return structs.NewErrRPCCoded(404, err.Error())
   422  	}
   423  
   424  	if err := nodeSupportsRpc(node); err != nil {
   425  		return structs.NewErrRPCCoded(400, err.Error())
   426  	}
   427  
   428  	// Get the Connection to the client either by fowarding to another server
   429  	// or creating direct stream
   430  	state, ok := a.srv.getNodeConn(nodeID)
   431  	if !ok {
   432  		// Determine the server that has a connection to the node
   433  		srv, err := a.srv.serverWithNodeConn(nodeID, a.srv.Region())
   434  		if err != nil {
   435  			code := 500
   436  			if structs.IsErrNoNodeConn(err) {
   437  				code = 404
   438  			}
   439  			return structs.NewErrRPCCoded(code, err.Error())
   440  		}
   441  
   442  		rpcErr := a.srv.forwardServer(srv, "Agent.Profile", args, reply)
   443  		if rpcErr != nil {
   444  			return rpcErr
   445  		}
   446  	} else {
   447  		// NodeRpc
   448  		rpcErr := NodeRpc(state.Session, "Agent.Profile", args, reply)
   449  		if rpcErr != nil {
   450  			return rpcErr
   451  		}
   452  	}
   453  
   454  	return nil
   455  }