github.com/kubiko/snapd@v0.0.0-20201013125620-d4f3094d9ddf/usersession/agent/session_agent.go (about)

     1  // -*- Mode: Go; indent-tabs-mode: t -*-
     2  
     3  /*
     4   * Copyright (C) 2019 Canonical Ltd
     5   *
     6   * This program is free software: you can redistribute it and/or modify
     7   * it under the terms of the GNU General Public License version 3 as
     8   * published by the Free Software Foundation.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package agent
    21  
    22  import (
    23  	"context"
    24  	"fmt"
    25  	"net"
    26  	"net/http"
    27  	"os"
    28  	"sync"
    29  	"syscall"
    30  	"time"
    31  
    32  	"github.com/gorilla/mux"
    33  	"gopkg.in/tomb.v2"
    34  
    35  	"github.com/snapcore/snapd/dirs"
    36  	"github.com/snapcore/snapd/logger"
    37  	"github.com/snapcore/snapd/netutil"
    38  	"github.com/snapcore/snapd/osutil/sys"
    39  	"github.com/snapcore/snapd/systemd"
    40  )
    41  
    42  type SessionAgent struct {
    43  	Version  string
    44  	listener net.Listener
    45  	serve    *http.Server
    46  	tomb     tomb.Tomb
    47  	router   *mux.Router
    48  
    49  	idle        *idleTracker
    50  	IdleTimeout time.Duration
    51  }
    52  
    53  // A ResponseFunc handles one of the individual verbs for a method
    54  type ResponseFunc func(*Command, *http.Request) Response
    55  
    56  // A Command routes a request to an individual per-verb ResponseFunc
    57  type Command struct {
    58  	Path string
    59  
    60  	GET    ResponseFunc
    61  	PUT    ResponseFunc
    62  	POST   ResponseFunc
    63  	DELETE ResponseFunc
    64  
    65  	s *SessionAgent
    66  }
    67  
    68  func (c *Command) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    69  	var rspf ResponseFunc
    70  	var rsp = MethodNotAllowed("method %q not allowed", r.Method)
    71  
    72  	switch r.Method {
    73  	case "GET":
    74  		rspf = c.GET
    75  	case "PUT":
    76  		rspf = c.PUT
    77  	case "POST":
    78  		rspf = c.POST
    79  	case "DELETE":
    80  		rspf = c.DELETE
    81  	}
    82  
    83  	if rspf != nil {
    84  		rsp = rspf(c, r)
    85  	}
    86  	rsp.ServeHTTP(w, r)
    87  }
    88  
    89  type idleTracker struct {
    90  	mu         sync.Mutex
    91  	active     map[net.Conn]struct{}
    92  	lastActive time.Time
    93  }
    94  
    95  var sysGetsockoptUcred = syscall.GetsockoptUcred
    96  
    97  func getUcred(conn net.Conn) (*syscall.Ucred, error) {
    98  	if uconn, ok := conn.(*net.UnixConn); ok {
    99  		f, err := uconn.File()
   100  		if err != nil {
   101  			return nil, err
   102  		}
   103  		defer f.Close()
   104  		return sysGetsockoptUcred(int(f.Fd()), syscall.SOL_SOCKET, syscall.SO_PEERCRED)
   105  	}
   106  	return nil, fmt.Errorf("expected a net.UnixConn, but got a %T", conn)
   107  }
   108  
   109  func (it *idleTracker) trackConn(conn net.Conn, state http.ConnState) {
   110  	// Perform peer credentials check
   111  	if state == http.StateNew {
   112  		ucred, err := getUcred(conn)
   113  		if err != nil {
   114  			logger.Noticef("Failed to retrieve peer credentials: %v", err)
   115  			conn.Close()
   116  			return
   117  		}
   118  		if ucred.Uid != 0 && ucred.Uid != uint32(sys.Geteuid()) {
   119  			logger.Noticef("Blocking request from user ID %v", ucred.Uid)
   120  			conn.Close()
   121  			return
   122  		}
   123  	}
   124  
   125  	it.mu.Lock()
   126  	defer it.mu.Unlock()
   127  	oldActive := len(it.active)
   128  	if state == http.StateNew || state == http.StateActive {
   129  		it.active[conn] = struct{}{}
   130  	} else {
   131  		delete(it.active, conn)
   132  	}
   133  	if len(it.active) == 0 && oldActive != 0 {
   134  		it.lastActive = time.Now()
   135  	}
   136  }
   137  
   138  // idleDuration returns the duration of time the server has been idle
   139  func (it *idleTracker) idleDuration() time.Duration {
   140  	it.mu.Lock()
   141  	defer it.mu.Unlock()
   142  	if len(it.active) != 0 {
   143  		return 0
   144  	}
   145  	return time.Since(it.lastActive)
   146  }
   147  
   148  const (
   149  	defaultIdleTimeout = 30 * time.Second
   150  	shutdownTimeout    = 5 * time.Second
   151  )
   152  
   153  type closeOnceListener struct {
   154  	net.Listener
   155  
   156  	idempotClose sync.Once
   157  	closeErr     error
   158  }
   159  
   160  func (l *closeOnceListener) Close() error {
   161  	l.idempotClose.Do(func() {
   162  		l.closeErr = l.Listener.Close()
   163  	})
   164  	return l.closeErr
   165  }
   166  
   167  func (s *SessionAgent) Init() error {
   168  	listenerMap, err := netutil.ActivationListeners()
   169  	if err != nil {
   170  		return err
   171  	}
   172  	agentSocket := fmt.Sprintf("%s/%d/snapd-session-agent.socket", dirs.XdgRuntimeDirBase, os.Getuid())
   173  	if l, err := netutil.GetListener(agentSocket, listenerMap); err != nil {
   174  		return fmt.Errorf("cannot listen on socket %s: %v", agentSocket, err)
   175  	} else {
   176  		s.listener = &closeOnceListener{Listener: l}
   177  	}
   178  	s.idle = &idleTracker{
   179  		active:     make(map[net.Conn]struct{}),
   180  		lastActive: time.Now(),
   181  	}
   182  	s.IdleTimeout = defaultIdleTimeout
   183  	s.addRoutes()
   184  	s.serve = &http.Server{
   185  		Handler:   s.router,
   186  		ConnState: s.idle.trackConn,
   187  	}
   188  	return nil
   189  }
   190  
   191  func (s *SessionAgent) addRoutes() {
   192  	s.router = mux.NewRouter()
   193  	for _, c := range restApi {
   194  		c.s = s
   195  		s.router.Handle(c.Path, c).Name(c.Path)
   196  	}
   197  	s.router.NotFoundHandler = NotFound("not found")
   198  }
   199  
   200  func (s *SessionAgent) Start() {
   201  	s.tomb.Go(s.runServer)
   202  	s.tomb.Go(s.shutdownServerOnKill)
   203  	s.tomb.Go(s.exitOnIdle)
   204  	systemd.SdNotify("READY=1")
   205  }
   206  
   207  func (s *SessionAgent) runServer() error {
   208  	err := s.serve.Serve(s.listener)
   209  	if err == http.ErrServerClosed {
   210  		err = nil
   211  	}
   212  	if s.tomb.Err() == tomb.ErrStillAlive {
   213  		return err
   214  	}
   215  	return nil
   216  }
   217  
   218  func (s *SessionAgent) shutdownServerOnKill() error {
   219  	<-s.tomb.Dying()
   220  	// closing the listener (but then it needs wrapping in
   221  	// closeOnceListener) before actually calling Shutdown, to
   222  	// workaround https://github.com/golang/go/issues/20239, we
   223  	// can in some cases (e.g. tests) end up calling Shutdown
   224  	// before runServer calls Serve and in go <1.11 this can be
   225  	// racy and the shutdown blocks.
   226  	// Historically We do something similar in the main daemon
   227  	// logic as well.
   228  	s.listener.Close()
   229  	systemd.SdNotify("STOPPING=1")
   230  	ctx, cancel := context.WithTimeout(context.Background(), shutdownTimeout)
   231  	defer cancel()
   232  	return s.serve.Shutdown(ctx)
   233  }
   234  
   235  func (s *SessionAgent) exitOnIdle() error {
   236  	timer := time.NewTimer(s.IdleTimeout)
   237  Loop:
   238  	for {
   239  		select {
   240  		case <-s.tomb.Dying():
   241  			break Loop
   242  		case <-timer.C:
   243  			// Have we been idle
   244  			idleDuration := s.idle.idleDuration()
   245  			if idleDuration >= s.IdleTimeout {
   246  				s.tomb.Kill(nil)
   247  				break Loop
   248  			} else {
   249  				timer.Reset(s.IdleTimeout - idleDuration)
   250  			}
   251  		}
   252  	}
   253  	return nil
   254  }
   255  
   256  // Stop performs a graceful shutdown of the session agent and waits up to 5
   257  // seconds for it to complete.
   258  func (s *SessionAgent) Stop() error {
   259  	s.tomb.Kill(nil)
   260  	return s.tomb.Wait()
   261  }
   262  
   263  func (s *SessionAgent) Dying() <-chan struct{} {
   264  	return s.tomb.Dying()
   265  }
   266  
   267  func New() (*SessionAgent, error) {
   268  	agent := &SessionAgent{}
   269  	if err := agent.Init(); err != nil {
   270  		return nil, err
   271  	}
   272  	return agent, nil
   273  }