github.com/grailbio/base@v0.0.11/cloud/spotfeed/spotfeed.go (about)

     1  // Package spotfeed is used for querying spot-data-feeds provided by AWS.
     2  // See https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/spot-data-feeds.html for a description of the
     3  // spot data feed format.
     4  //
     5  // This package provides two interfaces for interacting with the AWS spot data feed format for files hosted
     6  // on S3.
     7  //
     8  // 1. Fetch  - makes a single blocking call to fetch feed files for some historical period, then parses and
     9  //             returns the results as a single slice.
    10  // 2. Stream - creates a goroutine that asynchronously checks (once per 30mins by default) the specified S3
    11  //             location for new spot data feed files (and sends parsed entries into a channel provided to
    12  //             the user at invocation).
    13  //
    14  // This package also provides a LocalLoader which can perform a Fetch operation against feed files already
    15  // downloaded to local disk. This is often useful for analyzing spot usage over long periods of time, since
    16  // the download phase can take some time.
    17  package spotfeed
    18  
    19  import (
    20  	"compress/gzip"
    21  	"context"
    22  	"fmt"
    23  	"io/ioutil"
    24  	"log"
    25  	"os"
    26  	"path"
    27  	"strings"
    28  	"sync"
    29  	"time"
    30  
    31  	"github.com/aws/aws-sdk-go/aws"
    32  	"github.com/aws/aws-sdk-go/aws/request"
    33  	"github.com/aws/aws-sdk-go/aws/session"
    34  	"github.com/aws/aws-sdk-go/service/ec2"
    35  	"github.com/aws/aws-sdk-go/service/s3"
    36  	"github.com/aws/aws-sdk-go/service/s3/s3iface"
    37  	"github.com/grailbio/base/errors"
    38  	"github.com/grailbio/base/retry"
    39  	"golang.org/x/sync/errgroup"
    40  	"golang.org/x/time/rate"
    41  )
    42  
    43  var (
    44  	// RetryPolicy is used to retry failed S3 API calls.
    45  	retryPolicy = retry.Backoff(time.Second, 10*time.Second, 2)
    46  
    47  	// Used to rate limit S3 calls.
    48  	limiter = rate.NewLimiter(rate.Limit(16), 4)
    49  )
    50  
    51  type filterable interface {
    52  	accountId() string
    53  	timestamp() time.Time
    54  	version() int64
    55  }
    56  
    57  type filters struct {
    58  	// AccountId configures the Loader to only return Entry objects that belong to the specified
    59  	// 12-digit AWS account number (ID). If zero, no AccountId filter is applied.
    60  	AccountId string
    61  
    62  	// StartTime configures the Loader to only return Entry objects younger than StartTime.
    63  	// If nil, no StartTime filter is applied.
    64  	StartTime *time.Time
    65  
    66  	// EndTime configures the Loader to only return Entry objects older than EndTime.
    67  	// If nil, no EndTime filter is applied.
    68  	EndTime *time.Time
    69  
    70  	// Version configures the Loader to only return Entry objects with version equal to Version.
    71  	// If zero, no Version filter is applied, and if multiple feed versions declare the same
    72  	// instance-hour, de-duping based on the maximum value seen for that hour will be applied.
    73  	Version int64
    74  }
    75  
    76  // filter returns true if the entry does not match loader criteria and should be filtered out.
    77  func (l *filters) filter(f filterable) bool {
    78  	if l.AccountId != "" && f.accountId() != l.AccountId {
    79  		return true
    80  	}
    81  	if l.StartTime != nil && f.timestamp().Before(*l.StartTime) { // inclusive
    82  		return true
    83  	}
    84  	if l.EndTime != nil && !f.timestamp().Before(*l.EndTime) { // exclusive
    85  		return true
    86  	}
    87  	if l.Version != 0 && f.version() != l.Version {
    88  		return true
    89  	}
    90  	return false
    91  }
    92  
    93  // filterTruncatedStartTime performs the same checks as filter but truncates the start boundary down to the hour.
    94  func (l *filters) filterTruncatedStartTime(f filterable) bool {
    95  	if l.AccountId != "" && f.accountId() != l.AccountId {
    96  		return true
    97  	}
    98  	if l.StartTime != nil {
    99  		truncatedStart := l.StartTime.Truncate(time.Hour)
   100  		if f.timestamp().Before(truncatedStart) { // inclusive
   101  			return true
   102  		}
   103  	}
   104  	if l.EndTime != nil && !f.timestamp().Before(*l.EndTime) { // exclusive
   105  		return true
   106  	}
   107  	if l.Version != 0 && f.version() != l.Version {
   108  		return true
   109  	}
   110  	return false
   111  
   112  }
   113  
   114  type localFile struct {
   115  	*fileMeta
   116  	path string
   117  }
   118  
   119  func (f *localFile) read() ([]*Entry, error) {
   120  	fd, err := os.Open(f.path)
   121  	defer func() { _ = fd.Close() }()
   122  	if err != nil {
   123  		err = errors.E(err, fmt.Sprintf("failed to open local spot feed data file %s", f.path))
   124  		return nil, err
   125  	}
   126  
   127  	if f.IsGzip {
   128  		gz, err := gzip.NewReader(fd)
   129  		defer func() { _ = gz.Close() }()
   130  		if err != nil {
   131  			return nil, fmt.Errorf("failed to read gzipped file %s", f.Name)
   132  		}
   133  		return ParseFeedFile(gz, f.AccountId)
   134  	}
   135  
   136  	return ParseFeedFile(fd, f.AccountId)
   137  }
   138  
   139  type s3File struct {
   140  	*fileMeta
   141  	bucket, key string
   142  	client      s3iface.S3API
   143  }
   144  
   145  func (s *s3File) read(ctx context.Context) ([]*Entry, error) {
   146  	// Pull feed file from S3 with rate limiting and retries.
   147  	var output *s3.GetObjectOutput
   148  	for retries := 0; ; {
   149  		if err := limiter.Wait(ctx); err != nil {
   150  			return nil, err
   151  		}
   152  		var getObjErr error
   153  		if output, getObjErr = s.client.GetObjectWithContext(ctx, &s3.GetObjectInput{
   154  			Bucket: aws.String(s.bucket),
   155  			Key:    aws.String(s.key),
   156  		}); getObjErr != nil {
   157  			if !request.IsErrorThrottle(getObjErr) {
   158  				return nil, getObjErr
   159  			}
   160  			if err := retry.Wait(ctx, retryPolicy, retries); err != nil {
   161  				return nil, err
   162  			}
   163  			retries++
   164  			continue
   165  		}
   166  		break
   167  	}
   168  	// If the file is gzipped, unpack before attempting to read.
   169  	if s.IsGzip {
   170  		gz, err := gzip.NewReader(output.Body)
   171  		if err != nil {
   172  			return nil, fmt.Errorf("failed to read gzipped file s3://%s/%s", s.bucket, s.key)
   173  		}
   174  		defer func() { _ = gz.Close() }()
   175  		return ParseFeedFile(gz, s.AccountId)
   176  	}
   177  
   178  	return ParseFeedFile(output.Body, s.AccountId)
   179  }
   180  
   181  // Loader provides an API for pulling Spot Data Feed Entry objects from some repository.
   182  // The tolerateErr parameter configures how the Loader responds to errors parsing
   183  // individual files or entries; if true, the Loader will continue to parse and yield Entry
   184  // objects if an error is encountered during parsing.
   185  type Loader interface {
   186  	// Fetch performs a single blocking call to fetch a discrete set of Entry objects.
   187  	Fetch(ctx context.Context, tolerateErr bool) ([]*Entry, error)
   188  
   189  	// Stream asynchronously retrieves, parses and sends Entry objects on the returned channel.
   190  	// To graciously terminate the goroutine managing the Stream, the client terminates the given context.
   191  	Stream(ctx context.Context, tolerateErr bool) (<-chan *Entry, error)
   192  }
   193  
   194  type s3Loader struct {
   195  	Loader
   196  	filters
   197  
   198  	log     *log.Logger
   199  	client  s3iface.S3API
   200  	bucket  string
   201  	rootURI string
   202  }
   203  
   204  // commonFilePrefix returns the most specific prefix common to all spot feed data files that
   205  // match the loader criteria.
   206  func (s *s3Loader) commonFilePrefix() string {
   207  	if s.AccountId == "" {
   208  		return ""
   209  	}
   210  
   211  	if s.StartTime == nil || s.EndTime == nil || s.StartTime.Year() != s.EndTime.Year() {
   212  		return s.AccountId
   213  	}
   214  
   215  	if s.StartTime.Month() != s.EndTime.Month() {
   216  		return fmt.Sprintf("%s.%d", s.AccountId, s.StartTime.Year())
   217  	}
   218  
   219  	if s.StartTime.Day() != s.EndTime.Day() {
   220  		return fmt.Sprintf("%s.%d-%02d", s.AccountId, s.StartTime.Year(), s.StartTime.Month())
   221  	}
   222  
   223  	if s.StartTime.Hour() != s.EndTime.Hour() {
   224  		return fmt.Sprintf("%s.%d-%02d-%02d", s.AccountId, s.StartTime.Year(), s.StartTime.Month(), s.StartTime.Day())
   225  	}
   226  
   227  	return fmt.Sprintf("%s.%d-%02d-%02d-%02d", s.AccountId, s.StartTime.Year(), s.StartTime.Month(), s.StartTime.Day(), s.StartTime.Hour())
   228  }
   229  
   230  // timePrefix returns a prefix which matches the given time in UTC.
   231  func (s *s3Loader) timePrefix(t time.Time) string {
   232  	if s.AccountId == "" {
   233  		panic("nowPrefix cannot be given without an account id")
   234  	}
   235  
   236  	t = t.UTC()
   237  	return fmt.Sprintf("%s.%d-%02d-%02d-%02d", s.AccountId, t.Year(), t.Month(), t.Day(), t.Hour())
   238  }
   239  
   240  // path returns a prefix which joins the loader rootURI with the given uri.
   241  func (s *s3Loader) path(uri string) string {
   242  	if s.rootURI == "" {
   243  		return uri
   244  	} else {
   245  		return fmt.Sprintf("%s/%s", s.rootURI, uri)
   246  	}
   247  }
   248  
   249  // list queries the AWS S3 ListBucket API for feed files.
   250  func (s *s3Loader) list(ctx context.Context, startAfter string, tolerateErr bool) ([]*s3File, error) {
   251  	prefix := s.path(s.commonFilePrefix())
   252  
   253  	s3Files := make([]*s3File, 0)
   254  	var parseMetaErr error
   255  	if err := s.client.ListObjectsV2PagesWithContext(ctx, &s3.ListObjectsV2Input{
   256  		Bucket:     aws.String(s.bucket),
   257  		Prefix:     aws.String(prefix),
   258  		StartAfter: aws.String(startAfter),
   259  	}, func(output *s3.ListObjectsV2Output, lastPage bool) bool {
   260  		for _, object := range output.Contents {
   261  			filename := aws.StringValue(object.Key)
   262  			fileMeta, err := parseFeedFileName(filename)
   263  			if err != nil {
   264  				parseMetaErr = errors.E(err, fmt.Sprintf("failed to parse spot feed data file name %s", filename))
   265  				if tolerateErr {
   266  					s.log.Print(parseMetaErr)
   267  					continue
   268  				} else {
   269  					return false
   270  				}
   271  			}
   272  
   273  			// skips s3Files that do not match the loader criteria. Truncate the startTime of the filter to ensure that
   274  			// we do not skip files at hour HH:00 with a startTime of (i.e.) HH:30.
   275  			if s.filterTruncatedStartTime(fileMeta) {
   276  				s.log.Printf("%s does not pass fileMeta filter, skipping", filename)
   277  				continue
   278  			}
   279  			s3Files = append(s3Files, &s3File{
   280  				fileMeta,
   281  				s.bucket,
   282  				filename,
   283  				s.client,
   284  			})
   285  		}
   286  		return true
   287  	}); err != nil {
   288  		return nil, fmt.Errorf("list on path %s failed with error: %s", prefix, err)
   289  	}
   290  	if !tolerateErr && parseMetaErr != nil {
   291  		return nil, parseMetaErr
   292  	}
   293  	return s3Files, nil
   294  }
   295  
   296  // fetchAfter builds a list of S3 feed file objects using the S3 ListBucket API. It then concurrently
   297  // fetches and parses the feed files, observing rate and concurrency limits.
   298  func (s *s3Loader) fetchAfter(ctx context.Context, startAfter string, tolerateErr bool) ([]*Entry, error) {
   299  	s3Files, err := s.list(ctx, startAfter, tolerateErr)
   300  	if err != nil {
   301  		return nil, err
   302  	}
   303  
   304  	mu := &sync.Mutex{}
   305  	spotDataEntries := make([]*Entry, 0)
   306  	group, groupCtx := errgroup.WithContext(ctx)
   307  	for _, file := range s3Files {
   308  		file := file
   309  		group.Go(func() error {
   310  			if entries, err := file.read(groupCtx); err != nil {
   311  				err = errors.E(err, fmt.Sprintf("failed to parse spot feed data file s3://%s/%s", file.bucket, file.key))
   312  				if tolerateErr {
   313  					s.log.Printf("encountered error %s, tolerating and skipping file s3://%s/%s", err, file.bucket, file.key)
   314  					return nil
   315  				} else {
   316  					return err
   317  				}
   318  			} else {
   319  				mu.Lock()
   320  				spotDataEntries = append(spotDataEntries, entries...)
   321  				mu.Unlock()
   322  			}
   323  			return nil
   324  		})
   325  	}
   326  	if err := group.Wait(); err != nil {
   327  		return nil, err
   328  	}
   329  
   330  	filteredEntries := make([]*Entry, 0)
   331  	for _, e := range spotDataEntries {
   332  		if !s.filter(e) {
   333  			filteredEntries = append(filteredEntries, e)
   334  		}
   335  	}
   336  
   337  	return filteredEntries, nil
   338  }
   339  
   340  // Fetch makes a single blocking call to fetch feed files for some historical period,
   341  // then parses and returns the results as a single slice. The call attempts to start
   342  // from the first entry such that Key > l.StartTime and breaks when it encounters the
   343  // first entry such that Key > l.EndTime
   344  func (s *s3Loader) Fetch(ctx context.Context, tolerateErr bool) ([]*Entry, error) {
   345  	prefix := s.path(s.commonFilePrefix())
   346  	return s.fetchAfter(ctx, prefix, tolerateErr)
   347  }
   348  
   349  var (
   350  	// streamSleepDuration specifies how long to wait between calls to S3 ListBucket
   351  	streamSleepDuration = 30 * time.Minute
   352  )
   353  
   354  // Stream creates a goroutine that asynchronously checks (once per 30mins by default) the specified S3
   355  // location for new spot data feed files (and sends parsed entries into a channel provided to the user at invocation).
   356  // s3Loader must be configured with an account id to support the Stream interface. To stream events for multiple account ids
   357  // which share a feed bucket, create multiple s3Loader objects.
   358  // TODO: Allow caller to pass channel, allowing a single reader to manage multiple s3Loader.Stream calls.
   359  func (s *s3Loader) Stream(ctx context.Context, tolerateErr bool) (<-chan *Entry, error) {
   360  	if s.AccountId == "" {
   361  		return nil, fmt.Errorf("s3Loader must be configured with an account id to provide asynchronous event streaming")
   362  	}
   363  
   364  	entryChan := make(chan *Entry)
   365  	go func() {
   366  		startAfter := s.timePrefix(time.Now())
   367  		for {
   368  			if ctx.Err() != nil {
   369  				close(entryChan)
   370  				return
   371  			}
   372  
   373  			entries, err := s.fetchAfter(ctx, startAfter, tolerateErr)
   374  			if err != nil {
   375  				close(entryChan)
   376  				return
   377  			}
   378  
   379  			for _, entry := range entries {
   380  				entryChan <- entry
   381  			}
   382  
   383  			if len(entries) != 0 {
   384  				finalEntry := entries[len(entries)-1]
   385  				startAfter = s.timePrefix(finalEntry.Timestamp)
   386  			}
   387  
   388  			time.Sleep(streamSleepDuration)
   389  		}
   390  	}()
   391  
   392  	return entryChan, nil
   393  }
   394  
   395  // NewSpotFeedLoader returns a Loader which queries the spot data feed subscription using the given session and
   396  // returns a Loader which queries the S3 API for feed files (if a subscription does exist).
   397  // NewSpotFeedLoader will return an error if the spot data feed subscription is missing.
   398  func NewSpotFeedLoader(sess *session.Session, log *log.Logger, startTime, endTime *time.Time, version int64) (Loader, error) {
   399  	ec2api := ec2.New(sess)
   400  	resp, err := ec2api.DescribeSpotDatafeedSubscription(&ec2.DescribeSpotDatafeedSubscriptionInput{})
   401  	if err != nil {
   402  		return nil, errors.E("DescribeSpotDatafeedSubscription", err)
   403  	}
   404  	bucket := aws.StringValue(resp.SpotDatafeedSubscription.Bucket)
   405  	rootURI := aws.StringValue(resp.SpotDatafeedSubscription.Prefix)
   406  	accountID := aws.StringValue(resp.SpotDatafeedSubscription.OwnerId)
   407  	return NewS3Loader(bucket, rootURI, s3.New(sess), log, accountID, startTime, endTime, version), nil
   408  }
   409  
   410  // NewS3Loader returns a Loader which queries the S3 API for feed files. It supports the Fetch and Stream APIs.
   411  func NewS3Loader(bucket, rootURI string, client s3iface.S3API, log *log.Logger, accountId string, startTime, endTime *time.Time, version int64) Loader {
   412  	// Remove any trailing slash from bucket and trailing/leading slash from rootURI.
   413  	if strings.HasSuffix(bucket, "/") {
   414  		bucket = bucket[:len(bucket)-1]
   415  	}
   416  	if strings.HasPrefix(rootURI, "/") {
   417  		rootURI = rootURI[1:]
   418  	}
   419  	if strings.HasSuffix(rootURI, "/") {
   420  		rootURI = rootURI[:len(rootURI)-1]
   421  	}
   422  
   423  	return &s3Loader{
   424  		filters: filters{
   425  			AccountId: accountId,
   426  			StartTime: startTime,
   427  			EndTime:   endTime,
   428  			Version:   version,
   429  		},
   430  		log:     log,
   431  		client:  client,
   432  		bucket:  bucket,
   433  		rootURI: rootURI,
   434  	}
   435  }
   436  
   437  type localLoader struct {
   438  	Loader
   439  	filters
   440  
   441  	log      *log.Logger
   442  	rootPath string
   443  }
   444  
   445  // Fetch queries the local filesystem for feed files at the given path which match the given filename filters.
   446  // It then parses, filters again and returns the Entry objects.
   447  func (l *localLoader) Fetch(ctx context.Context, tolerateErr bool) ([]*Entry, error) {
   448  	// Iterate over files in directory, filter and build slice of feed files.
   449  	spotFiles := make([]*localFile, 0)
   450  	items, _ := ioutil.ReadDir(l.rootPath)
   451  	for _, item := range items {
   452  		// Skip subdirectories.
   453  		if item.IsDir() {
   454  			continue
   455  		}
   456  
   457  		p := path.Join(l.rootPath, item.Name())
   458  		fileMeta, err := parseFeedFileName(item.Name())
   459  		if err != nil {
   460  			err = errors.E(err, fmt.Sprintf("failed to parse spot feed data file name %s", p))
   461  			if tolerateErr {
   462  				l.log.Printf("encountered error %s, tolerating and skipping file %s", err, p)
   463  				continue
   464  			} else {
   465  				return nil, err
   466  			}
   467  		}
   468  
   469  		// skips files that do not match the loader criteria. Truncate the startTime of the filter to ensure that
   470  		// we do not skip files at hour HH:00 with a startTime of (i.e.) HH:30.
   471  		if l.filterTruncatedStartTime(fileMeta) {
   472  			l.log.Printf("%s does not pass fileMeta filter, skipping", p)
   473  			continue
   474  		}
   475  
   476  		spotFiles = append(spotFiles, &localFile{
   477  			fileMeta,
   478  			p,
   479  		})
   480  	}
   481  
   482  	// Concurrently iterate over spot data feed files and build a slice of entries.
   483  	mu := &sync.Mutex{}
   484  	spotDataEntries := make([]*Entry, 0)
   485  	group, _ := errgroup.WithContext(ctx)
   486  	for _, file := range spotFiles {
   487  		file := file
   488  		group.Go(func() error {
   489  			if entries, err := file.read(); err != nil {
   490  				err = errors.E(err, fmt.Sprintf("failed to parse spot feed data file %s", file.path))
   491  				if tolerateErr {
   492  					l.log.Printf("encountered error %s, tolerating and skipping file %s", err, file.path)
   493  					return nil
   494  				} else {
   495  					return err
   496  				}
   497  			} else {
   498  				mu.Lock()
   499  				spotDataEntries = append(spotDataEntries, entries...)
   500  				mu.Unlock()
   501  			}
   502  			return nil
   503  		})
   504  	}
   505  	if err := group.Wait(); err != nil {
   506  		return nil, err
   507  	}
   508  
   509  	// Filter entries
   510  	filteredEntries := make([]*Entry, 0)
   511  	for _, e := range spotDataEntries {
   512  		if !l.filter(e) {
   513  			filteredEntries = append(filteredEntries, e)
   514  		}
   515  	}
   516  
   517  	return filteredEntries, nil
   518  }
   519  
   520  // NewLocalLoader returns a Loader which fetches feed files from a path on the local filesystem. It does not support
   521  // the Stream API.
   522  func NewLocalLoader(path string, log *log.Logger, accountId string, startTime, endTime *time.Time, version int64) Loader {
   523  	return &localLoader{
   524  		filters: filters{
   525  			AccountId: accountId,
   526  			StartTime: startTime,
   527  			EndTime:   endTime,
   528  			Version:   version,
   529  		},
   530  		log:      log,
   531  		rootPath: path,
   532  	}
   533  }