
     1  /*
     3  Copyright (c) 2024 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     6  */
     8  package web
    10  import (
    11  	"bytes"
    12  	"io"
    13  	"net/http"
    14  	"os"
    15  	"regexp"
    16  	"sync"
    18  	""
    19  	""
    20  )
    22  // NewStaticFileServer returns a new static file cache.
    23  func NewStaticFileServer(options ...StaticFileserverOption) *StaticFileServer {
    24  	var sfs StaticFileServer
    25  	for _, opt := range options {
    26  		opt(&sfs)
    27  	}
    28  	return &sfs
    29  }
    31  // StaticFileserverOption are options for static fileservers.
    32  type StaticFileserverOption func(*StaticFileServer)
    34  // OptStaticFileServerSearchPaths sets the static fileserver search paths.
    35  func OptStaticFileServerSearchPaths(searchPaths ...http.FileSystem) StaticFileserverOption {
    36  	return func(sfs *StaticFileServer) {
    37  		sfs.SearchPaths = searchPaths
    38  	}
    39  }
    41  // OptStaticFileServerHeaders sets the static fileserver default headers..
    42  func OptStaticFileServerHeaders(headers http.Header) StaticFileserverOption {
    43  	return func(sfs *StaticFileServer) {
    44  		sfs.Headers = headers
    45  	}
    46  }
    48  // OptStaticFileServerCacheDisabled sets the static fileserver should read from disk for each request.
    49  func OptStaticFileServerCacheDisabled(cacheDisabled bool) StaticFileserverOption {
    50  	return func(sfs *StaticFileServer) {
    51  		sfs.CacheDisabled = cacheDisabled
    52  	}
    53  }
    55  // StaticFileServer is a cache of static files.
    56  // It can operate in cached mode, or with `CacheDisabled` set to `true`
    57  // it will read from disk for each request.
    58  // In cached mode, it automatically adds etags for files it caches.
    59  type StaticFileServer struct {
    60  	sync.RWMutex
    62  	SearchPaths   []http.FileSystem
    63  	RewriteRules  []RewriteRule
    64  	Headers       http.Header
    65  	CacheDisabled bool
    66  	Cache         map[string]*CachedStaticFile
    67  }
    69  // AddHeader adds a header to the static cache results.
    70  func (sc *StaticFileServer) AddHeader(key, value string) {
    71  	if sc.Headers == nil {
    72  		sc.Headers = http.Header{}
    73  	}
    74  	sc.Headers[key] = append(sc.Headers[key], value)
    75  }
    77  // AddRewriteRule adds a static re-write rule.
    78  // This is meant to modify the path of a file from what is requested by the browser
    79  // to how a file may actually be accessed on disk.
    80  // Typically re-write rules are used to enforce caching semantics.
    81  func (sc *StaticFileServer) AddRewriteRule(match string, action RewriteAction) error {
    82  	expr, err := regexp.Compile(match)
    83  	if err != nil {
    84  		return err
    85  	}
    86  	sc.RewriteRules = append(sc.RewriteRules, RewriteRule{
    87  		MatchExpression: match,
    88  		expr:            expr,
    89  		Action:          action,
    90  	})
    91  	return nil
    92  }
    94  // Action is the entrypoint for the static server.
    95  // It  adds default headers if specified, and then serves the file from disk
    96  // or from a pull-through cache if enabled.
    97  func (sc *StaticFileServer) Action(r *Ctx) Result {
    98  	filePath, err := r.RouteParam("filepath")
    99  	if err != nil {
   100  		if r.DefaultProvider != nil {
   101  			return r.DefaultProvider.BadRequest(err)
   102  		}
   103  		http.Error(r.Response, err.Error(), http.StatusBadRequest)
   104  		return nil
   105  	}
   107  	for key, values := range sc.Headers {
   108  		for _, value := range values {
   109  			r.Response.Header().Set(key, value)
   110  		}
   111  	}
   113  	if sc.CacheDisabled {
   114  		return sc.ServeFile(r, filePath)
   115  	}
   116  	return sc.ServeCachedFile(r, filePath)
   117  }
   119  // ServeFile writes the file to the response by reading from disk
   120  // for each request (i.e. skipping the cache)
   121  func (sc *StaticFileServer) ServeFile(r *Ctx, filePath string) Result {
   122  	f, finalPath, err := sc.ResolveFile(filePath)
   123  	if err != nil {
   124  		return sc.fileError(r, err)
   125  	}
   126  	defer f.Close()
   128  	finfo, err := f.Stat()
   129  	if err != nil {
   130  		return sc.fileError(r, err)
   131  	}
   132  	if finfo.IsDir() {
   133  		return r.DefaultProvider.NotFound()
   134  	}
   136  	r.WithContext(logger.WithLabel(r.Context(), "web.static_file", finalPath))
   137  	http.ServeContent(r.Response, r.Request, filePath, finfo.ModTime(), f)
   138  	return nil
   139  }
   141  // ServeCachedFile writes the file to the response, potentially
   142  // serving a cached instance of the file.
   143  func (sc *StaticFileServer) ServeCachedFile(r *Ctx, filepath string) Result {
   144  	file, err := sc.ResolveCachedFile(filepath)
   145  	if err != nil {
   146  		return sc.fileError(r, err)
   147  	}
   148  	if file == nil {
   149  		return r.DefaultProvider.NotFound()
   150  	}
   151  	_ = file.Render(r)
   152  	return nil
   153  }
   155  // ResolveFile resolves a file from rewrite rules and search paths.
   156  // First the file path is modified according to the rewrite rules.
   157  // Then each search path is checked for the resolved file path.
   158  func (sc *StaticFileServer) ResolveFile(filePath string) (f http.File, finalPath string, err error) {
   159  	for _, rule := range sc.RewriteRules {
   160  		if matched, newFilePath := rule.Apply(filePath); matched {
   161  			filePath = newFilePath
   162  		}
   163  	}
   164  	for _, searchPath := range sc.SearchPaths {
   165  		f, err = searchPath.Open(filePath)
   166  		if typed, ok := f.(*os.File); ok && typed != nil {
   167  			finalPath = typed.Name()
   168  		}
   169  		if err != nil {
   170  			if os.IsNotExist(err) {
   171  				continue
   172  			}
   173  			return
   174  		}
   175  		if f != nil {
   176  			return
   177  		}
   178  	}
   179  	return
   180  }
   182  // ResolveCachedFile returns a cached file at a given path.
   183  // It returns the cached instance of a file if it exists, and adds it to the cache if there is a miss.
   184  func (sc *StaticFileServer) ResolveCachedFile(filepath string) (*CachedStaticFile, error) {
   185  	// start in read shared mode
   186  	sc.RLock()
   187  	if sc.Cache != nil {
   188  		if file, ok := sc.Cache[filepath]; ok {
   189  			sc.RUnlock()
   190  			return file, nil
   191  		}
   192  	}
   193  	sc.RUnlock()
   195  	// transition to exclusive write mode
   196  	sc.Lock()
   197  	defer sc.Unlock()
   199  	if sc.Cache == nil {
   200  		sc.Cache = make(map[string]*CachedStaticFile)
   201  	}
   202  	// double check ftw
   203  	if file, ok := sc.Cache[filepath]; ok {
   204  		return file, nil
   205  	}
   207  	diskFile, _, err := sc.ResolveFile(filepath)
   208  	if err != nil {
   209  		return nil, err
   210  	}
   212  	if diskFile == nil {
   213  		sc.Cache[filepath] = nil
   214  		return nil, nil
   215  	}
   217  	finfo, err := diskFile.Stat()
   218  	if err != nil {
   219  		if os.IsNotExist(err) {
   220  			return nil, nil
   221  		}
   222  		return nil, err
   223  	}
   224  	if finfo.IsDir() {
   225  		return nil, nil
   226  	}
   228  	contents, err := io.ReadAll(diskFile)
   229  	if err != nil {
   230  		return nil, err
   231  	}
   233  	file := &CachedStaticFile{
   234  		Path:     filepath,
   235  		Contents: bytes.NewReader(contents),
   236  		ModTime:  finfo.ModTime(),
   237  		ETag:     webutil.ETag(contents),
   238  		Size:     len(contents),
   239  	}
   241  	sc.Cache[filepath] = file
   242  	return file, nil
   243  }
   245  func (sc *StaticFileServer) fileError(r *Ctx, err error) Result {
   246  	if os.IsNotExist(err) {
   247  		if r.DefaultProvider != nil {
   248  			return r.DefaultProvider.NotFound()
   249  		}
   250  		http.NotFound(r.Response, r.Request)
   251  		return nil
   252  	}
   253  	if r.DefaultProvider != nil {
   254  		return r.DefaultProvider.InternalError(err)
   255  	}
   256  	http.Error(r.Response, err.Error(), http.StatusInternalServerError)
   257  	return nil
   258  }