github.com/keybase/client/go@v0.0.0-20241007131713-f10651d043c8/kbfs/libhttpserver/server.go (about)

     1  // Copyright 2018 Keybase Inc. All rights reserved.
     2  // Use of this source code is governed by a BSD
     3  // license that can be found in the LICENSE file.
     4  
     5  package libhttpserver
     6  
     7  import (
     8  	"context"
     9  	"crypto/rand"
    10  	"encoding/base64"
    11  	"io"
    12  	"net/http"
    13  	"path"
    14  	"strings"
    15  	"sync"
    16  	"time"
    17  
    18  	lru "github.com/hashicorp/golang-lru"
    19  	"github.com/keybase/client/go/kbfs/data"
    20  	"github.com/keybase/client/go/kbfs/env"
    21  	"github.com/keybase/client/go/kbfs/libfs"
    22  	"github.com/keybase/client/go/kbfs/libkbfs"
    23  	"github.com/keybase/client/go/kbfs/libmime"
    24  	"github.com/keybase/client/go/kbfs/tlf"
    25  	"github.com/keybase/client/go/kbhttp"
    26  	"github.com/keybase/client/go/libkb"
    27  	"github.com/keybase/client/go/logger"
    28  	"github.com/keybase/client/go/protocol/keybase1"
    29  )
    30  
    31  const fsCacheSize = 64
    32  
    33  // Server is a local HTTP server for serving KBFS content over HTTP.
    34  type Server struct {
    35  	config          libkbfs.Config
    36  	logger          logger.Logger
    37  	vlog            *libkb.VDebugLog
    38  	appStateUpdater env.AppStateUpdater
    39  	cancel          func()
    40  
    41  	tokenLock       sync.RWMutex
    42  	token           string
    43  	tokenExpireTime time.Time
    44  
    45  	fs *lru.Cache
    46  
    47  	serverLock sync.RWMutex
    48  	server     *kbhttp.Srv
    49  }
    50  
    51  const tokenByteSize = 32
    52  const tokenValidTime = 10 * time.Minute
    53  
    54  // CurrentToken returns the currently valid token that a HTTP client can use to
    55  // load content from the server.
    56  func (s *Server) CurrentToken() (token string, err error) {
    57  	s.tokenLock.RLock()
    58  	if s.config.Clock().Now().Before(s.tokenExpireTime) {
    59  		defer s.tokenLock.RUnlock()
    60  		return s.token, nil
    61  	}
    62  
    63  	s.tokenLock.RUnlock()
    64  
    65  	buf := make([]byte, tokenByteSize)
    66  	if _, err = rand.Read(buf); err != nil {
    67  		return "", err
    68  	}
    69  	token = base64.URLEncoding.EncodeToString(buf)
    70  
    71  	s.tokenLock.Lock()
    72  	defer s.tokenLock.Unlock()
    73  
    74  	if s.config.Clock().Now().Before(s.tokenExpireTime) {
    75  		return s.token, nil
    76  	}
    77  
    78  	s.token = token
    79  	s.tokenExpireTime = s.config.Clock().Now().Add(tokenValidTime)
    80  
    81  	return token, nil
    82  }
    83  
    84  func (s *Server) handleInvalidToken(w http.ResponseWriter) {
    85  	w.WriteHeader(http.StatusForbidden)
    86  	_, _ = io.WriteString(w, `
    87      <html>
    88          <head>
    89              <title>KBFS HTTP Token Invalid</title>
    90          </head>
    91          <body>
    92              token invalid
    93          </body>
    94      </html>
    95      `)
    96  }
    97  
    98  func (s *Server) handleBadRequest(w http.ResponseWriter) {
    99  	w.WriteHeader(http.StatusBadRequest)
   100  }
   101  
   102  func (s *Server) handleInternalServerError(w http.ResponseWriter) {
   103  	w.WriteHeader(http.StatusInternalServerError)
   104  }
   105  
   106  type obsoleteTrackingFS struct {
   107  	fs *libfs.FS
   108  	ch <-chan struct{}
   109  }
   110  
   111  func (e obsoleteTrackingFS) isObsolete() bool {
   112  	select {
   113  	case <-e.ch:
   114  		return true
   115  	default:
   116  		return false
   117  	}
   118  }
   119  
   120  func (s *Server) getHTTPFileSystem(ctx context.Context, requestPath string) (
   121  	toStrip string, fs http.FileSystem, err error) {
   122  	fields := strings.Split(requestPath, "/")
   123  	if len(fields) < 2 {
   124  		return "", libfs.NewRootFS(s.config).ToHTTPFileSystem(ctx), nil
   125  	}
   126  
   127  	tlfType, err := tlf.ParseTlfTypeFromPath(fields[0])
   128  	if err != nil {
   129  		return "", nil, err
   130  	}
   131  
   132  	toStrip = path.Join(fields[0], fields[1])
   133  
   134  	if fsCached, ok := s.fs.Get(toStrip); ok {
   135  		if fsCachedTyped, ok := fsCached.(obsoleteTrackingFS); ok {
   136  			if !fsCachedTyped.isObsolete() {
   137  				return toStrip, fsCachedTyped.fs.ToHTTPFileSystem(ctx), nil
   138  			}
   139  		}
   140  	}
   141  
   142  	tlfHandle, err := libkbfs.GetHandleFromFolderNameAndType(ctx,
   143  		s.config.KBPKI(), s.config.MDOps(), s.config, fields[1], tlfType)
   144  	if err != nil {
   145  		return "", nil, err
   146  	}
   147  
   148  	tlfFS, err := libfs.NewFS(ctx,
   149  		s.config, tlfHandle, data.MasterBranch, "", "",
   150  		keybase1.MDPriorityNormal)
   151  	if err != nil {
   152  		return "", nil, err
   153  	}
   154  
   155  	fsLifeCh, err := tlfFS.SubscribeToObsolete()
   156  	if err != nil {
   157  		return "", nil, err
   158  	}
   159  
   160  	s.fs.Add(toStrip, obsoleteTrackingFS{fs: tlfFS, ch: fsLifeCh})
   161  
   162  	return toStrip, tlfFS.ToHTTPFileSystem(ctx), nil
   163  }
   164  
   165  // serve accepts "/<fs path>?token=<token>"
   166  // For example:
   167  //
   168  //	/team/keybase/file.txt?token=1234567890abcdef1234567890abcdef
   169  func (s *Server) serve(w http.ResponseWriter, req *http.Request) {
   170  	s.vlog.Log(libkb.VLog1, "Incoming request from %q: %s", req.UserAgent(), req.URL)
   171  	addr, err := s.server.Addr()
   172  	if err != nil {
   173  		s.logger.Error("serve: failed to get HTTP server address: %s", err)
   174  		s.handleInternalServerError(w)
   175  		return
   176  	}
   177  	if req.Host != addr {
   178  		s.logger.Warning("Host %s didn't match addr %s, failing request to protect against DNS rebinding", req.Host, addr)
   179  		s.handleBadRequest(w)
   180  		return
   181  	}
   182  	token := req.URL.Query().Get("token")
   183  	currentToken, err := s.CurrentToken()
   184  	if err != nil {
   185  		s.logger.Error("serve: failed to get current token: %s", err)
   186  		s.handleInternalServerError(w)
   187  		return
   188  	}
   189  	if len(token) == 0 || token != currentToken {
   190  		s.vlog.Log(libkb.VLog1, "Invalid token %q", token)
   191  		s.handleInvalidToken(w)
   192  		return
   193  	}
   194  	toStrip, fs, err := s.getHTTPFileSystem(req.Context(), req.URL.Path)
   195  	if err != nil {
   196  		s.logger.Warning("Bad request; error=%v", err)
   197  		s.handleBadRequest(w)
   198  		return
   199  	}
   200  	viewTypeInvariance := req.URL.Query().Get("viewTypeInvariance")
   201  	if len(viewTypeInvariance) == 0 {
   202  		s.logger.Warning("Bad request; missing viewTypeInvariance")
   203  		s.handleBadRequest(w)
   204  		return
   205  	}
   206  	wrappedW := newContentTypeOverridingResponseWriter(w,
   207  		viewTypeInvariance)
   208  	http.StripPrefix(toStrip, http.FileServer(fs)).ServeHTTP(wrappedW, req)
   209  }
   210  
   211  const portStart = 16723
   212  const portEnd = 60000
   213  const requestPathRoot = "/files/"
   214  
   215  func (s *Server) restart() (err error) {
   216  	s.serverLock.Lock()
   217  	defer s.serverLock.Unlock()
   218  	if s.server != nil {
   219  		s.server.Stop()
   220  		err = s.server.Start()
   221  	}
   222  	if s.server == nil ||
   223  		// If pinned port is in use, just pick a new one like we never had a
   224  		// server before.
   225  		err == kbhttp.ErrPinnedPortInUse {
   226  		s.server = kbhttp.NewSrv(s.logger,
   227  			kbhttp.NewRandomPortRangeListenerSource(portStart, portEnd))
   228  		err = s.server.Start()
   229  	}
   230  	if err != nil {
   231  		return err
   232  	}
   233  	// Have to start this first to populate the ServeMux object.
   234  	s.server.Handle(requestPathRoot,
   235  		http.StripPrefix(requestPathRoot, http.HandlerFunc(s.serve)))
   236  	return nil
   237  }
   238  
   239  func (s *Server) monitorAppState(ctx context.Context) {
   240  	state := keybase1.MobileAppState_FOREGROUND
   241  	for {
   242  		select {
   243  		case <-ctx.Done():
   244  			return
   245  		case state = <-s.appStateUpdater.NextAppStateUpdate(&state):
   246  			// Due to the way NextUpdate is designed, it's possible we miss an
   247  			// update if processing the last update takes too long. So it's
   248  			// possible to get consecutive FOREGROUND updates even if there are
   249  			// other states in-between. Since libkb/appstate.go already
   250  			// deduplicates, it'll never actually send consecutive identical
   251  			// states to us. In addition, apart from FOREGROUND/BACKGROUND,
   252  			// there are other possible states too, and potentially more in the
   253  			// future. So, we just restart the server under FOREGROUND instead
   254  			// of trying to listen on all state updates.
   255  			if state != keybase1.MobileAppState_FOREGROUND {
   256  				continue
   257  			}
   258  			if err := s.restart(); err != nil {
   259  				s.logger.Error("(Re)starting server failed: %v", err)
   260  			}
   261  		}
   262  	}
   263  }
   264  
   265  // New creates and starts a new server.
   266  func New(appStateUpdater env.AppStateUpdater, config libkbfs.Config) (
   267  	s *Server, err error) {
   268  	logger := config.MakeLogger("HTTP")
   269  	s = &Server{
   270  		appStateUpdater: appStateUpdater,
   271  		config:          config,
   272  		logger:          logger,
   273  		vlog:            config.MakeVLogger(logger),
   274  	}
   275  	if s.fs, err = lru.New(fsCacheSize); err != nil {
   276  		return nil, err
   277  	}
   278  	if err = s.restart(); err != nil {
   279  		return nil, err
   280  	}
   281  	ctx, cancel := context.WithCancel(context.Background())
   282  	go s.monitorAppState(ctx)
   283  	s.cancel = cancel
   284  	libmime.Patch(additionalMimeTypes)
   285  	return s, nil
   286  }
   287  
   288  // Address returns the address that the server is listening on.
   289  func (s *Server) Address() (string, error) {
   290  	s.serverLock.RLock()
   291  	defer s.serverLock.RUnlock()
   292  	return s.server.Addr()
   293  }
   294  
   295  // Shutdown shuts down the server.
   296  func (s *Server) Shutdown() {
   297  	s.serverLock.Lock()
   298  	defer s.serverLock.Unlock()
   299  	s.server.Stop()
   300  	s.cancel()
   301  }