github.com/ethanhsieh/snapd@v0.0.0-20210615102523-3db9b8e4edc5/usersession/agent/session_agent.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2019 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package agent
    21  
    22  import (
    23  	"context"
    24  	"fmt"
    25  	"net"
    26  	"net/http"
    27  	"os"
    28  	"sync"
    29  	"syscall"
    30  	"time"
    31  
    32  	"github.com/godbus/dbus"
    33  	"github.com/gorilla/mux"
    34  	"gopkg.in/tomb.v2"
    35  
    36  	"github.com/snapcore/snapd/dbusutil"
    37  	"github.com/snapcore/snapd/dirs"
    38  	"github.com/snapcore/snapd/logger"
    39  	"github.com/snapcore/snapd/netutil"
    40  	"github.com/snapcore/snapd/osutil/sys"
    41  	"github.com/snapcore/snapd/systemd"
    42  )
    43  
    44  type SessionAgent struct {
    45  	Version  string
    46  	bus      *dbus.Conn
    47  	listener net.Listener
    48  	serve    *http.Server
    49  	tomb     tomb.Tomb
    50  	router   *mux.Router
    51  
    52  	idle        *idleTracker
    53  	IdleTimeout time.Duration
    54  }
    55  
    56  const sessionAgentBusName = "io.snapcraft.SessionAgent"
    57  
    58  // A ResponseFunc handles one of the individual verbs for a method
    59  type ResponseFunc func(*Command, *http.Request) Response
    60  
    61  // A Command routes a request to an individual per-verb ResponseFunc
    62  type Command struct {
    63  	Path string
    64  
    65  	GET    ResponseFunc
    66  	PUT    ResponseFunc
    67  	POST   ResponseFunc
    68  	DELETE ResponseFunc
    69  
    70  	s *SessionAgent
    71  }
    72  
    73  func (c *Command) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    74  	var rspf ResponseFunc
    75  	var rsp = MethodNotAllowed("method %q not allowed", r.Method)
    76  
    77  	switch r.Method {
    78  	case "GET":
    79  		rspf = c.GET
    80  	case "PUT":
    81  		rspf = c.PUT
    82  	case "POST":
    83  		rspf = c.POST
    84  	case "DELETE":
    85  		rspf = c.DELETE
    86  	}
    87  
    88  	if rspf != nil {
    89  		rsp = rspf(c, r)
    90  	}
    91  	rsp.ServeHTTP(w, r)
    92  }
    93  
    94  type idleTracker struct {
    95  	mu         sync.Mutex
    96  	active     map[net.Conn]struct{}
    97  	lastActive time.Time
    98  }
    99  
   100  var sysGetsockoptUcred = syscall.GetsockoptUcred
   101  
   102  func getUcred(conn net.Conn) (*syscall.Ucred, error) {
   103  	if uconn, ok := conn.(*net.UnixConn); ok {
   104  		f, err := uconn.File()
   105  		if err != nil {
   106  			return nil, err
   107  		}
   108  		defer f.Close()
   109  		return sysGetsockoptUcred(int(f.Fd()), syscall.SOL_SOCKET, syscall.SO_PEERCRED)
   110  	}
   111  	return nil, fmt.Errorf("expected a net.UnixConn, but got a %T", conn)
   112  }
   113  
   114  func (it *idleTracker) trackConn(conn net.Conn, state http.ConnState) {
   115  	// Perform peer credentials check
   116  	if state == http.StateNew {
   117  		ucred, err := getUcred(conn)
   118  		if err != nil {
   119  			logger.Noticef("Failed to retrieve peer credentials: %v", err)
   120  			conn.Close()
   121  			return
   122  		}
   123  		if ucred.Uid != 0 && ucred.Uid != uint32(sys.Geteuid()) {
   124  			logger.Noticef("Blocking request from user ID %v", ucred.Uid)
   125  			conn.Close()
   126  			return
   127  		}
   128  	}
   129  
   130  	it.mu.Lock()
   131  	defer it.mu.Unlock()
   132  	oldActive := len(it.active)
   133  	if state == http.StateNew || state == http.StateActive {
   134  		it.active[conn] = struct{}{}
   135  	} else {
   136  		delete(it.active, conn)
   137  	}
   138  	if len(it.active) == 0 && oldActive != 0 {
   139  		it.lastActive = time.Now()
   140  	}
   141  }
   142  
   143  // idleDuration returns the duration of time the server has been idle
   144  func (it *idleTracker) idleDuration() time.Duration {
   145  	it.mu.Lock()
   146  	defer it.mu.Unlock()
   147  	if len(it.active) != 0 {
   148  		return 0
   149  	}
   150  	return time.Since(it.lastActive)
   151  }
   152  
   153  const (
   154  	defaultIdleTimeout = 30 * time.Second
   155  	shutdownTimeout    = 5 * time.Second
   156  )
   157  
   158  type closeOnceListener struct {
   159  	net.Listener
   160  
   161  	idempotClose sync.Once
   162  	closeErr     error
   163  }
   164  
   165  func (l *closeOnceListener) Close() error {
   166  	l.idempotClose.Do(func() {
   167  		l.closeErr = l.Listener.Close()
   168  	})
   169  	return l.closeErr
   170  }
   171  
   172  func (s *SessionAgent) Init() error {
   173  	// Set up D-Bus connection
   174  	if err := s.tryConnectSessionBus(); err != nil {
   175  		return err
   176  	}
   177  
   178  	// Set up REST API server
   179  	listenerMap, err := netutil.ActivationListeners()
   180  	if err != nil {
   181  		return err
   182  	}
   183  	agentSocket := fmt.Sprintf("%s/%d/snapd-session-agent.socket", dirs.XdgRuntimeDirBase, os.Getuid())
   184  	if l, err := netutil.GetListener(agentSocket, listenerMap); err != nil {
   185  		return fmt.Errorf("cannot listen on socket %s: %v", agentSocket, err)
   186  	} else {
   187  		s.listener = &closeOnceListener{Listener: l}
   188  	}
   189  	s.idle = &idleTracker{
   190  		active:     make(map[net.Conn]struct{}),
   191  		lastActive: time.Now(),
   192  	}
   193  	s.IdleTimeout = defaultIdleTimeout
   194  	s.addRoutes()
   195  	s.serve = &http.Server{
   196  		Handler:   s.router,
   197  		ConnState: s.idle.trackConn,
   198  	}
   199  	return nil
   200  }
   201  
   202  func (s *SessionAgent) tryConnectSessionBus() (err error) {
   203  	s.bus, err = dbusutil.SessionBusPrivate()
   204  	if err != nil {
   205  		// ssh sessions on Ubuntu 16.04 may have a user
   206  		// instance of systemd but no D-Bus session bus.  So
   207  		// don't treat this as an error.
   208  		logger.Noticef("Could not connect to session bus: %v", err)
   209  		return nil
   210  	}
   211  	defer func() {
   212  		if err != nil {
   213  			s.bus.Close()
   214  			s.bus = nil
   215  		}
   216  	}()
   217  
   218  	reply, err := s.bus.RequestName(sessionAgentBusName, dbus.NameFlagDoNotQueue)
   219  	if err != nil {
   220  		return err
   221  	}
   222  	if reply != dbus.RequestNameReplyPrimaryOwner {
   223  		return fmt.Errorf("cannot obtain bus name %q: %v", sessionAgentBusName, reply)
   224  	}
   225  	return nil
   226  }
   227  
   228  func (s *SessionAgent) addRoutes() {
   229  	s.router = mux.NewRouter()
   230  	for _, c := range restApi {
   231  		c.s = s
   232  		s.router.Handle(c.Path, c).Name(c.Path)
   233  	}
   234  	s.router.NotFoundHandler = NotFound("not found")
   235  }
   236  
   237  func (s *SessionAgent) Start() {
   238  	s.tomb.Go(s.runServer)
   239  	s.tomb.Go(s.shutdownServerOnKill)
   240  	s.tomb.Go(s.exitOnIdle)
   241  	systemd.SdNotify("READY=1")
   242  }
   243  
   244  func (s *SessionAgent) runServer() error {
   245  	err := s.serve.Serve(s.listener)
   246  	if err == http.ErrServerClosed {
   247  		err = nil
   248  	}
   249  	if s.tomb.Err() == tomb.ErrStillAlive {
   250  		return err
   251  	}
   252  	return nil
   253  }
   254  
   255  func (s *SessionAgent) shutdownServerOnKill() error {
   256  	<-s.tomb.Dying()
   257  	systemd.SdNotify("STOPPING=1")
   258  	// closing the listener (but then it needs wrapping in
   259  	// closeOnceListener) before actually calling Shutdown, to
   260  	// workaround https://github.com/golang/go/issues/20239, we
   261  	// can in some cases (e.g. tests) end up calling Shutdown
   262  	// before runServer calls Serve and in go <1.11 this can be
   263  	// racy and the shutdown blocks.
   264  	// Historically We do something similar in the main daemon
   265  	// logic as well.
   266  	s.listener.Close()
   267  	s.bus.Close()
   268  	ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
   269  	defer cancel()
   270  	return s.serve.Shutdown(ctx)
   271  }
   272  
   273  func (s *SessionAgent) exitOnIdle() error {
   274  	timer := time.NewTimer(s.IdleTimeout)
   275  Loop:
   276  	for {
   277  		select {
   278  		case <-s.tomb.Dying():
   279  			break Loop
   280  		case <-timer.C:
   281  			// Have we been idle
   282  			idleDuration := s.idle.idleDuration()
   283  			if idleDuration >= s.IdleTimeout {
   284  				s.tomb.Kill(nil)
   285  				break Loop
   286  			} else {
   287  				timer.Reset(s.IdleTimeout - idleDuration)
   288  			}
   289  		}
   290  	}
   291  	return nil
   292  }
   293  
   294  // Stop performs a graceful shutdown of the session agent and waits up to 5
   295  // seconds for it to complete.
   296  func (s *SessionAgent) Stop() error {
   297  	s.tomb.Kill(nil)
   298  	return s.tomb.Wait()
   299  }
   300  
   301  func (s *SessionAgent) Dying() <-chan struct{} {
   302  	return s.tomb.Dying()
   303  }
   304  
   305  func New() (*SessionAgent, error) {
   306  	agent := &SessionAgent{}
   307  	if err := agent.Init(); err != nil {
   308  		return nil, err
   309  	}
   310  	return agent, nil
   311  }