github.com/crowdsecurity/crowdsec@v1.6.1/pkg/cwhub/dataset.go (about)

     1  package cwhub
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"io"
     7  	"io/fs"
     8  	"net/http"
     9  	"os"
    10  	"path/filepath"
    11  	"runtime"
    12  	"time"
    13  
    14  	"github.com/sirupsen/logrus"
    15  	"gopkg.in/yaml.v3"
    16  
    17  	"github.com/crowdsecurity/crowdsec/pkg/types"
    18  )
    19  
    20  // The DataSet is a list of data sources required by an item (built from the data: section in the yaml).
    21  type DataSet struct {
    22  	Data []types.DataSource `yaml:"data,omitempty"`
    23  }
    24  
    25  // downloadFile downloads a file and writes it to disk, with no hash verification.
    26  func downloadFile(url string, destPath string) error {
    27  	resp, err := hubClient.Get(url)
    28  	if err != nil {
    29  		return fmt.Errorf("while downloading %s: %w", url, err)
    30  	}
    31  	defer resp.Body.Close()
    32  
    33  	if resp.StatusCode != http.StatusOK {
    34  		return fmt.Errorf("bad http code %d for %s", resp.StatusCode, url)
    35  	}
    36  
    37  	// Download to a temporary location to avoid corrupting files
    38  	// that are currently in use or memory mapped.
    39  
    40  	tmpFile, err := os.CreateTemp(filepath.Dir(destPath), filepath.Base(destPath)+".*.tmp")
    41  	if err != nil {
    42  		return err
    43  	}
    44  
    45  	tmpFileName := tmpFile.Name()
    46  	defer func() {
    47  		tmpFile.Close()
    48  		os.Remove(tmpFileName)
    49  	}()
    50  
    51  	// avoid reading the whole file in memory
    52  	_, err = io.Copy(tmpFile, resp.Body)
    53  	if err != nil {
    54  		return err
    55  	}
    56  
    57  	if err = tmpFile.Sync(); err != nil {
    58  		return err
    59  	}
    60  
    61  	if err = tmpFile.Close(); err != nil {
    62  		return err
    63  	}
    64  
    65  	// a check on stdout is used while scripting to know if the hub has been upgraded
    66  	// and a configuration reload is required
    67  	// TODO: use a better way to communicate this
    68  	fmt.Printf("updated %s\n", filepath.Base(destPath))
    69  
    70  	if runtime.GOOS == "windows" {
    71  		// On Windows, rename will fail if the destination file already exists
    72  		// so we remove it first.
    73  		err = os.Remove(destPath)
    74  		switch {
    75  		case errors.Is(err, fs.ErrNotExist):
    76  			break
    77  		case err != nil:
    78  			return err
    79  		}
    80  	}
    81  
    82  	if err = os.Rename(tmpFileName, destPath); err != nil {
    83  		return err
    84  	}
    85  
    86  	return nil
    87  }
    88  
    89  // needsUpdate checks if a data file has to be downloaded (or updated).
    90  // if the local file doesn't exist, update.
    91  // if the remote is newer than the local file, update.
    92  // if the remote has no modification date, but local file has been modified > a week ago, update.
    93  func needsUpdate(destPath string, url string, logger *logrus.Logger) bool {
    94  	fileInfo, err := os.Stat(destPath)
    95  
    96  	switch {
    97  	case os.IsNotExist(err):
    98  		return true
    99  	case err != nil:
   100  		logger.Errorf("while getting %s: %s", destPath, err)
   101  		return true
   102  	}
   103  
   104  	resp, err := hubClient.Head(url)
   105  	if err != nil {
   106  		logger.Errorf("while getting %s: %s", url, err)
   107  		// Head failed, Get would likely fail too -> no update
   108  		return false
   109  	}
   110  	defer resp.Body.Close()
   111  
   112  	if resp.StatusCode != http.StatusOK {
   113  		logger.Errorf("bad http code %d for %s", resp.StatusCode, url)
   114  		return false
   115  	}
   116  
   117  	// update if local file is older than this
   118  	shelfLife := 7 * 24 * time.Hour
   119  
   120  	lastModify := fileInfo.ModTime()
   121  
   122  	localIsOld := lastModify.Add(shelfLife).Before(time.Now())
   123  
   124  	remoteLastModified := resp.Header.Get("Last-Modified")
   125  	if remoteLastModified == "" {
   126  		if localIsOld {
   127  			logger.Infof("no last modified date for %s, but local file is older than %s", url, shelfLife)
   128  		}
   129  
   130  		return localIsOld
   131  	}
   132  
   133  	lastAvailable, err := time.Parse(time.RFC1123, remoteLastModified)
   134  	if err != nil {
   135  		logger.Warningf("while parsing last modified date for %s: %s", url, err)
   136  		return localIsOld
   137  	}
   138  
   139  	if lastModify.Before(lastAvailable) {
   140  		logger.Infof("new version available, updating %s", destPath)
   141  		return true
   142  	}
   143  
   144  	return false
   145  }
   146  
   147  // downloadDataSet downloads all the data files for an item.
   148  func downloadDataSet(dataFolder string, force bool, reader io.Reader, logger *logrus.Logger) error {
   149  	dec := yaml.NewDecoder(reader)
   150  
   151  	for {
   152  		data := &DataSet{}
   153  
   154  		if err := dec.Decode(data); err != nil {
   155  			if errors.Is(err, io.EOF) {
   156  				break
   157  			}
   158  
   159  			return fmt.Errorf("while reading file: %w", err)
   160  		}
   161  
   162  		for _, dataS := range data.Data {
   163  			destPath, err := safePath(dataFolder, dataS.DestPath)
   164  			if err != nil {
   165  				return err
   166  			}
   167  
   168  			if force || needsUpdate(destPath, dataS.SourceURL, logger) {
   169  				logger.Debugf("downloading %s in %s", dataS.SourceURL, destPath)
   170  
   171  				if err := downloadFile(dataS.SourceURL, destPath); err != nil {
   172  					return fmt.Errorf("while getting data: %w", err)
   173  				}
   174  			}
   175  		}
   176  	}
   177  
   178  	return nil
   179  }