go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/logdog/server/collector/utils_test.go (about)

     1  // Copyright 2016 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package collector
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"crypto/sha256"
    21  	"encoding/hex"
    22  	"errors"
    23  	"fmt"
    24  	"sort"
    25  	"strings"
    26  	"sync"
    27  
    28  	"google.golang.org/protobuf/types/known/timestamppb"
    29  
    30  	"go.chromium.org/luci/common/clock"
    31  	"go.chromium.org/luci/logdog/api/logpb"
    32  	"go.chromium.org/luci/logdog/client/pubsubprotocol"
    33  	"go.chromium.org/luci/logdog/common/storage"
    34  	"go.chromium.org/luci/logdog/common/types"
    35  	cc "go.chromium.org/luci/logdog/server/collector/coordinator"
    36  )
    37  
    38  var testSecret = bytes.Repeat([]byte{0x55}, types.PrefixSecretLength)
    39  
    40  type streamKey struct {
    41  	project string
    42  	id      string
    43  }
    44  
    45  func mkStreamKey(project, id string) streamKey {
    46  	return streamKey{project, id}
    47  }
    48  
    49  // testCoordinator is an implementation of Coordinator that can be used for
    50  // testing.
    51  type testCoordinator struct {
    52  	sync.Mutex
    53  
    54  	// registerCallback, if not nil, is called when stream registration happens.
    55  	registerCallback func(cc.LogStreamState) error
    56  	// terminateCallback, if not nil, is called when stream termination happens.
    57  	terminateCallback func(cc.TerminateRequest) error
    58  
    59  	// state is the latest tracked stream state.
    60  	state map[streamKey]*cc.LogStreamState
    61  }
    62  
    63  var _ cc.Coordinator = (*testCoordinator)(nil)
    64  
    65  func (c *testCoordinator) register(s cc.LogStreamState) cc.LogStreamState {
    66  	c.Lock()
    67  	defer c.Unlock()
    68  
    69  	// Update our state.
    70  	if c.state == nil {
    71  		c.state = make(map[streamKey]*cc.LogStreamState)
    72  	}
    73  
    74  	id := idFromPath(string(s.Path))
    75  	key := mkStreamKey(string(s.Project), id)
    76  
    77  	if sp := c.state[key]; sp != nil {
    78  		return *sp
    79  	}
    80  
    81  	s.ID = id
    82  	c.state[key] = &s
    83  	return s
    84  }
    85  
    86  func (c *testCoordinator) RegisterStream(ctx context.Context, s *cc.LogStreamState, desc []byte) (*cc.LogStreamState, error) {
    87  	if cb := c.registerCallback; cb != nil {
    88  		if err := cb(*s); err != nil {
    89  			return nil, err
    90  		}
    91  	}
    92  
    93  	sp := c.register(*s)
    94  	return &sp, nil
    95  }
    96  
    97  func (c *testCoordinator) TerminateStream(ctx context.Context, tr *cc.TerminateRequest) error {
    98  	if cb := c.terminateCallback; cb != nil {
    99  		if err := cb(*tr); err != nil {
   100  			return err
   101  		}
   102  	}
   103  
   104  	if tr.TerminalIndex < 0 {
   105  		return errors.New("submitted stream is not terminal")
   106  	}
   107  
   108  	c.Lock()
   109  	defer c.Unlock()
   110  
   111  	// Update our state.
   112  	cachedState, ok := c.state[mkStreamKey(string(tr.Project), tr.ID)]
   113  	if !ok {
   114  		return fmt.Errorf("no such stream: %s", tr.ID)
   115  	}
   116  	if cachedState.TerminalIndex >= 0 && tr.TerminalIndex != cachedState.TerminalIndex {
   117  		return fmt.Errorf("incompatible terminal indexes: %d != %d", tr.TerminalIndex, cachedState.TerminalIndex)
   118  	}
   119  
   120  	cachedState.TerminalIndex = tr.TerminalIndex
   121  	return nil
   122  }
   123  
   124  func (c *testCoordinator) stream(project, id string) (int, bool) {
   125  	c.Lock()
   126  	defer c.Unlock()
   127  
   128  	sp, ok := c.state[mkStreamKey(project, id)]
   129  	if !ok {
   130  		return 0, false
   131  	}
   132  	return int(sp.TerminalIndex), true
   133  }
   134  
   135  func (c *testCoordinator) streamForPath(project, path string) (int, bool) {
   136  	return c.stream(project, idFromPath(path))
   137  }
   138  
   139  // testStorage is a testing storage instance that returns errors.
   140  type testStorage struct {
   141  	storage.Storage
   142  	err func() error
   143  }
   144  
   145  func (s *testStorage) Put(c context.Context, r storage.PutRequest) error {
   146  	if s.err != nil {
   147  		if err := s.err(); err != nil {
   148  			return err
   149  		}
   150  	}
   151  	return s.Storage.Put(c, r)
   152  }
   153  
   154  // bundleBuilder is a set of utility functions to help test cases construct
   155  // specific logpb.ButlerLogBundle layouts.
   156  type bundleBuilder struct {
   157  	context.Context
   158  
   159  	base *logpb.ButlerLogBundle
   160  }
   161  
   162  func (b *bundleBuilder) genBase() *logpb.ButlerLogBundle {
   163  	if b.base == nil {
   164  		b.base = &logpb.ButlerLogBundle{
   165  			Timestamp: timestamppb.New(clock.Now(b)),
   166  			Project:   "test-project",
   167  			Prefix:    "foo",
   168  			Secret:    testSecret,
   169  		}
   170  	}
   171  	return b.base
   172  }
   173  
   174  func (b *bundleBuilder) addBundleEntry(be *logpb.ButlerLogBundle_Entry) {
   175  	base := b.genBase()
   176  	base.Entries = append(base.Entries, be)
   177  }
   178  
   179  func (b *bundleBuilder) genBundleEntry(name string, tidx int, idxs ...int) *logpb.ButlerLogBundle_Entry {
   180  	p, n := types.StreamPath(name).Split()
   181  	be := logpb.ButlerLogBundle_Entry{
   182  		Desc: &logpb.LogStreamDescriptor{
   183  			Prefix:      string(p),
   184  			Name:        string(n),
   185  			ContentType: "application/test-message",
   186  			StreamType:  logpb.StreamType_TEXT,
   187  			Timestamp:   timestamppb.New(clock.Now(b)),
   188  		},
   189  	}
   190  
   191  	if len(idxs) > 0 {
   192  		be.Logs = make([]*logpb.LogEntry, len(idxs))
   193  		for i, idx := range idxs {
   194  			be.Logs[i] = b.logEntry(idx)
   195  		}
   196  		if tidx >= 0 {
   197  			be.Terminal = true
   198  			be.TerminalIndex = uint64(tidx)
   199  		}
   200  	}
   201  
   202  	return &be
   203  }
   204  
   205  func (b *bundleBuilder) addStreamEntries(name string, term int, idxs ...int) {
   206  	b.addBundleEntry(b.genBundleEntry(name, term, idxs...))
   207  }
   208  
   209  func (b *bundleBuilder) addFullStream(name string, count int) {
   210  	idxs := make([]int, count)
   211  	for i := range idxs {
   212  		idxs[i] = i
   213  	}
   214  	b.addStreamEntries(name, count-1, idxs...)
   215  }
   216  
   217  func (b *bundleBuilder) logEntry(idx int) *logpb.LogEntry {
   218  	return &logpb.LogEntry{
   219  		StreamIndex: uint64(idx),
   220  		Sequence:    uint64(idx),
   221  		Content: &logpb.LogEntry_Text{
   222  			Text: &logpb.Text{
   223  				Lines: []*logpb.Text_Line{
   224  					{
   225  						Value:     []byte(fmt.Sprintf("Line #%d", idx)),
   226  						Delimiter: "\n",
   227  					},
   228  				},
   229  			},
   230  		},
   231  	}
   232  }
   233  
   234  func (b *bundleBuilder) bundle() []byte {
   235  	buf := bytes.Buffer{}
   236  	w := pubsubprotocol.Writer{Compress: true}
   237  	if err := w.Write(&buf, b.genBase()); err != nil {
   238  		panic(err)
   239  	}
   240  
   241  	b.base = nil
   242  	return buf.Bytes()
   243  }
   244  
   245  type indexRange struct {
   246  	start int
   247  	end   int
   248  }
   249  
   250  func (r *indexRange) String() string { return fmt.Sprintf("[%d..%d]", r.start, r.end) }
   251  
   252  // shouldHaveRegisteredStream asserts that a testCoordinator has
   253  // registered a stream (string) and its terminal index (int).
   254  func shouldHaveRegisteredStream(actual any, expected ...any) string {
   255  	tcc := actual.(*testCoordinator)
   256  
   257  	if len(expected) != 3 {
   258  		return "invalid number of expected arguments (should be 3)."
   259  	}
   260  	project := expected[0].(string)
   261  	path := expected[1].(string)
   262  	tidx := expected[2].(int)
   263  
   264  	cur, ok := tcc.streamForPath(project, path)
   265  	if !ok {
   266  		return fmt.Sprintf("stream %q is not registered", path)
   267  	}
   268  	if tidx >= 0 && cur < 0 {
   269  		return fmt.Sprintf("stream %q is expected to be terminated, but isn't.", path)
   270  	}
   271  	if cur >= 0 && tidx < 0 {
   272  		return fmt.Sprintf("stream %q is NOT expected to be terminated, but it is.", path)
   273  	}
   274  	return ""
   275  }
   276  
   277  // shoudNotHaveRegisteredStream asserts that a testCoordinator has not
   278  // registered a stream (string).
   279  func shouldNotHaveRegisteredStream(actual any, expected ...any) string {
   280  	tcc := actual.(*testCoordinator)
   281  	if len(expected) != 2 {
   282  		return "invalid number of expected arguments (should be 2)."
   283  	}
   284  	project := expected[0].(string)
   285  	path := expected[1].(string)
   286  
   287  	if _, ok := tcc.streamForPath(project, path); ok {
   288  		return fmt.Sprintf("stream %q is registered, but it should NOT be.", path)
   289  	}
   290  	return ""
   291  }
   292  
   293  // shouldHaveStoredStream asserts that a storage.Storage instance has contiguous
   294  // stream records in it.
   295  //
   296  // actual is the storage.Storage instance. expected is a stream name (string)
   297  // followed by a a series of records to assert. This can either be a specific
   298  // integer index or an intexRange marking a closed range of indices.
   299  func shouldHaveStoredStream(actual any, expected ...any) string {
   300  	st := actual.(storage.Storage)
   301  	project := expected[0].(string)
   302  	name := expected[1].(string)
   303  	expected = expected[2:]
   304  
   305  	// Load all entries for this stream.
   306  	req := storage.GetRequest{
   307  		Project: project,
   308  		Path:    types.StreamPath(name),
   309  	}
   310  
   311  	entries := make(map[int]*logpb.LogEntry)
   312  	var ierr error
   313  	err := st.Get(context.Background(), req, func(e *storage.Entry) bool {
   314  		var le *logpb.LogEntry
   315  		if le, ierr = e.GetLogEntry(); ierr != nil {
   316  			return false
   317  		}
   318  		entries[int(le.StreamIndex)] = le
   319  		return true
   320  	})
   321  	if ierr != nil {
   322  		err = ierr
   323  	}
   324  	if err != nil && err != storage.ErrDoesNotExist {
   325  		return fmt.Sprintf("error: %v", err)
   326  	}
   327  
   328  	assertLogEntry := func(i int) string {
   329  		le := entries[i]
   330  		if le == nil {
   331  			return fmt.Sprintf("%d", i)
   332  		}
   333  		delete(entries, i)
   334  
   335  		if le.StreamIndex != uint64(i) {
   336  			return fmt.Sprintf("*%d", i)
   337  		}
   338  		return ""
   339  	}
   340  
   341  	var failed []string
   342  	for _, exp := range expected {
   343  		switch e := exp.(type) {
   344  		case int:
   345  			if err := assertLogEntry(e); err != "" {
   346  				failed = append(failed, fmt.Sprintf("missing{%s}", err))
   347  			}
   348  
   349  		case indexRange:
   350  			var errs []string
   351  			for i := e.start; i <= e.end; i++ {
   352  				if err := assertLogEntry(i); err != "" {
   353  					errs = append(errs, err)
   354  				}
   355  			}
   356  			if len(errs) > 0 {
   357  				failed = append(failed, fmt.Sprintf("%s{%s}", e.String(), strings.Join(errs, ",")))
   358  			}
   359  
   360  		default:
   361  			panic(fmt.Errorf("unknown expected type %T", e))
   362  		}
   363  	}
   364  
   365  	// Extras?
   366  	if len(entries) > 0 {
   367  		idxs := make([]int, 0, len(entries))
   368  		for i := range entries {
   369  			idxs = append(idxs, i)
   370  		}
   371  		sort.Ints(idxs)
   372  
   373  		extra := make([]string, len(idxs))
   374  		for i, idx := range idxs {
   375  			extra[i] = fmt.Sprintf("%d", idx)
   376  		}
   377  		failed = append(failed, fmt.Sprintf("extra{%s}", strings.Join(extra, ",")))
   378  	}
   379  
   380  	if len(failed) > 0 {
   381  		return strings.Join(failed, ", ")
   382  	}
   383  	return ""
   384  }
   385  
   386  func idFromPath(path string) string {
   387  	hash := sha256.Sum256([]byte(path))
   388  	return hex.EncodeToString(hash[:])
   389  }