
     1  /*
     3  Copyright (c) 2024 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     6  */
     8  package status
    10  import (
    11  	"context"
    12  	"fmt"
    13  	"math"
    14  	"sync"
    15  	"time"
    17  	""
    18  	""
    19  )
    21  var (
    22  	_ async.Interceptor = (*TrackedAction)(nil)
    23  )
    25  // NewTrackedAction returns a new tracked action.
    26  func NewTrackedAction(serviceName string, opts ...TrackedActionOption) *TrackedAction {
    27  	ta := &TrackedAction{
    28  		ServiceName: serviceName,
    29  	}
    30  	_ = (&ta.TrackedActionConfig).Resolve(context.Background())
    31  	for _, opt := range opts {
    32  		opt(ta)
    33  	}
    34  	return ta
    35  }
    37  // TrackedActionOption mutates a tracked action.
    38  type TrackedActionOption func(*TrackedAction)
    40  // OptTrackedActionConfig sets the tracked action config.
    41  func OptTrackedActionConfig(cfg TrackedActionConfig) TrackedActionOption {
    42  	return func(ta *TrackedAction) {
    43  		ta.TrackedActionConfig = cfg
    44  	}
    45  }
    47  // TrackedAction is a wrapper for action that tracks a rolling
    48  // window of history based on the configured expiration.
    49  type TrackedAction struct {
    50  	TrackedActionConfig
    51  	sync.Mutex
    53  	ServiceName string
    55  	nowProvider func() time.Time
    56  	errors      []ErrorInfo
    57  	requests    []RequestInfo
    58  }
    60  // Intercept implements async.Interceptor.
    61  func (t *TrackedAction) Intercept(action Actioner) Actioner {
    62  	return ActionerFunc(func(ctx context.Context, args interface{}) (output interface{}, err error) {
    63  		defer func() {
    64  			if r := recover(); r != nil {
    65  				err = ex.Append(err, ex.New(r))
    66  			}
    67  			t.CleanOldRequests()
    68  			if err != nil {
    69  				t.AddErroredRequest(args)
    70  			} else {
    71  				t.AddSuccessfulRequest()
    72  			}
    73  		}()
    74  		output, err = action.Action(ctx, args)
    75  		return
    76  	})
    77  }
    79  // GetStatus gets the status for the tracker.
    80  //
    81  // It is safe to call concurrently from multiple goroutines.
    82  func (t *TrackedAction) GetStatus() (info Info) {
    83  	t.Lock()
    84  	defer t.Unlock()
    86  	t.cleanOldRequestsUnsafe()
    87  	info.Name = t.ServiceName
    88  	info.Status = t.getStatusSignalUnsafe()
    90  	errorBreakdown := make(map[string]int)
    91  	if info.Status == SignalYellow || info.Status == SignalRed {
    92  		for _, errorInfo := range t.errors {
    93  			errorBreakdown[t.formatArgs(errorInfo.Args)]++
    94  		}
    95  	}
    96  	info.Details = Details{
    97  		ErrorCount:     len(t.errors),
    98  		RequestCount:   len(t.requests),
    99  		ErrorBreakdown: errorBreakdown,
   100  	}
   101  	return
   102  }
   104  // GetStatusSignal returns the current status signal.
   105  //
   106  // It is safe to call concurrently from multiple goroutines.
   107  func (t *TrackedAction) GetStatusSignal() (status Signal) {
   108  	t.Lock()
   109  	status = t.getStatusSignalUnsafe()
   110  	t.Unlock()
   111  	return
   112  }
   114  // AddErroredRequest adds an errored request.
   115  //
   116  // It is safe to call concurrently from multiple goroutines.
   117  func (t *TrackedAction) AddErroredRequest(args interface{}) {
   118  	t.Lock()
   119  	defer t.Unlock()
   120  	t.errors = append(t.errors, ErrorInfo{
   121  		Args: args,
   122  		RequestInfo: RequestInfo{
   123  			RequestTime:,
   124  		},
   125  	})
   126  }
   128  // AddSuccessfulRequest adds a successful request.
   129  //
   130  // It is safe to call concurrently from multiple goroutines.
   131  func (t *TrackedAction) AddSuccessfulRequest() {
   132  	t.Lock()
   133  	defer t.Unlock()
   134  	t.requests = append(t.requests, RequestInfo{RequestTime:})
   135  }
   137  // CleanOldRequests is an action delegate that removes expired requests
   138  // from the tracker
   139  //
   140  // It is safe to call concurrently from multiple goroutines.
   141  func (t *TrackedAction) CleanOldRequests() {
   142  	t.Lock()
   143  	defer t.Unlock()
   144  	t.cleanOldRequestsUnsafe()
   145  }
   147  //
   148  // Private - Internal
   149  //
   151  func (t *TrackedAction) formatArgs(args interface{}) string {
   152  	switch typed := args.(type) {
   153  	case string:
   154  		return typed
   155  	case []byte:
   156  		return string(typed)
   157  	case []rune:
   158  		return string(typed)
   159  	case fmt.Stringer:
   160  		return typed.String()
   161  	default:
   162  		return "unknown"
   163  	}
   164  }
   166  // getStatusSignalUnsafe gets the specific signal (green, yellow, or red)
   167  // for the tracker.
   168  func (t *TrackedAction) getStatusSignalUnsafe() (status Signal) {
   169  	status = SignalGreen
   170  	requestCount := len(t.requests)
   171  	errorCount := float64(len(t.errors))
   172  	if errorCount >= t.redErrorCount(requestCount) {
   173  		status = SignalRed
   174  	} else if errorCount >= t.yellowErrorCount(requestCount) {
   175  		status = SignalYellow
   176  	}
   177  	return status
   178  }
   180  func (t *TrackedAction) cleanOldRequestsUnsafe() {
   181  	nowUTC :=
   182  	var filteredErrors []ErrorInfo
   183  	for _, errorInfo := range t.errors {
   184  		if nowUTC.Sub(errorInfo.RequestTime) < t.ExpirationOrDefault() {
   185  			filteredErrors = append(filteredErrors, errorInfo)
   186  		}
   187  	}
   189  	t.errors = filteredErrors
   190  	var filteredRequests []RequestInfo
   191  	for _, requestInfo := range t.requests {
   192  		if nowUTC.Sub(requestInfo.RequestTime) < t.ExpirationOrDefault() {
   193  			filteredRequests = append(filteredRequests, requestInfo)
   194  		}
   195  	}
   196  	t.requests = filteredRequests
   197  }
   199  // redErrorCount returns the expected threshold for what is
   200  // considered a "red" signal status based on either the baseline `RedRequestCount`
   201  // or the RedRequestPercentage applied to the current request count.
   202  //
   203  // It is meant to scale the threshold to the volume of the calls
   204  // to the tracked action.
   205  func (t *TrackedAction) redErrorCount(requestCount int) float64 {
   206  	return math.Max(
   207  		float64(t.RedRequestCount),
   208  		t.RedRequestPercentage*float64(requestCount),
   209  	)
   210  }
   212  // yellowErrorCount returns the expected threshold for what is
   213  // considered a "yellow" signal status based on either the baseline `YellowRequestCount`
   214  // or the YellowRequestPercentage applied to the current request count.
   215  //
   216  // It is meant to scale the threshold to the volume of the calls
   217  // to the tracked action
   218  func (t *TrackedAction) yellowErrorCount(requestCount int) float64 {
   219  	return math.Max(
   220  		float64(t.YellowRequestCount),
   221  		t.YellowRequestPercentage*float64(requestCount),
   222  	)
   223  }
   225  // now returns the current time.
   226  func (t *TrackedAction) now() time.Time {
   227  	if t.nowProvider != nil {
   228  		return t.nowProvider()
   229  	}
   230  	return time.Now().UTC()
   231  }