github.com/makyo/juju@v0.0.0-20160425123129-2608902037e9/apiserver/apiserver.go (about)

     1  // Copyright 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package apiserver
     5  
     6  import (
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"net"
    10  	"net/http"
    11  	"strings"
    12  	"sync"
    13  	"sync/atomic"
    14  	"time"
    15  
    16  	"github.com/bmizerany/pat"
    17  	"github.com/juju/errors"
    18  	"github.com/juju/loggo"
    19  	"github.com/juju/names"
    20  	"github.com/juju/utils"
    21  	"golang.org/x/net/websocket"
    22  	"launchpad.net/tomb"
    23  
    24  	"github.com/juju/juju/apiserver/common"
    25  	"github.com/juju/juju/apiserver/common/apihttp"
    26  	"github.com/juju/juju/apiserver/params"
    27  	"github.com/juju/juju/rpc"
    28  	"github.com/juju/juju/rpc/jsoncodec"
    29  	"github.com/juju/juju/state"
    30  )
    31  
    32  var logger = loggo.GetLogger("juju.apiserver")
    33  
    34  // loginRateLimit defines how many concurrent Login requests we will
    35  // accept
    36  const loginRateLimit = 10
    37  
    38  // Server holds the server side of the API.
    39  type Server struct {
    40  	tomb              tomb.Tomb
    41  	wg                sync.WaitGroup
    42  	state             *state.State
    43  	statePool         *state.StatePool
    44  	lis               net.Listener
    45  	tag               names.Tag
    46  	dataDir           string
    47  	logDir            string
    48  	limiter           utils.Limiter
    49  	validator         LoginValidator
    50  	adminApiFactories map[int]adminApiFactory
    51  	modelUUID         string
    52  	authCtxt          *authContext
    53  	connections       int32 // count of active websocket connections
    54  }
    55  
    56  // LoginValidator functions are used to decide whether login requests
    57  // are to be allowed. The validator is called before credentials are
    58  // checked.
    59  type LoginValidator func(params.LoginRequest) error
    60  
    61  // ServerConfig holds parameters required to set up an API server.
    62  type ServerConfig struct {
    63  	Cert        []byte
    64  	Key         []byte
    65  	Tag         names.Tag
    66  	DataDir     string
    67  	LogDir      string
    68  	Validator   LoginValidator
    69  	CertChanged chan params.StateServingInfo
    70  
    71  	// This field only exists to support testing.
    72  	StatePool *state.StatePool
    73  }
    74  
    75  // changeCertListener wraps a TLS net.Listener.
    76  // It allows connection handshakes to be
    77  // blocked while the TLS certificate is updated.
    78  type changeCertListener struct {
    79  	net.Listener
    80  	tomb tomb.Tomb
    81  
    82  	// A mutex used to block accept operations.
    83  	m sync.Mutex
    84  
    85  	// A channel used to pass in new certificate information.
    86  	certChanged <-chan params.StateServingInfo
    87  
    88  	// The config to update with any new certificate.
    89  	config *tls.Config
    90  }
    91  
    92  func newChangeCertListener(lis net.Listener, certChanged <-chan params.StateServingInfo, config *tls.Config) *changeCertListener {
    93  	cl := &changeCertListener{
    94  		Listener:    lis,
    95  		certChanged: certChanged,
    96  		config:      config,
    97  	}
    98  	go func() {
    99  		defer cl.tomb.Done()
   100  		cl.tomb.Kill(cl.processCertChanges())
   101  	}()
   102  	return cl
   103  }
   104  
   105  // Accept waits for and returns the next connection to the listener.
   106  func (cl *changeCertListener) Accept() (net.Conn, error) {
   107  	conn, err := cl.Listener.Accept()
   108  	if err != nil {
   109  		return nil, err
   110  	}
   111  	cl.m.Lock()
   112  	defer cl.m.Unlock()
   113  	config := cl.config
   114  	return tls.Server(conn, config), nil
   115  }
   116  
   117  // Close closes the listener.
   118  func (cl *changeCertListener) Close() error {
   119  	cl.tomb.Kill(nil)
   120  	return cl.Listener.Close()
   121  }
   122  
   123  // processCertChanges receives new certificate information and
   124  // calls a method to update the listener's certificate.
   125  func (cl *changeCertListener) processCertChanges() error {
   126  	for {
   127  		select {
   128  		case info := <-cl.certChanged:
   129  			if info.Cert != "" {
   130  				cl.updateCertificate([]byte(info.Cert), []byte(info.PrivateKey))
   131  			}
   132  		case <-cl.tomb.Dying():
   133  			return tomb.ErrDying
   134  		}
   135  	}
   136  }
   137  
   138  // updateCertificate generates a new TLS certificate and assigns it
   139  // to the TLS listener.
   140  func (cl *changeCertListener) updateCertificate(cert, key []byte) {
   141  	cl.m.Lock()
   142  	defer cl.m.Unlock()
   143  	if tlsCert, err := tls.X509KeyPair(cert, key); err != nil {
   144  		logger.Errorf("cannot create new TLS certificate: %v", err)
   145  	} else {
   146  		logger.Infof("updating api server certificate")
   147  		x509Cert, err := x509.ParseCertificate(tlsCert.Certificate[0])
   148  		if err == nil {
   149  			var addr []string
   150  			for _, ip := range x509Cert.IPAddresses {
   151  				addr = append(addr, ip.String())
   152  			}
   153  			logger.Infof("new certificate addresses: %v", strings.Join(addr, ", "))
   154  		}
   155  		cl.config.Certificates = []tls.Certificate{tlsCert}
   156  	}
   157  }
   158  
   159  // NewServer serves the given state by accepting requests on the given
   160  // listener, using the given certificate and key (in PEM format) for
   161  // authentication.
   162  //
   163  // The Server will close the listener when it exits, even if returns an error.
   164  func NewServer(s *state.State, lis net.Listener, cfg ServerConfig) (*Server, error) {
   165  	// Important note:
   166  	// Do not manipulate the state within NewServer as the API
   167  	// server needs to run before mongo upgrades have happened and
   168  	// any state manipulation may be be relying on features of the
   169  	// database added by upgrades. Here be dragons.
   170  	l, ok := lis.(*net.TCPListener)
   171  	if !ok {
   172  		return nil, errors.Errorf("listener is not of type *net.TCPListener: %T", lis)
   173  	}
   174  	srv, err := newServer(s, l, cfg)
   175  	if err != nil {
   176  		// There is no running server around to close the listener.
   177  		lis.Close()
   178  		return nil, errors.Trace(err)
   179  	}
   180  	return srv, nil
   181  }
   182  
   183  func newServer(s *state.State, lis *net.TCPListener, cfg ServerConfig) (_ *Server, err error) {
   184  	tlsCert, err := tls.X509KeyPair(cfg.Cert, cfg.Key)
   185  	if err != nil {
   186  		return nil, err
   187  	}
   188  	// TODO(rog) check that *srvRoot is a valid type for using
   189  	// as an RPC server.
   190  	tlsConfig := &tls.Config{
   191  		Certificates: []tls.Certificate{tlsCert},
   192  		MinVersion:   tls.VersionTLS10,
   193  	}
   194  
   195  	stPool := cfg.StatePool
   196  	if stPool == nil {
   197  		stPool = state.NewStatePool(s)
   198  	}
   199  
   200  	srv := &Server{
   201  		state:     s,
   202  		statePool: stPool,
   203  		lis:       newChangeCertListener(lis, cfg.CertChanged, tlsConfig),
   204  		tag:       cfg.Tag,
   205  		dataDir:   cfg.DataDir,
   206  		logDir:    cfg.LogDir,
   207  		limiter:   utils.NewLimiter(loginRateLimit),
   208  		validator: cfg.Validator,
   209  		adminApiFactories: map[int]adminApiFactory{
   210  			3: newAdminApiV3,
   211  		},
   212  	}
   213  	srv.authCtxt, err = newAuthContext(s)
   214  	if err != nil {
   215  		return nil, errors.Trace(err)
   216  	}
   217  	go srv.run()
   218  	return srv, nil
   219  }
   220  
   221  // Dead returns a channel that signals when the server has exited.
   222  func (srv *Server) Dead() <-chan struct{} {
   223  	return srv.tomb.Dead()
   224  }
   225  
   226  // Stop stops the server and returns when all running requests
   227  // have completed.
   228  func (srv *Server) Stop() error {
   229  	srv.tomb.Kill(nil)
   230  	return srv.tomb.Wait()
   231  }
   232  
   233  // Kill implements worker.Worker.Kill.
   234  func (srv *Server) Kill() {
   235  	srv.tomb.Kill(nil)
   236  }
   237  
   238  // Wait implements worker.Worker.Wait.
   239  func (srv *Server) Wait() error {
   240  	return srv.tomb.Wait()
   241  }
   242  
   243  type requestNotifier struct {
   244  	id    int64
   245  	start time.Time
   246  
   247  	mu   sync.Mutex
   248  	tag_ string
   249  
   250  	// count is incremented by calls to join, and deincremented
   251  	// by calls to leave.
   252  	count *int32
   253  }
   254  
   255  var globalCounter int64
   256  
   257  func newRequestNotifier(count *int32) *requestNotifier {
   258  	return &requestNotifier{
   259  		id:   atomic.AddInt64(&globalCounter, 1),
   260  		tag_: "<unknown>",
   261  		// TODO(fwereade): 2016-03-17 lp:1558657
   262  		start: time.Now(),
   263  		count: count,
   264  	}
   265  }
   266  
   267  func (n *requestNotifier) login(tag string) {
   268  	n.mu.Lock()
   269  	n.tag_ = tag
   270  	n.mu.Unlock()
   271  }
   272  
   273  func (n *requestNotifier) tag() (tag string) {
   274  	n.mu.Lock()
   275  	tag = n.tag_
   276  	n.mu.Unlock()
   277  	return
   278  }
   279  
   280  func (n *requestNotifier) ServerRequest(hdr *rpc.Header, body interface{}) {
   281  	if hdr.Request.Type == "Pinger" && hdr.Request.Action == "Ping" {
   282  		return
   283  	}
   284  	// TODO(rog) 2013-10-11 remove secrets from some requests.
   285  	// Until secrets are removed, we only log the body of the requests at trace level
   286  	// which is below the default level of debug.
   287  	if logger.IsTraceEnabled() {
   288  		logger.Tracef("<- [%X] %s %s", n.id, n.tag(), jsoncodec.DumpRequest(hdr, body))
   289  	} else {
   290  		logger.Debugf("<- [%X] %s %s", n.id, n.tag(), jsoncodec.DumpRequest(hdr, "'params redacted'"))
   291  	}
   292  }
   293  
   294  func (n *requestNotifier) ServerReply(req rpc.Request, hdr *rpc.Header, body interface{}, timeSpent time.Duration) {
   295  	if req.Type == "Pinger" && req.Action == "Ping" {
   296  		return
   297  	}
   298  	// TODO(rog) 2013-10-11 remove secrets from some responses.
   299  	// Until secrets are removed, we only log the body of the requests at trace level
   300  	// which is below the default level of debug.
   301  	if logger.IsTraceEnabled() {
   302  		logger.Tracef("-> [%X] %s %s", n.id, n.tag(), jsoncodec.DumpRequest(hdr, body))
   303  	} else {
   304  		logger.Debugf("-> [%X] %s %s %s %s[%q].%s", n.id, n.tag(), timeSpent, jsoncodec.DumpRequest(hdr, "'body redacted'"), req.Type, req.Id, req.Action)
   305  	}
   306  }
   307  
   308  func (n *requestNotifier) join(req *http.Request) {
   309  	active := atomic.AddInt32(n.count, 1)
   310  	logger.Infof("[%X] API connection from %s, active connections: %d", n.id, req.RemoteAddr, active)
   311  }
   312  
   313  func (n *requestNotifier) leave() {
   314  	active := atomic.AddInt32(n.count, -1)
   315  	logger.Infof("[%X] %s API connection terminated after %v, active connections: %d", n.id, n.tag(), time.Since(n.start), active)
   316  }
   317  
   318  func (n *requestNotifier) ClientRequest(hdr *rpc.Header, body interface{}) {
   319  }
   320  
   321  func (n *requestNotifier) ClientReply(req rpc.Request, hdr *rpc.Header, body interface{}) {
   322  }
   323  
   324  func (srv *Server) run() {
   325  	logger.Infof("listening on %q", srv.lis.Addr())
   326  
   327  	defer func() {
   328  		addr := srv.lis.Addr().String() // Addr not valid after close
   329  		err := srv.lis.Close()
   330  		logger.Infof("closed listening socket %q with final error: %v", addr, err)
   331  
   332  		srv.state.HackLeadership() // Break deadlocks caused by BlockUntil... calls.
   333  		srv.wg.Wait()              // wait for any outstanding requests to complete.
   334  		srv.tomb.Done()
   335  		srv.statePool.Close()
   336  		srv.state.Close()
   337  	}()
   338  
   339  	srv.wg.Add(1)
   340  	go func() {
   341  		defer srv.wg.Done()
   342  		srv.tomb.Kill(srv.mongoPinger())
   343  	}()
   344  
   345  	// for pat based handlers, they are matched in-order of being
   346  	// registered, first match wins. So more specific ones have to be
   347  	// registered first.
   348  	mux := pat.New()
   349  	for _, endpoint := range srv.endpoints() {
   350  		registerEndpoint(endpoint, mux)
   351  	}
   352  
   353  	go func() {
   354  		addr := srv.lis.Addr() // not valid after addr closed
   355  		logger.Debugf("Starting API http server on address %q", addr)
   356  		err := http.Serve(srv.lis, mux)
   357  		// normally logging an error at debug level would be grounds for a beating,
   358  		// however in this case the error is *expected* to be non nil, and does not
   359  		// affect the operation of the apiserver, but for completeness log it anyway.
   360  		logger.Debugf("API http server exited, final error was: %v", err)
   361  	}()
   362  
   363  	<-srv.tomb.Dying()
   364  }
   365  
   366  func (srv *Server) endpoints() []apihttp.Endpoint {
   367  	httpCtxt := httpContext{
   368  		srv: srv,
   369  	}
   370  
   371  	endpoints := common.ResolveAPIEndpoints(srv.newHandlerArgs)
   372  
   373  	// TODO(ericsnow) Add the following to the registry instead.
   374  
   375  	add := func(pattern string, handler http.Handler) {
   376  		// TODO: We can switch from all methods to specific ones for entries
   377  		// where we only want to support specific request methods. However, our
   378  		// tests currently assert that errors come back as application/json and
   379  		// pat only does "text/plain" responses.
   380  		for _, method := range common.DefaultHTTPMethods {
   381  			endpoints = append(endpoints, apihttp.Endpoint{
   382  				Pattern: pattern,
   383  				Method:  method,
   384  				Handler: handler,
   385  			})
   386  		}
   387  	}
   388  
   389  	mainAPIHandler := srv.trackRequests(http.HandlerFunc(srv.apiHandler))
   390  	logSinkHandler := srv.trackRequests(newLogSinkHandler(httpCtxt, srv.logDir))
   391  	debugLogHandler := srv.trackRequests(newDebugLogDBHandler(httpCtxt))
   392  
   393  	add("/model/:modeluuid/logsink", logSinkHandler)
   394  	add("/model/:modeluuid/log", debugLogHandler)
   395  	add("/model/:modeluuid/charms",
   396  		&charmsHandler{
   397  			ctxt:    httpCtxt,
   398  			dataDir: srv.dataDir},
   399  	)
   400  	add("/model/:modeluuid/tools",
   401  		&toolsUploadHandler{
   402  			ctxt: httpCtxt,
   403  		},
   404  	)
   405  	add("/model/:modeluuid/tools/:version",
   406  		&toolsDownloadHandler{
   407  			ctxt: httpCtxt,
   408  		},
   409  	)
   410  	strictCtxt := httpCtxt
   411  	strictCtxt.strictValidation = true
   412  	strictCtxt.controllerModelOnly = true
   413  	add("/model/:modeluuid/backups",
   414  		&backupHandler{
   415  			ctxt: strictCtxt,
   416  		},
   417  	)
   418  	add("/model/:modeluuid/api", mainAPIHandler)
   419  
   420  	add("/model/:modeluuid/images/:kind/:series/:arch/:filename",
   421  		&imagesDownloadHandler{
   422  			ctxt:    httpCtxt,
   423  			dataDir: srv.dataDir,
   424  			state:   srv.state,
   425  		},
   426  	)
   427  
   428  	endpoints = append(endpoints, guiEndpoints("/gui/:modeluuid/", srv.dataDir, httpCtxt)...)
   429  	add("/gui-archive", &guiArchiveHandler{
   430  		ctxt: httpCtxt,
   431  	})
   432  	add("/gui-version", &guiVersionHandler{
   433  		ctxt: httpCtxt,
   434  	})
   435  
   436  	// For backwards compatibility we register all the old paths
   437  	add("/log", debugLogHandler)
   438  
   439  	add("/charms",
   440  		&charmsHandler{
   441  			ctxt:    httpCtxt,
   442  			dataDir: srv.dataDir,
   443  		},
   444  	)
   445  	add("/tools",
   446  		&toolsUploadHandler{
   447  			ctxt: httpCtxt,
   448  		},
   449  	)
   450  	add("/tools/:version",
   451  		&toolsDownloadHandler{
   452  			ctxt: httpCtxt,
   453  		},
   454  	)
   455  	add("/register",
   456  		&registerUserHandler{
   457  			httpCtxt,
   458  			srv.authCtxt.userAuth.CreateLocalLoginMacaroon,
   459  		},
   460  	)
   461  	add("/", mainAPIHandler)
   462  
   463  	return endpoints
   464  }
   465  
   466  func (srv *Server) newHandlerArgs(spec apihttp.HandlerConstraints) apihttp.NewHandlerArgs {
   467  	ctxt := httpContext{
   468  		srv:                 srv,
   469  		strictValidation:    spec.StrictValidation,
   470  		controllerModelOnly: spec.ControllerModelOnly,
   471  	}
   472  
   473  	var args apihttp.NewHandlerArgs
   474  	switch spec.AuthKind {
   475  	case names.UserTagKind:
   476  		args.Connect = ctxt.stateForRequestAuthenticatedUser
   477  	case names.UnitTagKind:
   478  		args.Connect = ctxt.stateForRequestAuthenticatedAgent
   479  	case "":
   480  		logger.Tracef(`no access level specified; proceeding with "unauthenticated"`)
   481  		args.Connect = func(req *http.Request) (*state.State, state.Entity, error) {
   482  			st, err := ctxt.stateForRequestUnauthenticated(req)
   483  			return st, nil, err
   484  		}
   485  	default:
   486  		logger.Warningf(`unrecognized access level %q; proceeding with "unauthenticated"`, spec.AuthKind)
   487  		args.Connect = func(req *http.Request) (*state.State, state.Entity, error) {
   488  			st, err := ctxt.stateForRequestUnauthenticated(req)
   489  			return st, nil, err
   490  		}
   491  	}
   492  	return args
   493  }
   494  
   495  // trackRequests wraps a http.Handler, incrementing and decrementing
   496  // the apiserver's WaitGroup and blocking request when the apiserver
   497  // is shutting down.
   498  //
   499  // Note: It is only safe to use trackRequests with API handlers which
   500  // are interruptible (i.e. they pay attention to the apiserver tomb)
   501  // or are guaranteed to be short-lived. If it's used with long running
   502  // API handlers which don't watch the apiserver's tomb, apiserver
   503  // shutdown will be blocked until the API handler returns.
   504  func (srv *Server) trackRequests(handler http.Handler) http.Handler {
   505  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   506  		srv.wg.Add(1)
   507  		defer srv.wg.Done()
   508  		// If we've got to this stage and the tomb is still
   509  		// alive, we know that any tomb.Kill must occur after we
   510  		// have called wg.Add, so we avoid the possibility of a
   511  		// handler goroutine running after Stop has returned.
   512  		if srv.tomb.Err() != tomb.ErrStillAlive {
   513  			return
   514  		}
   515  
   516  		handler.ServeHTTP(w, r)
   517  	})
   518  }
   519  
   520  func registerEndpoint(ep apihttp.Endpoint, mux *pat.PatternServeMux) {
   521  	switch ep.Method {
   522  	case "GET":
   523  		mux.Get(ep.Pattern, ep.Handler)
   524  	case "POST":
   525  		mux.Post(ep.Pattern, ep.Handler)
   526  	case "HEAD":
   527  		mux.Head(ep.Pattern, ep.Handler)
   528  	case "PUT":
   529  		mux.Put(ep.Pattern, ep.Handler)
   530  	case "DEL":
   531  		mux.Del(ep.Pattern, ep.Handler)
   532  	case "OPTIONS":
   533  		mux.Options(ep.Pattern, ep.Handler)
   534  	}
   535  }
   536  
   537  func (srv *Server) apiHandler(w http.ResponseWriter, req *http.Request) {
   538  	reqNotifier := newRequestNotifier(&srv.connections)
   539  	reqNotifier.join(req)
   540  	defer reqNotifier.leave()
   541  	wsServer := websocket.Server{
   542  		Handler: func(conn *websocket.Conn) {
   543  			modelUUID := req.URL.Query().Get(":modeluuid")
   544  			logger.Tracef("got a request for model %q", modelUUID)
   545  			if err := srv.serveConn(conn, reqNotifier, modelUUID); err != nil {
   546  				logger.Errorf("error serving RPCs: %v", err)
   547  			}
   548  		},
   549  	}
   550  	wsServer.ServeHTTP(w, req)
   551  }
   552  
   553  func (srv *Server) serveConn(wsConn *websocket.Conn, reqNotifier *requestNotifier, modelUUID string) error {
   554  	codec := jsoncodec.NewWebsocket(wsConn)
   555  	if loggo.GetLogger("juju.rpc.jsoncodec").EffectiveLogLevel() <= loggo.TRACE {
   556  		codec.SetLogging(true)
   557  	}
   558  	var notifier rpc.RequestNotifier
   559  	if logger.EffectiveLogLevel() <= loggo.DEBUG {
   560  		// Incur request monitoring overhead only if we
   561  		// know we'll need it.
   562  		notifier = reqNotifier
   563  	}
   564  	conn := rpc.NewConn(codec, notifier)
   565  
   566  	h, err := srv.newAPIHandler(conn, reqNotifier, modelUUID)
   567  	if err != nil {
   568  		conn.ServeFinder(&errRoot{err}, serverError)
   569  	} else {
   570  		adminApis := make(map[int]interface{})
   571  		for apiVersion, factory := range srv.adminApiFactories {
   572  			adminApis[apiVersion] = factory(srv, h, reqNotifier)
   573  		}
   574  		conn.ServeFinder(newAnonRoot(h, adminApis), serverError)
   575  	}
   576  	conn.Start()
   577  	select {
   578  	case <-conn.Dead():
   579  	case <-srv.tomb.Dying():
   580  	}
   581  	return conn.Close()
   582  }
   583  
   584  func (srv *Server) newAPIHandler(conn *rpc.Conn, reqNotifier *requestNotifier, modelUUID string) (*apiHandler, error) {
   585  	// Note that we don't overwrite modelUUID here because
   586  	// newAPIHandler treats an empty modelUUID as signifying
   587  	// the API version used.
   588  	resolvedModelUUID, err := validateModelUUID(validateArgs{
   589  		statePool: srv.statePool,
   590  		modelUUID: modelUUID,
   591  	})
   592  	if err != nil {
   593  		return nil, errors.Trace(err)
   594  	}
   595  	st, err := srv.statePool.Get(resolvedModelUUID)
   596  	if err != nil {
   597  		return nil, errors.Trace(err)
   598  	}
   599  	return newApiHandler(srv, st, conn, reqNotifier, modelUUID)
   600  }
   601  
   602  func (srv *Server) mongoPinger() error {
   603  	// TODO(fwereade): 2016-03-17 lp:1558657
   604  	timer := time.NewTimer(0)
   605  	session := srv.state.MongoSession().Copy()
   606  	defer session.Close()
   607  	for {
   608  		select {
   609  		case <-timer.C:
   610  		case <-srv.tomb.Dying():
   611  			return tomb.ErrDying
   612  		}
   613  		if err := session.Ping(); err != nil {
   614  			logger.Infof("got error pinging mongo: %v", err)
   615  			return errors.Annotate(err, "error pinging mongo")
   616  		}
   617  		timer.Reset(mongoPingInterval)
   618  	}
   619  }
   620  
   621  func serverError(err error) error {
   622  	if err := common.ServerError(err); err != nil {
   623  		return err
   624  	}
   625  	return nil
   626  }