github.com/mwhudson/juju@v0.0.0-20160512215208-90ff01f3497f/apiserver/apiserver.go (about)

     1  // Copyright 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package apiserver
     5  
     6  import (
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"net"
    10  	"net/http"
    11  	"strings"
    12  	"sync"
    13  	"sync/atomic"
    14  	"time"
    15  
    16  	"github.com/bmizerany/pat"
    17  	"github.com/juju/errors"
    18  	"github.com/juju/loggo"
    19  	"github.com/juju/names"
    20  	"github.com/juju/utils"
    21  	"golang.org/x/net/websocket"
    22  	"launchpad.net/tomb"
    23  
    24  	"github.com/juju/juju/apiserver/common"
    25  	"github.com/juju/juju/apiserver/common/apihttp"
    26  	"github.com/juju/juju/apiserver/params"
    27  	"github.com/juju/juju/rpc"
    28  	"github.com/juju/juju/rpc/jsoncodec"
    29  	"github.com/juju/juju/state"
    30  )
    31  
    32  var logger = loggo.GetLogger("juju.apiserver")
    33  
    34  // loginRateLimit defines how many concurrent Login requests we will
    35  // accept
    36  const loginRateLimit = 10
    37  
    38  // Server holds the server side of the API.
    39  type Server struct {
    40  	tomb              tomb.Tomb
    41  	wg                sync.WaitGroup
    42  	state             *state.State
    43  	statePool         *state.StatePool
    44  	lis               net.Listener
    45  	tag               names.Tag
    46  	dataDir           string
    47  	logDir            string
    48  	limiter           utils.Limiter
    49  	validator         LoginValidator
    50  	adminApiFactories map[int]adminApiFactory
    51  	modelUUID         string
    52  	authCtxt          *authContext
    53  	connections       int32 // count of active websocket connections
    54  }
    55  
    56  // LoginValidator functions are used to decide whether login requests
    57  // are to be allowed. The validator is called before credentials are
    58  // checked.
    59  type LoginValidator func(params.LoginRequest) error
    60  
    61  // ServerConfig holds parameters required to set up an API server.
    62  type ServerConfig struct {
    63  	Cert        []byte
    64  	Key         []byte
    65  	Tag         names.Tag
    66  	DataDir     string
    67  	LogDir      string
    68  	Validator   LoginValidator
    69  	CertChanged chan params.StateServingInfo
    70  
    71  	// This field only exists to support testing.
    72  	StatePool *state.StatePool
    73  }
    74  
    75  // changeCertListener wraps a TLS net.Listener.
    76  // It allows connection handshakes to be
    77  // blocked while the TLS certificate is updated.
    78  type changeCertListener struct {
    79  	net.Listener
    80  	tomb tomb.Tomb
    81  
    82  	// A mutex used to block accept operations.
    83  	m sync.Mutex
    84  
    85  	// A channel used to pass in new certificate information.
    86  	certChanged <-chan params.StateServingInfo
    87  
    88  	// The config to update with any new certificate.
    89  	config *tls.Config
    90  }
    91  
    92  func newChangeCertListener(lis net.Listener, certChanged <-chan params.StateServingInfo, config *tls.Config) *changeCertListener {
    93  	cl := &changeCertListener{
    94  		Listener:    lis,
    95  		certChanged: certChanged,
    96  		config:      config,
    97  	}
    98  	go func() {
    99  		defer cl.tomb.Done()
   100  		cl.tomb.Kill(cl.processCertChanges())
   101  	}()
   102  	return cl
   103  }
   104  
   105  // Accept waits for and returns the next connection to the listener.
   106  func (cl *changeCertListener) Accept() (net.Conn, error) {
   107  	conn, err := cl.Listener.Accept()
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  	cl.m.Lock()
   112  	defer cl.m.Unlock()
   113  	config := cl.config
   114  	return tls.Server(conn, config), nil
   115  }
   116  
   117  // Close closes the listener.
   118  func (cl *changeCertListener) Close() error {
   119  	cl.tomb.Kill(nil)
   120  	return cl.Listener.Close()
   121  }
   122  
   123  // processCertChanges receives new certificate information and
   124  // calls a method to update the listener's certificate.
   125  func (cl *changeCertListener) processCertChanges() error {
   126  	for {
   127  		select {
   128  		case info := <-cl.certChanged:
   129  			if info.Cert != "" {
   130  				cl.updateCertificate([]byte(info.Cert), []byte(info.PrivateKey))
   131  			}
   132  		case <-cl.tomb.Dying():
   133  			return tomb.ErrDying
   134  		}
   135  	}
   136  }
   137  
   138  // updateCertificate generates a new TLS certificate and assigns it
   139  // to the TLS listener.
   140  func (cl *changeCertListener) updateCertificate(cert, key []byte) {
   141  	cl.m.Lock()
   142  	defer cl.m.Unlock()
   143  	if tlsCert, err := tls.X509KeyPair(cert, key); err != nil {
   144  		logger.Errorf("cannot create new TLS certificate: %v", err)
   145  	} else {
   146  		logger.Infof("updating api server certificate")
   147  		x509Cert, err := x509.ParseCertificate(tlsCert.Certificate[0])
   148  		if err == nil {
   149  			var addr []string
   150  			for _, ip := range x509Cert.IPAddresses {
   151  				addr = append(addr, ip.String())
   152  			}
   153  			logger.Infof("new certificate addresses: %v", strings.Join(addr, ", "))
   154  		}
   155  		cl.config.Certificates = []tls.Certificate{tlsCert}
   156  	}
   157  }
   158  
   159  // NewServer serves the given state by accepting requests on the given
   160  // listener, using the given certificate and key (in PEM format) for
   161  // authentication.
   162  //
   163  // The Server will close the listener when it exits, even if returns an error.
   164  func NewServer(s *state.State, lis net.Listener, cfg ServerConfig) (*Server, error) {
   165  	// Important note:
   166  	// Do not manipulate the state within NewServer as the API
   167  	// server needs to run before mongo upgrades have happened and
   168  	// any state manipulation may be be relying on features of the
   169  	// database added by upgrades. Here be dragons.
   170  	l, ok := lis.(*net.TCPListener)
   171  	if !ok {
   172  		return nil, errors.Errorf("listener is not of type *net.TCPListener: %T", lis)
   173  	}
   174  	srv, err := newServer(s, l, cfg)
   175  	if err != nil {
   176  		// There is no running server around to close the listener.
   177  		lis.Close()
   178  		return nil, errors.Trace(err)
   179  	}
   180  	return srv, nil
   181  }
   182  
   183  func newServer(s *state.State, lis *net.TCPListener, cfg ServerConfig) (_ *Server, err error) {
   184  	tlsCert, err := tls.X509KeyPair(cfg.Cert, cfg.Key)
   185  	if err != nil {
   186  		return nil, err
   187  	}
   188  	// TODO(rog) check that *srvRoot is a valid type for using
   189  	// as an RPC server.
   190  	tlsConfig := utils.SecureTLSConfig()
   191  	tlsConfig.Certificates = []tls.Certificate{tlsCert}
   192  
   193  	stPool := cfg.StatePool
   194  	if stPool == nil {
   195  		stPool = state.NewStatePool(s)
   196  	}
   197  
   198  	srv := &Server{
   199  		state:     s,
   200  		statePool: stPool,
   201  		lis:       newChangeCertListener(lis, cfg.CertChanged, tlsConfig),
   202  		tag:       cfg.Tag,
   203  		dataDir:   cfg.DataDir,
   204  		logDir:    cfg.LogDir,
   205  		limiter:   utils.NewLimiter(loginRateLimit),
   206  		validator: cfg.Validator,
   207  		adminApiFactories: map[int]adminApiFactory{
   208  			3: newAdminApiV3,
   209  		},
   210  	}
   211  	srv.authCtxt, err = newAuthContext(s)
   212  	if err != nil {
   213  		return nil, errors.Trace(err)
   214  	}
   215  	go srv.run()
   216  	return srv, nil
   217  }
   218  
   219  // Dead returns a channel that signals when the server has exited.
   220  func (srv *Server) Dead() <-chan struct{} {
   221  	return srv.tomb.Dead()
   222  }
   223  
   224  // Stop stops the server and returns when all running requests
   225  // have completed.
   226  func (srv *Server) Stop() error {
   227  	srv.tomb.Kill(nil)
   228  	return srv.tomb.Wait()
   229  }
   230  
   231  // Kill implements worker.Worker.Kill.
   232  func (srv *Server) Kill() {
   233  	srv.tomb.Kill(nil)
   234  }
   235  
   236  // Wait implements worker.Worker.Wait.
   237  func (srv *Server) Wait() error {
   238  	return srv.tomb.Wait()
   239  }
   240  
   241  type requestNotifier struct {
   242  	id    int64
   243  	start time.Time
   244  
   245  	mu   sync.Mutex
   246  	tag_ string
   247  
   248  	// count is incremented by calls to join, and deincremented
   249  	// by calls to leave.
   250  	count *int32
   251  }
   252  
   253  var globalCounter int64
   254  
   255  func newRequestNotifier(count *int32) *requestNotifier {
   256  	return &requestNotifier{
   257  		id:   atomic.AddInt64(&globalCounter, 1),
   258  		tag_: "<unknown>",
   259  		// TODO(fwereade): 2016-03-17 lp:1558657
   260  		start: time.Now(),
   261  		count: count,
   262  	}
   263  }
   264  
   265  func (n *requestNotifier) login(tag string) {
   266  	n.mu.Lock()
   267  	n.tag_ = tag
   268  	n.mu.Unlock()
   269  }
   270  
   271  func (n *requestNotifier) tag() (tag string) {
   272  	n.mu.Lock()
   273  	tag = n.tag_
   274  	n.mu.Unlock()
   275  	return
   276  }
   277  
   278  func (n *requestNotifier) ServerRequest(hdr *rpc.Header, body interface{}) {
   279  	if hdr.Request.Type == "Pinger" && hdr.Request.Action == "Ping" {
   280  		return
   281  	}
   282  	// TODO(rog) 2013-10-11 remove secrets from some requests.
   283  	// Until secrets are removed, we only log the body of the requests at trace level
   284  	// which is below the default level of debug.
   285  	if logger.IsTraceEnabled() {
   286  		logger.Tracef("<- [%X] %s %s", n.id, n.tag(), jsoncodec.DumpRequest(hdr, body))
   287  	} else {
   288  		logger.Debugf("<- [%X] %s %s", n.id, n.tag(), jsoncodec.DumpRequest(hdr, "'params redacted'"))
   289  	}
   290  }
   291  
   292  func (n *requestNotifier) ServerReply(req rpc.Request, hdr *rpc.Header, body interface{}, timeSpent time.Duration) {
   293  	if req.Type == "Pinger" && req.Action == "Ping" {
   294  		return
   295  	}
   296  	// TODO(rog) 2013-10-11 remove secrets from some responses.
   297  	// Until secrets are removed, we only log the body of the requests at trace level
   298  	// which is below the default level of debug.
   299  	if logger.IsTraceEnabled() {
   300  		logger.Tracef("-> [%X] %s %s", n.id, n.tag(), jsoncodec.DumpRequest(hdr, body))
   301  	} else {
   302  		logger.Debugf("-> [%X] %s %s %s %s[%q].%s", n.id, n.tag(), timeSpent, jsoncodec.DumpRequest(hdr, "'body redacted'"), req.Type, req.Id, req.Action)
   303  	}
   304  }
   305  
   306  func (n *requestNotifier) join(req *http.Request) {
   307  	active := atomic.AddInt32(n.count, 1)
   308  	logger.Infof("[%X] API connection from %s, active connections: %d", n.id, req.RemoteAddr, active)
   309  }
   310  
   311  func (n *requestNotifier) leave() {
   312  	active := atomic.AddInt32(n.count, -1)
   313  	logger.Infof("[%X] %s API connection terminated after %v, active connections: %d", n.id, n.tag(), time.Since(n.start), active)
   314  }
   315  
   316  func (n *requestNotifier) ClientRequest(hdr *rpc.Header, body interface{}) {
   317  }
   318  
   319  func (n *requestNotifier) ClientReply(req rpc.Request, hdr *rpc.Header, body interface{}) {
   320  }
   321  
   322  func (srv *Server) run() {
   323  	logger.Infof("listening on %q", srv.lis.Addr())
   324  
   325  	defer func() {
   326  		addr := srv.lis.Addr().String() // Addr not valid after close
   327  		err := srv.lis.Close()
   328  		logger.Infof("closed listening socket %q with final error: %v", addr, err)
   329  
   330  		srv.state.HackLeadership() // Break deadlocks caused by BlockUntil... calls.
   331  		srv.wg.Wait()              // wait for any outstanding requests to complete.
   332  		srv.tomb.Done()
   333  		srv.statePool.Close()
   334  		srv.state.Close()
   335  	}()
   336  
   337  	srv.wg.Add(1)
   338  	go func() {
   339  		defer srv.wg.Done()
   340  		srv.tomb.Kill(srv.mongoPinger())
   341  	}()
   342  
   343  	// for pat based handlers, they are matched in-order of being
   344  	// registered, first match wins. So more specific ones have to be
   345  	// registered first.
   346  	mux := pat.New()
   347  	for _, endpoint := range srv.endpoints() {
   348  		registerEndpoint(endpoint, mux)
   349  	}
   350  
   351  	go func() {
   352  		addr := srv.lis.Addr() // not valid after addr closed
   353  		logger.Debugf("Starting API http server on address %q", addr)
   354  		err := http.Serve(srv.lis, mux)
   355  		// normally logging an error at debug level would be grounds for a beating,
   356  		// however in this case the error is *expected* to be non nil, and does not
   357  		// affect the operation of the apiserver, but for completeness log it anyway.
   358  		logger.Debugf("API http server exited, final error was: %v", err)
   359  	}()
   360  
   361  	<-srv.tomb.Dying()
   362  }
   363  
   364  func (srv *Server) endpoints() []apihttp.Endpoint {
   365  	httpCtxt := httpContext{
   366  		srv: srv,
   367  	}
   368  
   369  	endpoints := common.ResolveAPIEndpoints(srv.newHandlerArgs)
   370  
   371  	// TODO(ericsnow) Add the following to the registry instead.
   372  
   373  	add := func(pattern string, handler http.Handler) {
   374  		// TODO: We can switch from all methods to specific ones for entries
   375  		// where we only want to support specific request methods. However, our
   376  		// tests currently assert that errors come back as application/json and
   377  		// pat only does "text/plain" responses.
   378  		for _, method := range common.DefaultHTTPMethods {
   379  			endpoints = append(endpoints, apihttp.Endpoint{
   380  				Pattern: pattern,
   381  				Method:  method,
   382  				Handler: handler,
   383  			})
   384  		}
   385  	}
   386  
   387  	mainAPIHandler := srv.trackRequests(http.HandlerFunc(srv.apiHandler))
   388  	logSinkHandler := srv.trackRequests(newLogSinkHandler(httpCtxt, srv.logDir))
   389  	debugLogHandler := srv.trackRequests(newDebugLogDBHandler(httpCtxt))
   390  
   391  	add("/model/:modeluuid/logsink", logSinkHandler)
   392  	add("/model/:modeluuid/log", debugLogHandler)
   393  	add("/model/:modeluuid/charms",
   394  		&charmsHandler{
   395  			ctxt:    httpCtxt,
   396  			dataDir: srv.dataDir},
   397  	)
   398  	add("/model/:modeluuid/tools",
   399  		&toolsUploadHandler{
   400  			ctxt: httpCtxt,
   401  		},
   402  	)
   403  	add("/model/:modeluuid/tools/:version",
   404  		&toolsDownloadHandler{
   405  			ctxt: httpCtxt,
   406  		},
   407  	)
   408  	strictCtxt := httpCtxt
   409  	strictCtxt.strictValidation = true
   410  	strictCtxt.controllerModelOnly = true
   411  	add("/model/:modeluuid/backups",
   412  		&backupHandler{
   413  			ctxt: strictCtxt,
   414  		},
   415  	)
   416  	add("/model/:modeluuid/api", mainAPIHandler)
   417  
   418  	add("/model/:modeluuid/images/:kind/:series/:arch/:filename",
   419  		&imagesDownloadHandler{
   420  			ctxt:    httpCtxt,
   421  			dataDir: srv.dataDir,
   422  			state:   srv.state,
   423  		},
   424  	)
   425  
   426  	endpoints = append(endpoints, guiEndpoints("/gui/:modeluuid/", srv.dataDir, httpCtxt)...)
   427  	add("/gui-archive", &guiArchiveHandler{
   428  		ctxt: httpCtxt,
   429  	})
   430  	add("/gui-version", &guiVersionHandler{
   431  		ctxt: httpCtxt,
   432  	})
   433  
   434  	// For backwards compatibility we register all the old paths
   435  	add("/log", debugLogHandler)
   436  
   437  	add("/charms",
   438  		&charmsHandler{
   439  			ctxt:    httpCtxt,
   440  			dataDir: srv.dataDir,
   441  		},
   442  	)
   443  	add("/tools",
   444  		&toolsUploadHandler{
   445  			ctxt: httpCtxt,
   446  		},
   447  	)
   448  	add("/tools/:version",
   449  		&toolsDownloadHandler{
   450  			ctxt: httpCtxt,
   451  		},
   452  	)
   453  	add("/register",
   454  		&registerUserHandler{
   455  			httpCtxt,
   456  			srv.authCtxt.userAuth.CreateLocalLoginMacaroon,
   457  		},
   458  	)
   459  	add("/", mainAPIHandler)
   460  
   461  	return endpoints
   462  }
   463  
   464  func (srv *Server) newHandlerArgs(spec apihttp.HandlerConstraints) apihttp.NewHandlerArgs {
   465  	ctxt := httpContext{
   466  		srv:                 srv,
   467  		strictValidation:    spec.StrictValidation,
   468  		controllerModelOnly: spec.ControllerModelOnly,
   469  	}
   470  
   471  	var args apihttp.NewHandlerArgs
   472  	switch spec.AuthKind {
   473  	case names.UserTagKind:
   474  		args.Connect = ctxt.stateForRequestAuthenticatedUser
   475  	case names.UnitTagKind:
   476  		args.Connect = ctxt.stateForRequestAuthenticatedAgent
   477  	case "":
   478  		logger.Tracef(`no access level specified; proceeding with "unauthenticated"`)
   479  		args.Connect = func(req *http.Request) (*state.State, state.Entity, error) {
   480  			st, err := ctxt.stateForRequestUnauthenticated(req)
   481  			return st, nil, err
   482  		}
   483  	default:
   484  		logger.Warningf(`unrecognized access level %q; proceeding with "unauthenticated"`, spec.AuthKind)
   485  		args.Connect = func(req *http.Request) (*state.State, state.Entity, error) {
   486  			st, err := ctxt.stateForRequestUnauthenticated(req)
   487  			return st, nil, err
   488  		}
   489  	}
   490  	return args
   491  }
   492  
   493  // trackRequests wraps a http.Handler, incrementing and decrementing
   494  // the apiserver's WaitGroup and blocking request when the apiserver
   495  // is shutting down.
   496  //
   497  // Note: It is only safe to use trackRequests with API handlers which
   498  // are interruptible (i.e. they pay attention to the apiserver tomb)
   499  // or are guaranteed to be short-lived. If it's used with long running
   500  // API handlers which don't watch the apiserver's tomb, apiserver
   501  // shutdown will be blocked until the API handler returns.
   502  func (srv *Server) trackRequests(handler http.Handler) http.Handler {
   503  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   504  		srv.wg.Add(1)
   505  		defer srv.wg.Done()
   506  		// If we've got to this stage and the tomb is still
   507  		// alive, we know that any tomb.Kill must occur after we
   508  		// have called wg.Add, so we avoid the possibility of a
   509  		// handler goroutine running after Stop has returned.
   510  		if srv.tomb.Err() != tomb.ErrStillAlive {
   511  			return
   512  		}
   513  
   514  		handler.ServeHTTP(w, r)
   515  	})
   516  }
   517  
   518  func registerEndpoint(ep apihttp.Endpoint, mux *pat.PatternServeMux) {
   519  	switch ep.Method {
   520  	case "GET":
   521  		mux.Get(ep.Pattern, ep.Handler)
   522  	case "POST":
   523  		mux.Post(ep.Pattern, ep.Handler)
   524  	case "HEAD":
   525  		mux.Head(ep.Pattern, ep.Handler)
   526  	case "PUT":
   527  		mux.Put(ep.Pattern, ep.Handler)
   528  	case "DEL":
   529  		mux.Del(ep.Pattern, ep.Handler)
   530  	case "OPTIONS":
   531  		mux.Options(ep.Pattern, ep.Handler)
   532  	}
   533  }
   534  
   535  func (srv *Server) apiHandler(w http.ResponseWriter, req *http.Request) {
   536  	reqNotifier := newRequestNotifier(&srv.connections)
   537  	reqNotifier.join(req)
   538  	defer reqNotifier.leave()
   539  	wsServer := websocket.Server{
   540  		Handler: func(conn *websocket.Conn) {
   541  			modelUUID := req.URL.Query().Get(":modeluuid")
   542  			logger.Tracef("got a request for model %q", modelUUID)
   543  			if err := srv.serveConn(conn, reqNotifier, modelUUID); err != nil {
   544  				logger.Errorf("error serving RPCs: %v", err)
   545  			}
   546  		},
   547  	}
   548  	wsServer.ServeHTTP(w, req)
   549  }
   550  
   551  func (srv *Server) serveConn(wsConn *websocket.Conn, reqNotifier *requestNotifier, modelUUID string) error {
   552  	codec := jsoncodec.NewWebsocket(wsConn)
   553  	if loggo.GetLogger("juju.rpc.jsoncodec").EffectiveLogLevel() <= loggo.TRACE {
   554  		codec.SetLogging(true)
   555  	}
   556  	var notifier rpc.RequestNotifier
   557  	if logger.EffectiveLogLevel() <= loggo.DEBUG {
   558  		// Incur request monitoring overhead only if we
   559  		// know we'll need it.
   560  		notifier = reqNotifier
   561  	}
   562  	conn := rpc.NewConn(codec, notifier)
   563  
   564  	h, err := srv.newAPIHandler(conn, reqNotifier, modelUUID)
   565  	if err != nil {
   566  		conn.ServeFinder(&errRoot{err}, serverError)
   567  	} else {
   568  		adminApis := make(map[int]interface{})
   569  		for apiVersion, factory := range srv.adminApiFactories {
   570  			adminApis[apiVersion] = factory(srv, h, reqNotifier)
   571  		}
   572  		conn.ServeFinder(newAnonRoot(h, adminApis), serverError)
   573  	}
   574  	conn.Start()
   575  	select {
   576  	case <-conn.Dead():
   577  	case <-srv.tomb.Dying():
   578  	}
   579  	return conn.Close()
   580  }
   581  
   582  func (srv *Server) newAPIHandler(conn *rpc.Conn, reqNotifier *requestNotifier, modelUUID string) (*apiHandler, error) {
   583  	// Note that we don't overwrite modelUUID here because
   584  	// newAPIHandler treats an empty modelUUID as signifying
   585  	// the API version used.
   586  	resolvedModelUUID, err := validateModelUUID(validateArgs{
   587  		statePool: srv.statePool,
   588  		modelUUID: modelUUID,
   589  	})
   590  	if err != nil {
   591  		return nil, errors.Trace(err)
   592  	}
   593  	st, err := srv.statePool.Get(resolvedModelUUID)
   594  	if err != nil {
   595  		return nil, errors.Trace(err)
   596  	}
   597  	return newApiHandler(srv, st, conn, reqNotifier, modelUUID)
   598  }
   599  
   600  func (srv *Server) mongoPinger() error {
   601  	// TODO(fwereade): 2016-03-17 lp:1558657
   602  	timer := time.NewTimer(0)
   603  	session := srv.state.MongoSession().Copy()
   604  	defer session.Close()
   605  	for {
   606  		select {
   607  		case <-timer.C:
   608  		case <-srv.tomb.Dying():
   609  			return tomb.ErrDying
   610  		}
   611  		if err := session.Ping(); err != nil {
   612  			logger.Infof("got error pinging mongo: %v", err)
   613  			return errors.Annotate(err, "error pinging mongo")
   614  		}
   615  		timer.Reset(mongoPingInterval)
   616  	}
   617  }
   618  
   619  func serverError(err error) error {
   620  	if err := common.ServerError(err); err != nil {
   621  		return err
   622  	}
   623  	return nil
   624  }