github.com/grailbio/base@v0.0.11/eventlog/cloudwatch/cloudwatch_test.go (about)

     1  // Copyright 2020 GRAIL, Inc. All rights reserved.
     2  // Use of this source code is governed by the Apache 2.0
     3  // license that can be found in the LICENSE file.
     4  
     5  package cloudwatch
     6  
     7  import (
     8  	"context"
     9  	"encoding/json"
    10  	"fmt"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/aws/aws-sdk-go/aws/request"
    16  	"github.com/aws/aws-sdk-go/service/cloudwatchlogs"
    17  	"github.com/aws/aws-sdk-go/service/cloudwatchlogs/cloudwatchlogsiface"
    18  	"github.com/grailbio/base/eventlog/internal/marshal"
    19  )
    20  
    21  const testGroup = "testGroup"
    22  const testStream = "testStream"
    23  const typ = "testEventType"
    24  const k = "testFieldKey"
    25  
    26  type logsAPIFake struct {
    27  	cloudwatchlogsiface.CloudWatchLogsAPI
    28  
    29  	groupInput   *cloudwatchlogs.CreateLogGroupInput
    30  	streamInput  *cloudwatchlogs.CreateLogStreamInput
    31  	eventsInputs []*cloudwatchlogs.PutLogEventsInput
    32  
    33  	sequenceMu sync.Mutex
    34  	sequence   int
    35  }
    36  
    37  func (f *logsAPIFake) CreateLogGroupWithContext(ctx context.Context,
    38  	input *cloudwatchlogs.CreateLogGroupInput,
    39  	opts ...request.Option) (*cloudwatchlogs.CreateLogGroupOutput, error) {
    40  
    41  	f.groupInput = input
    42  	return nil, nil
    43  }
    44  
    45  func (f *logsAPIFake) CreateLogStreamWithContext(ctx context.Context,
    46  	input *cloudwatchlogs.CreateLogStreamInput,
    47  	opts ...request.Option) (*cloudwatchlogs.CreateLogStreamOutput, error) {
    48  
    49  	f.streamInput = input
    50  	return nil, nil
    51  }
    52  
    53  func (f *logsAPIFake) PutLogEventsWithContext(ctx context.Context,
    54  	input *cloudwatchlogs.PutLogEventsInput,
    55  	opts ...request.Option) (*cloudwatchlogs.PutLogEventsOutput, error) {
    56  
    57  	var ts *int64
    58  	for _, event := range input.LogEvents {
    59  		if ts != nil && *event.Timestamp < *ts {
    60  			return nil, &cloudwatchlogs.InvalidParameterException{}
    61  		}
    62  		ts = event.Timestamp
    63  	}
    64  
    65  	nextSequenceToken, err := func() (*string, error) {
    66  		f.sequenceMu.Lock()
    67  		defer f.sequenceMu.Unlock()
    68  		if f.sequence != 0 {
    69  			sequenceToken := fmt.Sprintf("%d", f.sequence)
    70  			if input.SequenceToken == nil || sequenceToken != *input.SequenceToken {
    71  				return nil, &cloudwatchlogs.InvalidSequenceTokenException{
    72  					ExpectedSequenceToken: &sequenceToken,
    73  				}
    74  			}
    75  		}
    76  		f.sequence++
    77  		nextSequenceToken := fmt.Sprintf("%d", f.sequence)
    78  		return &nextSequenceToken, nil
    79  	}()
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  
    84  	f.eventsInputs = append(f.eventsInputs, input)
    85  	return &cloudwatchlogs.PutLogEventsOutput{
    86  		NextSequenceToken: nextSequenceToken,
    87  	}, nil
    88  }
    89  
    90  func (f *logsAPIFake) logEvents() []*cloudwatchlogs.InputLogEvent {
    91  	var events []*cloudwatchlogs.InputLogEvent
    92  	for _, input := range f.eventsInputs {
    93  		events = append(events, input.LogEvents...)
    94  	}
    95  	return events
    96  }
    97  
    98  func (f *logsAPIFake) incrNextSequence() {
    99  	f.sequenceMu.Lock()
   100  	defer f.sequenceMu.Unlock()
   101  	f.sequence++
   102  }
   103  
   104  // TestEvent verifies that logged events are sent to CloudWatch correctly.
   105  func TestEvent(t *testing.T) {
   106  	const N = 1000
   107  
   108  	if eventBufferSize < N {
   109  		panic("keep N <= eventBufferSize to make sure no events are dropped")
   110  	}
   111  
   112  	// Note: Access to nowUnixMillis is unsynchronized because now() is only called in Event(),
   113  	// not in any background or asynchronous goroutine.
   114  	var nowUnixMillis int64 = 1600000000000 // Arbitrary time in 2020.
   115  	now := func() time.Time {
   116  		return time.UnixMilli(nowUnixMillis)
   117  	}
   118  
   119  	// Log events.
   120  	cw := &logsAPIFake{}
   121  	e := NewEventer(cw, testGroup, testStream, OptNow(now))
   122  	wantTimestamps := make([]time.Time, N)
   123  	for i := 0; i < N; i++ {
   124  		k := fmt.Sprintf("k%d", i)
   125  		e.Event(typ, k, i)
   126  		wantTimestamps[i] = now()
   127  		nowUnixMillis += time.Hour.Milliseconds()
   128  	}
   129  	e.Close()
   130  
   131  	// Make sure events get to CloudWatch with the right contents and in order.
   132  	events := cw.logEvents()
   133  	if got, want := len(events), N; got != want {
   134  		t.Errorf("got %v, want %v", got, want)
   135  	}
   136  	for i, event := range events {
   137  		k := fmt.Sprintf("k%d", i)
   138  		m, err := marshal.Marshal(typ, []interface{}{k, i})
   139  		if err != nil {
   140  			t.Fatalf("error marshaling event: %v", err)
   141  		}
   142  		if got, want := *event.Message, m; got != want {
   143  			t.Errorf("got %v, want %v", got, want)
   144  			continue
   145  		}
   146  		if got, want := time.UnixMilli(*event.Timestamp), wantTimestamps[i]; !want.Equal(got) {
   147  			t.Errorf("got %v, want %v", got, want)
   148  			continue
   149  		}
   150  	}
   151  }
   152  
   153  // TestBufferFull verifies that exceeding the event buffer leads to, at worst,
   154  // dropped events. Events that are not dropped should still be logged in order.
   155  func TestBufferFull(t *testing.T) {
   156  	const N = 100 * 1000
   157  
   158  	// Log many events, overwhelming buffer.
   159  	cw := &logsAPIFake{}
   160  	e := NewEventer(cw, testGroup, testStream)
   161  	for i := 0; i < N; i++ {
   162  		e.Event(typ, k, i)
   163  	}
   164  	e.Close()
   165  
   166  	events := cw.logEvents()
   167  	if N < len(events) {
   168  		t.Fatalf("more events sent to CloudWatch than were logged: %d < %d", N, len(events))
   169  	}
   170  	assertOrdered(t, events)
   171  }
   172  
   173  // TestInvalidSequenceToken verifies that we recover if our sequence token gets
   174  // out of sync. This should not happen, as we should be the only thing writing
   175  // to a given log stream, but we try to recover anyway.
   176  func TestInvalidSequenceToken(t *testing.T) {
   177  	cw := &logsAPIFake{}
   178  	e := NewEventer(cw, testGroup, testStream)
   179  
   180  	e.Event(typ, k, 0)
   181  	e.sync()
   182  	cw.incrNextSequence()
   183  	e.Event(typ, k, 1)
   184  	e.sync()
   185  	e.Event(typ, k, 2)
   186  	e.sync()
   187  	e.Close()
   188  
   189  	events := cw.logEvents()
   190  	if 3 < len(events) {
   191  		t.Fatalf("more events sent to CloudWatch than were logged: 3 < %d", len(events))
   192  	}
   193  	if len(events) < 2 {
   194  		t.Errorf("did not successfully re-sync sequence token")
   195  	}
   196  	assertOrdered(t, events)
   197  }
   198  
   199  // assertOrdered asserts that the values of field k are increasing for events.
   200  // This is how we construct events sent to the Eventer, so we use this
   201  // verify that the events sent to the CloudWatch Logs API are ordered correctly.
   202  func assertOrdered(t *testing.T, events []*cloudwatchlogs.InputLogEvent) {
   203  	t.Helper()
   204  	last := -1
   205  	for _, event := range events {
   206  		var m map[string]interface{}
   207  		if err := json.Unmarshal([]byte(*event.Message), &m); err != nil {
   208  			t.Fatalf("could not unmarshal event message: %v", err)
   209  		}
   210  		v, ok := m[k]
   211  		if !ok {
   212  			t.Errorf("event message does not contain test key %q: %s", k, *event.Message)
   213  			continue
   214  		}
   215  		// All numeric values are unmarshaled as float64, so we need to convert
   216  		// back to int.
   217  		vi := int(v.(float64))
   218  		if vi <= last {
   219  			t.Errorf("event out of order; expected %d < %d", last, vi)
   220  			continue
   221  		}
   222  	}
   223  }