github.com/nektos/act@v0.2.63/pkg/artifactcache/handler.go (about)

     1  package artifactcache
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"net"
     9  	"net/http"
    10  	"os"
    11  	"path/filepath"
    12  	"regexp"
    13  	"strconv"
    14  	"strings"
    15  	"sync/atomic"
    16  	"time"
    17  
    18  	"github.com/julienschmidt/httprouter"
    19  	"github.com/sirupsen/logrus"
    20  	"github.com/timshannon/bolthold"
    21  	"go.etcd.io/bbolt"
    22  
    23  	"github.com/nektos/act/pkg/common"
    24  )
    25  
    26  const (
    27  	urlBase = "/_apis/artifactcache"
    28  )
    29  
    30  type Handler struct {
    31  	dir      string
    32  	storage  *Storage
    33  	router   *httprouter.Router
    34  	listener net.Listener
    35  	server   *http.Server
    36  	logger   logrus.FieldLogger
    37  
    38  	gcing atomic.Bool
    39  	gcAt  time.Time
    40  
    41  	outboundIP string
    42  }
    43  
    44  func StartHandler(dir, outboundIP string, port uint16, logger logrus.FieldLogger) (*Handler, error) {
    45  	h := &Handler{}
    46  
    47  	if logger == nil {
    48  		discard := logrus.New()
    49  		discard.Out = io.Discard
    50  		logger = discard
    51  	}
    52  	logger = logger.WithField("module", "artifactcache")
    53  	h.logger = logger
    54  
    55  	if dir == "" {
    56  		home, err := os.UserHomeDir()
    57  		if err != nil {
    58  			return nil, err
    59  		}
    60  		dir = filepath.Join(home, ".cache", "actcache")
    61  	}
    62  	if err := os.MkdirAll(dir, 0o755); err != nil {
    63  		return nil, err
    64  	}
    65  
    66  	h.dir = dir
    67  
    68  	storage, err := NewStorage(filepath.Join(dir, "cache"))
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  	h.storage = storage
    73  
    74  	if outboundIP != "" {
    75  		h.outboundIP = outboundIP
    76  	} else if ip := common.GetOutboundIP(); ip == nil {
    77  		return nil, fmt.Errorf("unable to determine outbound IP address")
    78  	} else {
    79  		h.outboundIP = ip.String()
    80  	}
    81  
    82  	router := httprouter.New()
    83  	router.GET(urlBase+"/cache", h.middleware(h.find))
    84  	router.POST(urlBase+"/caches", h.middleware(h.reserve))
    85  	router.PATCH(urlBase+"/caches/:id", h.middleware(h.upload))
    86  	router.POST(urlBase+"/caches/:id", h.middleware(h.commit))
    87  	router.GET(urlBase+"/artifacts/:id", h.middleware(h.get))
    88  	router.POST(urlBase+"/clean", h.middleware(h.clean))
    89  
    90  	h.router = router
    91  
    92  	h.gcCache()
    93  
    94  	listener, err := net.Listen("tcp", fmt.Sprintf(":%d", port)) // listen on all interfaces
    95  	if err != nil {
    96  		return nil, err
    97  	}
    98  	server := &http.Server{
    99  		ReadHeaderTimeout: 2 * time.Second,
   100  		Handler:           router,
   101  	}
   102  	go func() {
   103  		if err := server.Serve(listener); err != nil && errors.Is(err, net.ErrClosed) {
   104  			logger.Errorf("http serve: %v", err)
   105  		}
   106  	}()
   107  	h.listener = listener
   108  	h.server = server
   109  
   110  	return h, nil
   111  }
   112  
   113  func (h *Handler) ExternalURL() string {
   114  	// TODO: make the external url configurable if necessary
   115  	return fmt.Sprintf("http://%s:%d",
   116  		h.outboundIP,
   117  		h.listener.Addr().(*net.TCPAddr).Port)
   118  }
   119  
   120  func (h *Handler) Close() error {
   121  	if h == nil {
   122  		return nil
   123  	}
   124  	var retErr error
   125  	if h.server != nil {
   126  		err := h.server.Close()
   127  		if err != nil {
   128  			retErr = err
   129  		}
   130  		h.server = nil
   131  	}
   132  	if h.listener != nil {
   133  		err := h.listener.Close()
   134  		if errors.Is(err, net.ErrClosed) {
   135  			err = nil
   136  		}
   137  		if err != nil {
   138  			retErr = err
   139  		}
   140  		h.listener = nil
   141  	}
   142  	return retErr
   143  }
   144  
   145  func (h *Handler) openDB() (*bolthold.Store, error) {
   146  	return bolthold.Open(filepath.Join(h.dir, "bolt.db"), 0o644, &bolthold.Options{
   147  		Encoder: json.Marshal,
   148  		Decoder: json.Unmarshal,
   149  		Options: &bbolt.Options{
   150  			Timeout:      5 * time.Second,
   151  			NoGrowSync:   bbolt.DefaultOptions.NoGrowSync,
   152  			FreelistType: bbolt.DefaultOptions.FreelistType,
   153  		},
   154  	})
   155  }
   156  
   157  // GET /_apis/artifactcache/cache
   158  func (h *Handler) find(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
   159  	keys := strings.Split(r.URL.Query().Get("keys"), ",")
   160  	// cache keys are case insensitive
   161  	for i, key := range keys {
   162  		keys[i] = strings.ToLower(key)
   163  	}
   164  	version := r.URL.Query().Get("version")
   165  
   166  	db, err := h.openDB()
   167  	if err != nil {
   168  		h.responseJSON(w, r, 500, err)
   169  		return
   170  	}
   171  	defer db.Close()
   172  
   173  	cache, err := findCache(db, keys, version)
   174  	if err != nil {
   175  		h.responseJSON(w, r, 500, err)
   176  		return
   177  	}
   178  	if cache == nil {
   179  		h.responseJSON(w, r, 204)
   180  		return
   181  	}
   182  
   183  	if ok, err := h.storage.Exist(cache.ID); err != nil {
   184  		h.responseJSON(w, r, 500, err)
   185  		return
   186  	} else if !ok {
   187  		_ = db.Delete(cache.ID, cache)
   188  		h.responseJSON(w, r, 204)
   189  		return
   190  	}
   191  	h.responseJSON(w, r, 200, map[string]any{
   192  		"result":          "hit",
   193  		"archiveLocation": fmt.Sprintf("%s%s/artifacts/%d", h.ExternalURL(), urlBase, cache.ID),
   194  		"cacheKey":        cache.Key,
   195  	})
   196  }
   197  
   198  // POST /_apis/artifactcache/caches
   199  func (h *Handler) reserve(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
   200  	api := &Request{}
   201  	if err := json.NewDecoder(r.Body).Decode(api); err != nil {
   202  		h.responseJSON(w, r, 400, err)
   203  		return
   204  	}
   205  	// cache keys are case insensitive
   206  	api.Key = strings.ToLower(api.Key)
   207  
   208  	cache := api.ToCache()
   209  	db, err := h.openDB()
   210  	if err != nil {
   211  		h.responseJSON(w, r, 500, err)
   212  		return
   213  	}
   214  	defer db.Close()
   215  
   216  	now := time.Now().Unix()
   217  	cache.CreatedAt = now
   218  	cache.UsedAt = now
   219  	if err := insertCache(db, cache); err != nil {
   220  		h.responseJSON(w, r, 500, err)
   221  		return
   222  	}
   223  	h.responseJSON(w, r, 200, map[string]any{
   224  		"cacheId": cache.ID,
   225  	})
   226  }
   227  
   228  // PATCH /_apis/artifactcache/caches/:id
   229  func (h *Handler) upload(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
   230  	id, err := strconv.ParseInt(params.ByName("id"), 10, 64)
   231  	if err != nil {
   232  		h.responseJSON(w, r, 400, err)
   233  		return
   234  	}
   235  
   236  	cache := &Cache{}
   237  	db, err := h.openDB()
   238  	if err != nil {
   239  		h.responseJSON(w, r, 500, err)
   240  		return
   241  	}
   242  	defer db.Close()
   243  	if err := db.Get(id, cache); err != nil {
   244  		if errors.Is(err, bolthold.ErrNotFound) {
   245  			h.responseJSON(w, r, 400, fmt.Errorf("cache %d: not reserved", id))
   246  			return
   247  		}
   248  		h.responseJSON(w, r, 500, err)
   249  		return
   250  	}
   251  
   252  	if cache.Complete {
   253  		h.responseJSON(w, r, 400, fmt.Errorf("cache %v %q: already complete", cache.ID, cache.Key))
   254  		return
   255  	}
   256  	db.Close()
   257  	start, _, err := parseContentRange(r.Header.Get("Content-Range"))
   258  	if err != nil {
   259  		h.responseJSON(w, r, 400, err)
   260  		return
   261  	}
   262  	if err := h.storage.Write(cache.ID, start, r.Body); err != nil {
   263  		h.responseJSON(w, r, 500, err)
   264  	}
   265  	h.useCache(id)
   266  	h.responseJSON(w, r, 200)
   267  }
   268  
   269  // POST /_apis/artifactcache/caches/:id
   270  func (h *Handler) commit(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
   271  	id, err := strconv.ParseInt(params.ByName("id"), 10, 64)
   272  	if err != nil {
   273  		h.responseJSON(w, r, 400, err)
   274  		return
   275  	}
   276  
   277  	cache := &Cache{}
   278  	db, err := h.openDB()
   279  	if err != nil {
   280  		h.responseJSON(w, r, 500, err)
   281  		return
   282  	}
   283  	defer db.Close()
   284  	if err := db.Get(id, cache); err != nil {
   285  		if errors.Is(err, bolthold.ErrNotFound) {
   286  			h.responseJSON(w, r, 400, fmt.Errorf("cache %d: not reserved", id))
   287  			return
   288  		}
   289  		h.responseJSON(w, r, 500, err)
   290  		return
   291  	}
   292  
   293  	if cache.Complete {
   294  		h.responseJSON(w, r, 400, fmt.Errorf("cache %v %q: already complete", cache.ID, cache.Key))
   295  		return
   296  	}
   297  
   298  	db.Close()
   299  
   300  	size, err := h.storage.Commit(cache.ID, cache.Size)
   301  	if err != nil {
   302  		h.responseJSON(w, r, 500, err)
   303  		return
   304  	}
   305  	// write real size back to cache, it may be different from the current value when the request doesn't specify it.
   306  	cache.Size = size
   307  
   308  	db, err = h.openDB()
   309  	if err != nil {
   310  		h.responseJSON(w, r, 500, err)
   311  		return
   312  	}
   313  	defer db.Close()
   314  
   315  	cache.Complete = true
   316  	if err := db.Update(cache.ID, cache); err != nil {
   317  		h.responseJSON(w, r, 500, err)
   318  		return
   319  	}
   320  
   321  	h.responseJSON(w, r, 200)
   322  }
   323  
   324  // GET /_apis/artifactcache/artifacts/:id
   325  func (h *Handler) get(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
   326  	id, err := strconv.ParseInt(params.ByName("id"), 10, 64)
   327  	if err != nil {
   328  		h.responseJSON(w, r, 400, err)
   329  		return
   330  	}
   331  	h.useCache(id)
   332  	h.storage.Serve(w, r, uint64(id))
   333  }
   334  
   335  // POST /_apis/artifactcache/clean
   336  func (h *Handler) clean(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
   337  	// TODO: don't support force deleting cache entries
   338  	// see: https://docs.github.com/en/actions/using-workflows/caching-dependencies-to-speed-up-workflows#force-deleting-cache-entries
   339  
   340  	h.responseJSON(w, r, 200)
   341  }
   342  
   343  func (h *Handler) middleware(handler httprouter.Handle) httprouter.Handle {
   344  	return func(w http.ResponseWriter, r *http.Request, params httprouter.Params) {
   345  		h.logger.Debugf("%s %s", r.Method, r.RequestURI)
   346  		handler(w, r, params)
   347  		go h.gcCache()
   348  	}
   349  }
   350  
   351  // if not found, return (nil, nil) instead of an error.
   352  func findCache(db *bolthold.Store, keys []string, version string) (*Cache, error) {
   353  	cache := &Cache{}
   354  	for _, prefix := range keys {
   355  		// if a key in the list matches exactly, don't return partial matches
   356  		if err := db.FindOne(cache,
   357  			bolthold.Where("Key").Eq(prefix).
   358  				And("Version").Eq(version).
   359  				And("Complete").Eq(true).
   360  				SortBy("CreatedAt").Reverse()); err == nil || !errors.Is(err, bolthold.ErrNotFound) {
   361  			if err != nil {
   362  				return nil, fmt.Errorf("find cache: %w", err)
   363  			}
   364  			return cache, nil
   365  		}
   366  		prefixPattern := fmt.Sprintf("^%s", regexp.QuoteMeta(prefix))
   367  		re, err := regexp.Compile(prefixPattern)
   368  		if err != nil {
   369  			continue
   370  		}
   371  		if err := db.FindOne(cache,
   372  			bolthold.Where("Key").RegExp(re).
   373  				And("Version").Eq(version).
   374  				And("Complete").Eq(true).
   375  				SortBy("CreatedAt").Reverse()); err != nil {
   376  			if errors.Is(err, bolthold.ErrNotFound) {
   377  				continue
   378  			}
   379  			return nil, fmt.Errorf("find cache: %w", err)
   380  		}
   381  		return cache, nil
   382  	}
   383  	return nil, nil
   384  }
   385  
   386  func insertCache(db *bolthold.Store, cache *Cache) error {
   387  	if err := db.Insert(bolthold.NextSequence(), cache); err != nil {
   388  		return fmt.Errorf("insert cache: %w", err)
   389  	}
   390  	// write back id to db
   391  	if err := db.Update(cache.ID, cache); err != nil {
   392  		return fmt.Errorf("write back id to db: %w", err)
   393  	}
   394  	return nil
   395  }
   396  
   397  func (h *Handler) useCache(id int64) {
   398  	db, err := h.openDB()
   399  	if err != nil {
   400  		return
   401  	}
   402  	defer db.Close()
   403  	cache := &Cache{}
   404  	if err := db.Get(id, cache); err != nil {
   405  		return
   406  	}
   407  	cache.UsedAt = time.Now().Unix()
   408  	_ = db.Update(cache.ID, cache)
   409  }
   410  
   411  const (
   412  	keepUsed   = 30 * 24 * time.Hour
   413  	keepUnused = 7 * 24 * time.Hour
   414  	keepTemp   = 5 * time.Minute
   415  	keepOld    = 5 * time.Minute
   416  )
   417  
   418  func (h *Handler) gcCache() {
   419  	if h.gcing.Load() {
   420  		return
   421  	}
   422  	if !h.gcing.CompareAndSwap(false, true) {
   423  		return
   424  	}
   425  	defer h.gcing.Store(false)
   426  
   427  	if time.Since(h.gcAt) < time.Hour {
   428  		h.logger.Debugf("skip gc: %v", h.gcAt.String())
   429  		return
   430  	}
   431  	h.gcAt = time.Now()
   432  	h.logger.Debugf("gc: %v", h.gcAt.String())
   433  
   434  	db, err := h.openDB()
   435  	if err != nil {
   436  		return
   437  	}
   438  	defer db.Close()
   439  
   440  	// Remove the caches which are not completed for a while, they are most likely to be broken.
   441  	var caches []*Cache
   442  	if err := db.Find(&caches, bolthold.
   443  		Where("UsedAt").Lt(time.Now().Add(-keepTemp).Unix()).
   444  		And("Complete").Eq(false),
   445  	); err != nil {
   446  		h.logger.Warnf("find caches: %v", err)
   447  	} else {
   448  		for _, cache := range caches {
   449  			h.storage.Remove(cache.ID)
   450  			if err := db.Delete(cache.ID, cache); err != nil {
   451  				h.logger.Warnf("delete cache: %v", err)
   452  				continue
   453  			}
   454  			h.logger.Infof("deleted cache: %+v", cache)
   455  		}
   456  	}
   457  
   458  	// Remove the old caches which have not been used recently.
   459  	caches = caches[:0]
   460  	if err := db.Find(&caches, bolthold.
   461  		Where("UsedAt").Lt(time.Now().Add(-keepUnused).Unix()),
   462  	); err != nil {
   463  		h.logger.Warnf("find caches: %v", err)
   464  	} else {
   465  		for _, cache := range caches {
   466  			h.storage.Remove(cache.ID)
   467  			if err := db.Delete(cache.ID, cache); err != nil {
   468  				h.logger.Warnf("delete cache: %v", err)
   469  				continue
   470  			}
   471  			h.logger.Infof("deleted cache: %+v", cache)
   472  		}
   473  	}
   474  
   475  	// Remove the old caches which are too old.
   476  	caches = caches[:0]
   477  	if err := db.Find(&caches, bolthold.
   478  		Where("CreatedAt").Lt(time.Now().Add(-keepUsed).Unix()),
   479  	); err != nil {
   480  		h.logger.Warnf("find caches: %v", err)
   481  	} else {
   482  		for _, cache := range caches {
   483  			h.storage.Remove(cache.ID)
   484  			if err := db.Delete(cache.ID, cache); err != nil {
   485  				h.logger.Warnf("delete cache: %v", err)
   486  				continue
   487  			}
   488  			h.logger.Infof("deleted cache: %+v", cache)
   489  		}
   490  	}
   491  
   492  	// Remove the old caches with the same key and version, keep the latest one.
   493  	// Also keep the olds which have been used recently for a while in case of the cache is still in use.
   494  	if results, err := db.FindAggregate(
   495  		&Cache{},
   496  		bolthold.Where("Complete").Eq(true),
   497  		"Key", "Version",
   498  	); err != nil {
   499  		h.logger.Warnf("find aggregate caches: %v", err)
   500  	} else {
   501  		for _, result := range results {
   502  			if result.Count() <= 1 {
   503  				continue
   504  			}
   505  			result.Sort("CreatedAt")
   506  			caches = caches[:0]
   507  			result.Reduction(&caches)
   508  			for _, cache := range caches[:len(caches)-1] {
   509  				if time.Since(time.Unix(cache.UsedAt, 0)) < keepOld {
   510  					// Keep it since it has been used recently, even if it's old.
   511  					// Or it could break downloading in process.
   512  					continue
   513  				}
   514  				h.storage.Remove(cache.ID)
   515  				if err := db.Delete(cache.ID, cache); err != nil {
   516  					h.logger.Warnf("delete cache: %v", err)
   517  					continue
   518  				}
   519  				h.logger.Infof("deleted cache: %+v", cache)
   520  			}
   521  		}
   522  	}
   523  }
   524  
   525  func (h *Handler) responseJSON(w http.ResponseWriter, r *http.Request, code int, v ...any) {
   526  	w.Header().Set("Content-Type", "application/json; charset=utf-8")
   527  	var data []byte
   528  	if len(v) == 0 || v[0] == nil {
   529  		data, _ = json.Marshal(struct{}{})
   530  	} else if err, ok := v[0].(error); ok {
   531  		h.logger.Errorf("%v %v: %v", r.Method, r.RequestURI, err)
   532  		data, _ = json.Marshal(map[string]any{
   533  			"error": err.Error(),
   534  		})
   535  	} else {
   536  		data, _ = json.Marshal(v[0])
   537  	}
   538  	w.WriteHeader(code)
   539  	_, _ = w.Write(data)
   540  }
   541  
   542  func parseContentRange(s string) (int64, int64, error) {
   543  	// support the format like "bytes 11-22/*" only
   544  	s, _, _ = strings.Cut(strings.TrimPrefix(s, "bytes "), "/")
   545  	s1, s2, _ := strings.Cut(s, "-")
   546  
   547  	start, err := strconv.ParseInt(s1, 10, 64)
   548  	if err != nil {
   549  		return 0, 0, fmt.Errorf("parse %q: %w", s, err)
   550  	}
   551  	stop, err := strconv.ParseInt(s2, 10, 64)
   552  	if err != nil {
   553  		return 0, 0, fmt.Errorf("parse %q: %w", s, err)
   554  	}
   555  	return start, stop, nil
   556  }