github.com/chipaca/snappy@v0.0.0-20210104084008-1f06296fe8ad/usersession/agent/session_agent.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2019 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package agent
    21  
    22  import (
    23  	"context"
    24  	"fmt"
    25  	"net"
    26  	"net/http"
    27  	"os"
    28  	"sync"
    29  	"syscall"
    30  	"time"
    31  
    32  	"github.com/godbus/dbus"
    33  	"github.com/gorilla/mux"
    34  	"gopkg.in/tomb.v2"
    35  
    36  	"github.com/snapcore/snapd/dirs"
    37  	"github.com/snapcore/snapd/logger"
    38  	"github.com/snapcore/snapd/netutil"
    39  	"github.com/snapcore/snapd/osutil/sys"
    40  	"github.com/snapcore/snapd/systemd"
    41  )
    42  
    43  type SessionAgent struct {
    44  	Version  string
    45  	bus      *dbus.Conn
    46  	listener net.Listener
    47  	serve    *http.Server
    48  	tomb     tomb.Tomb
    49  	router   *mux.Router
    50  
    51  	idle        *idleTracker
    52  	IdleTimeout time.Duration
    53  }
    54  
    55  const sessionAgentBusName = "io.snapcraft.SessionAgent"
    56  
    57  func dbusSessionBus() (*dbus.Conn, error) {
    58  	// use a private connection to the session bus, this way we can manage
    59  	// its lifetime without worrying of breaking other code
    60  	conn, err := dbus.SessionBusPrivate()
    61  	if err != nil {
    62  		return nil, err
    63  	}
    64  	if err := conn.Auth(nil); err != nil {
    65  		conn.Close()
    66  		return nil, err
    67  	}
    68  	if err := conn.Hello(); err != nil {
    69  		conn.Close()
    70  		return nil, err
    71  	}
    72  	return conn, nil
    73  }
    74  
    75  // A ResponseFunc handles one of the individual verbs for a method
    76  type ResponseFunc func(*Command, *http.Request) Response
    77  
    78  // A Command routes a request to an individual per-verb ResponseFunc
    79  type Command struct {
    80  	Path string
    81  
    82  	GET    ResponseFunc
    83  	PUT    ResponseFunc
    84  	POST   ResponseFunc
    85  	DELETE ResponseFunc
    86  
    87  	s *SessionAgent
    88  }
    89  
    90  func (c *Command) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    91  	var rspf ResponseFunc
    92  	var rsp = MethodNotAllowed("method %q not allowed", r.Method)
    93  
    94  	switch r.Method {
    95  	case "GET":
    96  		rspf = c.GET
    97  	case "PUT":
    98  		rspf = c.PUT
    99  	case "POST":
   100  		rspf = c.POST
   101  	case "DELETE":
   102  		rspf = c.DELETE
   103  	}
   104  
   105  	if rspf != nil {
   106  		rsp = rspf(c, r)
   107  	}
   108  	rsp.ServeHTTP(w, r)
   109  }
   110  
   111  type idleTracker struct {
   112  	mu         sync.Mutex
   113  	active     map[net.Conn]struct{}
   114  	lastActive time.Time
   115  }
   116  
   117  var sysGetsockoptUcred = syscall.GetsockoptUcred
   118  
   119  func getUcred(conn net.Conn) (*syscall.Ucred, error) {
   120  	if uconn, ok := conn.(*net.UnixConn); ok {
   121  		f, err := uconn.File()
   122  		if err != nil {
   123  			return nil, err
   124  		}
   125  		defer f.Close()
   126  		return sysGetsockoptUcred(int(f.Fd()), syscall.SOL_SOCKET, syscall.SO_PEERCRED)
   127  	}
   128  	return nil, fmt.Errorf("expected a net.UnixConn, but got a %T", conn)
   129  }
   130  
   131  func (it *idleTracker) trackConn(conn net.Conn, state http.ConnState) {
   132  	// Perform peer credentials check
   133  	if state == http.StateNew {
   134  		ucred, err := getUcred(conn)
   135  		if err != nil {
   136  			logger.Noticef("Failed to retrieve peer credentials: %v", err)
   137  			conn.Close()
   138  			return
   139  		}
   140  		if ucred.Uid != 0 && ucred.Uid != uint32(sys.Geteuid()) {
   141  			logger.Noticef("Blocking request from user ID %v", ucred.Uid)
   142  			conn.Close()
   143  			return
   144  		}
   145  	}
   146  
   147  	it.mu.Lock()
   148  	defer it.mu.Unlock()
   149  	oldActive := len(it.active)
   150  	if state == http.StateNew || state == http.StateActive {
   151  		it.active[conn] = struct{}{}
   152  	} else {
   153  		delete(it.active, conn)
   154  	}
   155  	if len(it.active) == 0 && oldActive != 0 {
   156  		it.lastActive = time.Now()
   157  	}
   158  }
   159  
   160  // idleDuration returns the duration of time the server has been idle
   161  func (it *idleTracker) idleDuration() time.Duration {
   162  	it.mu.Lock()
   163  	defer it.mu.Unlock()
   164  	if len(it.active) != 0 {
   165  		return 0
   166  	}
   167  	return time.Since(it.lastActive)
   168  }
   169  
   170  const (
   171  	defaultIdleTimeout = 30 * time.Second
   172  	shutdownTimeout    = 5 * time.Second
   173  )
   174  
   175  type closeOnceListener struct {
   176  	net.Listener
   177  
   178  	idempotClose sync.Once
   179  	closeErr     error
   180  }
   181  
   182  func (l *closeOnceListener) Close() error {
   183  	l.idempotClose.Do(func() {
   184  		l.closeErr = l.Listener.Close()
   185  	})
   186  	return l.closeErr
   187  }
   188  
   189  func (s *SessionAgent) Init() error {
   190  	// Set up D-Bus connection
   191  	if err := s.tryConnectSessionBus(); err != nil {
   192  		return err
   193  	}
   194  
   195  	// Set up REST API server
   196  	listenerMap, err := netutil.ActivationListeners()
   197  	if err != nil {
   198  		return err
   199  	}
   200  	agentSocket := fmt.Sprintf("%s/%d/snapd-session-agent.socket", dirs.XdgRuntimeDirBase, os.Getuid())
   201  	if l, err := netutil.GetListener(agentSocket, listenerMap); err != nil {
   202  		return fmt.Errorf("cannot listen on socket %s: %v", agentSocket, err)
   203  	} else {
   204  		s.listener = &closeOnceListener{Listener: l}
   205  	}
   206  	s.idle = &idleTracker{
   207  		active:     make(map[net.Conn]struct{}),
   208  		lastActive: time.Now(),
   209  	}
   210  	s.IdleTimeout = defaultIdleTimeout
   211  	s.addRoutes()
   212  	s.serve = &http.Server{
   213  		Handler:   s.router,
   214  		ConnState: s.idle.trackConn,
   215  	}
   216  	return nil
   217  }
   218  
   219  func (s *SessionAgent) tryConnectSessionBus() (err error) {
   220  	s.bus, err = dbusSessionBus()
   221  	if err != nil {
   222  		// ssh sessions on Ubuntu 16.04 may have a user
   223  		// instance of systemd but no D-Bus session bus.  So
   224  		// don't treat this as an error.
   225  		logger.Noticef("Could not connect to session bus: %v", err)
   226  		return nil
   227  	}
   228  	defer func() {
   229  		if err != nil {
   230  			s.bus.Close()
   231  			s.bus = nil
   232  		}
   233  	}()
   234  
   235  	reply, err := s.bus.RequestName(sessionAgentBusName, dbus.NameFlagDoNotQueue)
   236  	if err != nil {
   237  		return err
   238  	}
   239  	if reply != dbus.RequestNameReplyPrimaryOwner {
   240  		return fmt.Errorf("cannot obtain bus name %q: %v", sessionAgentBusName, reply)
   241  	}
   242  	return nil
   243  }
   244  
   245  func (s *SessionAgent) addRoutes() {
   246  	s.router = mux.NewRouter()
   247  	for _, c := range restApi {
   248  		c.s = s
   249  		s.router.Handle(c.Path, c).Name(c.Path)
   250  	}
   251  	s.router.NotFoundHandler = NotFound("not found")
   252  }
   253  
   254  func (s *SessionAgent) Start() {
   255  	s.tomb.Go(s.runServer)
   256  	s.tomb.Go(s.shutdownServerOnKill)
   257  	s.tomb.Go(s.exitOnIdle)
   258  	systemd.SdNotify("READY=1")
   259  }
   260  
   261  func (s *SessionAgent) runServer() error {
   262  	err := s.serve.Serve(s.listener)
   263  	if err == http.ErrServerClosed {
   264  		err = nil
   265  	}
   266  	if s.tomb.Err() == tomb.ErrStillAlive {
   267  		return err
   268  	}
   269  	return nil
   270  }
   271  
   272  func (s *SessionAgent) shutdownServerOnKill() error {
   273  	<-s.tomb.Dying()
   274  	systemd.SdNotify("STOPPING=1")
   275  	// closing the listener (but then it needs wrapping in
   276  	// closeOnceListener) before actually calling Shutdown, to
   277  	// workaround https://github.com/golang/go/issues/20239, we
   278  	// can in some cases (e.g. tests) end up calling Shutdown
   279  	// before runServer calls Serve and in go <1.11 this can be
   280  	// racy and the shutdown blocks.
   281  	// Historically We do something similar in the main daemon
   282  	// logic as well.
   283  	s.listener.Close()
   284  	s.bus.Close()
   285  	ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
   286  	defer cancel()
   287  	return s.serve.Shutdown(ctx)
   288  }
   289  
   290  func (s *SessionAgent) exitOnIdle() error {
   291  	timer := time.NewTimer(s.IdleTimeout)
   292  Loop:
   293  	for {
   294  		select {
   295  		case <-s.tomb.Dying():
   296  			break Loop
   297  		case <-timer.C:
   298  			// Have we been idle
   299  			idleDuration := s.idle.idleDuration()
   300  			if idleDuration >= s.IdleTimeout {
   301  				s.tomb.Kill(nil)
   302  				break Loop
   303  			} else {
   304  				timer.Reset(s.IdleTimeout - idleDuration)
   305  			}
   306  		}
   307  	}
   308  	return nil
   309  }
   310  
   311  // Stop performs a graceful shutdown of the session agent and waits up to 5
   312  // seconds for it to complete.
   313  func (s *SessionAgent) Stop() error {
   314  	s.tomb.Kill(nil)
   315  	return s.tomb.Wait()
   316  }
   317  
   318  func (s *SessionAgent) Dying() <-chan struct{} {
   319  	return s.tomb.Dying()
   320  }
   321  
   322  func New() (*SessionAgent, error) {
   323  	agent := &SessionAgent{}
   324  	if err := agent.Init(); err != nil {
   325  		return nil, err
   326  	}
   327  	return agent, nil
   328  }