github.com/grailbio/base@v0.0.11/cloud/spotadvisor/spotadvisor.go (about)

     1  // Copyright 2021 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package spotadvisor provides an interface for utilizing spot instance
     6  // interrupt rate data and savings data from AWS.
     7  package spotadvisor
     8  
     9  import (
    10  	"encoding/json"
    11  	"fmt"
    12  	"net/http"
    13  	"sync"
    14  	"time"
    15  )
    16  
    17  var spotAdvisorDataUrl = "https://spot-bid-advisor.s3.amazonaws.com/spot-advisor-data.json"
    18  
    19  const (
    20  	// Spot Advisor data is only updated a few times a day, so we just refresh once an hour.
    21  	// Might need to revisit this value if data is updated more frequently in the future.
    22  	defaultRefreshInterval = 1 * time.Hour
    23  	defaultRequestTimeout  = 10 * time.Second
    24  
    25  	Linux   = OsType("Linux")
    26  	Windows = OsType("Windows")
    27  )
    28  
    29  // These need to be in their own const block to ensure iota starts at 0.
    30  const (
    31  	ZeroToFivePct InterruptRange = iota
    32  	FiveToTenPct
    33  	TenToFifteenPct
    34  	FifteenToTwentyPct
    35  	GreaterThanTwentyPct
    36  )
    37  
    38  // These need to be in their own const block to ensure iota starts at 0.
    39  const (
    40  	LessThanFivePct InterruptProbability = iota
    41  	LessThanTenPct
    42  	LessThanFifteenPct
    43  	LessThanTwentyPct
    44  	Any
    45  )
    46  
    47  type interruptRange struct {
    48  	Label string `json:"label"`
    49  	Index int    `json:"index"`
    50  	Dots  int    `json:"dots"`
    51  	Max   int    `json:"max"`
    52  }
    53  
    54  type instanceType struct {
    55  	Cores int     `json:"cores"`
    56  	Emr   bool    `json:"emr"`
    57  	RamGb float32 `json:"ram_gb"`
    58  }
    59  
    60  type instanceData struct {
    61  	RangeIdx int `json:"r"`
    62  	Savings  int `json:"s"`
    63  }
    64  
    65  type osGroups struct {
    66  	Windows map[string]instanceData `json:"Windows"`
    67  	Linux   map[string]instanceData `json:"Linux"`
    68  }
    69  
    70  type advisorData struct {
    71  	Ranges []interruptRange `json:"ranges"`
    72  	// key is an EC2 instance type name like "r5a.large"
    73  	InstanceTypes map[string]instanceType `json:"instance_types"`
    74  	// key is an AWS region name like "us-west-2"
    75  	SpotAdvisor map[string]osGroups `json:"spot_advisor"`
    76  }
    77  
    78  type aggKey struct {
    79  	ot OsType
    80  	ar AwsRegion
    81  	ip InterruptProbability
    82  }
    83  
    84  func (k aggKey) String() string {
    85  	return fmt.Sprintf("{%s, %s, %s}", k.ot, k.ar, k.ip)
    86  }
    87  
    88  // OsType should only be used via the pre-defined constants in this package.
    89  type OsType string
    90  
    91  // AwsRegion is an AWS region name like "us-west-2".
    92  type AwsRegion string
    93  
    94  // InstanceType is an EC2 instance type name like "r5a.large".
    95  type InstanceType string
    96  
    97  // InterruptRange is the AWS defined interrupt range for an instance type; it
    98  // should only be used via the pre-defined constants in this package.
    99  type InterruptRange int
   100  
   101  func (ir InterruptRange) String() string {
   102  	switch ir {
   103  	case ZeroToFivePct:
   104  		return "O-5%"
   105  	case FiveToTenPct:
   106  		return "5-10%"
   107  	case TenToFifteenPct:
   108  		return "10-15%"
   109  	case FifteenToTwentyPct:
   110  		return "15-20%"
   111  	case GreaterThanTwentyPct:
   112  		return "> 20%"
   113  	default:
   114  		return "invalid interrupt range"
   115  	}
   116  }
   117  
   118  // InterruptProbability is an upper bound used to indicate multiple interrupt
   119  // ranges; it should only be used via the pre-defined constants in this package.
   120  type InterruptProbability int
   121  
   122  func (ir InterruptProbability) String() string {
   123  	switch ir {
   124  	case LessThanFivePct:
   125  		return "< 5%"
   126  	case LessThanTenPct:
   127  		return "< 10%"
   128  	case LessThanFifteenPct:
   129  		return "< 15%"
   130  	case LessThanTwentyPct:
   131  		return "< 20%"
   132  	case Any:
   133  		return "Any"
   134  	default:
   135  		return "invalid interrupt probability"
   136  	}
   137  }
   138  
   139  // SpotAdvisor provides an interface for utilizing spot instance interrupt rate
   140  // data and savings data from AWS.
   141  type SpotAdvisor struct {
   142  	mu sync.RWMutex
   143  	// rawData is the decoded spot advisor json response
   144  	rawData advisorData
   145  	// aggData maps each aggKey to a slice of instance types aggregated by interrupt
   146  	// probability. For example, if aggKey.ip=LessThanTenPct, then the mapped value
   147  	// would contain all instance types which have an interrupt range of
   148  	// LessThanFivePct or FiveToTenPct.
   149  	aggData map[aggKey][]string
   150  
   151  	// TODO: incorporate spot advisor savings data
   152  }
   153  
   154  // SimpleLogger is a bare-bones logger interface which allows many logger
   155  // implementations to be used with SpotAdvisor. The default Go log.Logger and
   156  // grailbio/base/log.Logger implement this interface.
   157  type SimpleLogger interface {
   158  	Printf(string, ...interface{})
   159  }
   160  
   161  // NewSpotAdvisor initializes and returns a SpotAdvisor instance. If
   162  // initialization fails, a nil SpotAdvisor is returned with an error. The
   163  // underlying data is asynchronously updated, until the done channel is closed.
   164  // Errors during updates are non-fatal and will not prevent future updates.
   165  func NewSpotAdvisor(log SimpleLogger, done <-chan struct{}) (*SpotAdvisor, error) {
   166  	sa := SpotAdvisor{}
   167  	// initial load
   168  	if err := sa.refresh(); err != nil {
   169  		return nil, fmt.Errorf("error fetching spot advisor data: %s", err)
   170  	}
   171  
   172  	go func() {
   173  		ticker := time.NewTicker(defaultRefreshInterval)
   174  		defer ticker.Stop()
   175  		for {
   176  			select {
   177  			case <-done:
   178  				return
   179  			case <-ticker.C:
   180  				if err := sa.refresh(); err != nil {
   181  					log.Printf("error refreshing spot advisor data (will try again later): %s", err)
   182  				}
   183  			}
   184  		}
   185  	}()
   186  	return &sa, nil
   187  }
   188  
   189  func (sa *SpotAdvisor) refresh() (err error) {
   190  	// fetch
   191  	client := &http.Client{Timeout: defaultRequestTimeout}
   192  	resp, err := client.Get(spotAdvisorDataUrl)
   193  	if err != nil {
   194  		return err
   195  	}
   196  	if resp.StatusCode != http.StatusOK {
   197  		return fmt.Errorf("GET %s response StatusCode: %s", spotAdvisorDataUrl, http.StatusText(resp.StatusCode))
   198  	}
   199  	var rawData advisorData
   200  	err = json.NewDecoder(resp.Body).Decode(&rawData)
   201  	if err != nil {
   202  		return err
   203  	}
   204  	err = resp.Body.Close()
   205  	if err != nil {
   206  		return err
   207  	}
   208  
   209  	// update internal data structures
   210  	aggData := make(map[aggKey][]string)
   211  	for r, o := range rawData.SpotAdvisor {
   212  		region := AwsRegion(r)
   213  		// transform the raw data so that the values of aggData will contain just the instances in a given range
   214  		for instance, data := range o.Linux {
   215  			k := aggKey{Linux, region, InterruptProbability(data.RangeIdx)}
   216  			aggData[k] = append(aggData[k], instance)
   217  		}
   218  		for instance, data := range o.Windows {
   219  			k := aggKey{Windows, region, InterruptProbability(data.RangeIdx)}
   220  			aggData[k] = append(aggData[k], instance)
   221  		}
   222  
   223  		// aggregate instances by the upper bound interrupt probability of each key
   224  		for i := 1; i <= int(Any); i++ {
   225  			{
   226  				lk := aggKey{Linux, region, InterruptProbability(i)}
   227  				lprevk := aggKey{Linux, region, InterruptProbability(i - 1)}
   228  				aggData[lk] = append(aggData[lk], aggData[lprevk]...)
   229  			}
   230  			{
   231  				wk := aggKey{Windows, region, InterruptProbability(i)}
   232  				wprevk := aggKey{Windows, region, InterruptProbability(i - 1)}
   233  				aggData[wk] = append(aggData[wk], aggData[wprevk]...)
   234  			}
   235  		}
   236  	}
   237  	sa.mu.Lock()
   238  	sa.rawData = rawData
   239  	sa.aggData = aggData
   240  	sa.mu.Unlock()
   241  	return nil
   242  }
   243  
   244  // FilterByMaxInterruptProbability returns a subset of the input candidates by
   245  // removing instance types which have a probability of interruption greater than ip.
   246  func (sa *SpotAdvisor) FilterByMaxInterruptProbability(ot OsType, ar AwsRegion, candidates []string, ip InterruptProbability) (filtered []string, err error) {
   247  	if ip == Any {
   248  		// There's a chance we may not have spot advisor data for some instances in
   249  		// the candidates, so just return as is without doing a set difference.
   250  		return candidates, nil
   251  	}
   252  	allowed, err := sa.GetInstancesWithMaxInterruptProbability(ot, ar, ip)
   253  	if err != nil {
   254  		return nil, err
   255  	}
   256  	for _, c := range candidates {
   257  		if allowed[c] {
   258  			filtered = append(filtered, c)
   259  		}
   260  	}
   261  	return filtered, nil
   262  }
   263  
   264  // GetInstancesWithMaxInterruptProbability returns the set of spot instance types
   265  // with an interrupt probability less than or equal to ip, with the given OS and region.
   266  func (sa *SpotAdvisor) GetInstancesWithMaxInterruptProbability(ot OsType, region AwsRegion, ip InterruptProbability) (map[string]bool, error) {
   267  	if ip < LessThanFivePct || ip > Any {
   268  		return nil, fmt.Errorf("invalid InterruptProbability: %d", ip)
   269  	}
   270  	k := aggKey{ot, region, ip}
   271  	sa.mu.RLock()
   272  	defer sa.mu.RUnlock()
   273  	ts, ok := sa.aggData[k]
   274  	if !ok {
   275  		return nil, fmt.Errorf("no spot advisor data for: %s", k)
   276  	}
   277  	tsMap := make(map[string]bool, len(ts))
   278  	for _, t := range ts {
   279  		tsMap[t] = true
   280  	}
   281  	return tsMap, nil
   282  }
   283  
   284  // GetInterruptRange returns the interrupt range for the instance type with the
   285  // given OS and region.
   286  func (sa *SpotAdvisor) GetInterruptRange(ot OsType, ar AwsRegion, it InstanceType) (InterruptRange, error) {
   287  	sa.mu.RLock()
   288  	defer sa.mu.RUnlock()
   289  	osg, ok := sa.rawData.SpotAdvisor[string(ar)]
   290  	if !ok {
   291  		return -1, fmt.Errorf("no spot advisor data for: %s", ar)
   292  	}
   293  	var m map[string]instanceData
   294  	switch ot {
   295  	case Linux:
   296  		m = osg.Linux
   297  	case Windows:
   298  		m = osg.Windows
   299  	default:
   300  		return -1, fmt.Errorf("invalid OS: %s", ot)
   301  	}
   302  
   303  	d, ok := m[string(it)]
   304  	if !ok {
   305  		return -1, fmt.Errorf("no spot advisor data for %s instance type '%s' in %s", ot, it, ar)
   306  	}
   307  	return InterruptRange(d.RangeIdx), nil
   308  }
   309  
   310  // GetMaxInterruptProbability is a helper method to easily get the max interrupt
   311  // probability of an instance type (i.e. the upper bound of the interrupt range
   312  // for that instance type).
   313  func (sa *SpotAdvisor) GetMaxInterruptProbability(ot OsType, ar AwsRegion, it InstanceType) (InterruptProbability, error) {
   314  	ir, err := sa.GetInterruptRange(ot, ar, it)
   315  	return InterruptProbability(ir), err
   316  }