go.chromium.org/luci@v0.0.0-20240309015107-7cdc2e660f33/common/exec/execmock/context.go (about)

     1  // Copyright 2023 The LUCI Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package execmock
    16  
    17  import (
    18  	"context"
    19  	"os"
    20  	"sort"
    21  	"sync"
    22  	"sync/atomic"
    23  	"testing"
    24  
    25  	"go.chromium.org/luci/common/errors"
    26  	"go.chromium.org/luci/common/exec/internal/execmockctx"
    27  	"go.chromium.org/luci/common/exec/internal/execmockserver"
    28  	"go.chromium.org/luci/common/system/environ"
    29  )
    30  
    31  type mockEntry struct {
    32  	uses uses
    33  
    34  	inputData *execmockserver.InvocationInput
    35  
    36  	f filter
    37  
    38  	mockEntryID uint64
    39  }
    40  
    41  // mockEntryID is used as a disambiguator for Less. It is used by addMockEntry
    42  // via an atomic.AddUint64 invocation.
    43  var mockEntryID atomic.Uint64
    44  
    45  func (me *mockEntry) matches(mc *execmockctx.MockCriteria, proc **os.Process) (match bool, lastUse bool, usage usage) {
    46  	match = me.f.matches(mc)
    47  	if match {
    48  		usage = me.uses.addUsage(mc, proc)
    49  		if me.f.limit > 0 && uint64(me.uses.len()) >= uint64(me.f.limit) {
    50  			lastUse = true
    51  		}
    52  	}
    53  	return
    54  }
    55  
    56  func (me *mockEntry) less(other *mockEntry) bool {
    57  	if less, ok := me.f.less(other.f); ok {
    58  		return less
    59  	}
    60  	return me.mockEntryID < other.mockEntryID
    61  }
    62  
    63  // mockState is stored in a context; It holds all of the mock entries that the
    64  // user's test set in the context, and it can be interrogated by the user test
    65  // to list all the missed MockCriteria at the end of the test.
    66  //
    67  // Tests will populate the mockState with all of the mock entries.
    68  type mockState struct {
    69  	mu     sync.Mutex
    70  	mocks  []*mockEntry
    71  	misses []*MockCriteria
    72  
    73  	chatty bool
    74  
    75  	// flipped to `true` once the first mock invocation has happened.
    76  	sealed bool
    77  }
    78  
    79  func addMockEntry[Out any](ctx context.Context, f filter, i *execmockserver.InvocationInput) *Uses[Out] {
    80  	u := &Uses[Out]{}
    81  
    82  	state := mustGetState(ctx)
    83  
    84  	me := &mockEntry{uses: u, inputData: i, f: f}
    85  	me.mockEntryID = mockEntryID.Add(1)
    86  
    87  	state.mu.Lock()
    88  	defer state.mu.Unlock()
    89  
    90  	if state.sealed {
    91  		panic(errors.New("Cannot add Mock to sealed mocking context. Call ResetState to reset the mocking state on this context."))
    92  	}
    93  
    94  	state.mocks = append(state.mocks, me)
    95  	return u
    96  }
    97  
    98  func getMocker(ctx context.Context) (mocker execmockctx.CreateMockInvocation, chatty bool) {
    99  	state := getState(ctx)
   100  	if state == nil {
   101  		return func(mc *execmockctx.MockCriteria, proc **os.Process) (*execmockctx.MockInvocation, error) {
   102  			return nil, errors.Annotate(execmockctx.ErrNoMatchingMock, "execmock.Init not called on context").Err()
   103  		}, false
   104  	}
   105  	return state.createMockInvocation, state.chatty
   106  }
   107  
   108  // createMockInvocation is the main way that execmock interacts with both the
   109  // global state as well as the state in the context.
   110  func (state *mockState) createMockInvocation(mc *execmockctx.MockCriteria, proc **os.Process) (*execmockctx.MockInvocation, error) {
   111  	state.mu.Lock()
   112  	defer state.mu.Unlock()
   113  
   114  	if !state.sealed {
   115  		sort.SliceStable(state.mocks, func(i, j int) bool {
   116  			return state.mocks[i].less(state.mocks[j])
   117  		})
   118  		state.sealed = true
   119  	}
   120  
   121  	for i, ent := range state.mocks {
   122  		matches, lastUse, usage := ent.matches(mc, proc)
   123  		if matches {
   124  			if lastUse {
   125  				state.mocks = append(state.mocks[:i], state.mocks[i+1:]...)
   126  			}
   127  			if ent.inputData == nil {
   128  				return nil, nil
   129  			}
   130  			if err := ent.inputData.StartError; err != nil {
   131  				return nil, err
   132  			}
   133  			envvar, id := execmockserver.RegisterInvocation(server, ent.inputData, usage.setOutput)
   134  			return &execmockctx.MockInvocation{
   135  				ID:             id,
   136  				EnvVar:         envvar,
   137  				GetErrorOutput: usage.getErrorOutput,
   138  			}, nil
   139  		}
   140  	}
   141  	state.misses = append(state.misses, mcFromInternal(mc))
   142  
   143  	return nil, errors.Annotate(execmockctx.ErrNoMatchingMock, "%s", mc).Err()
   144  }
   145  
   146  // MockCriteria are the parameters from a Command used to match an ExecMocker
   147  type MockCriteria struct {
   148  	Args []string
   149  	Env  environ.Env
   150  }
   151  
   152  func mcFromInternal(imc *execmockctx.MockCriteria) *MockCriteria {
   153  	return &MockCriteria{
   154  		Args: append([]string(nil), imc.Args...),
   155  		Env:  imc.Env.Clone(),
   156  	}
   157  }
   158  
   159  // ResetState returns the list of missed MockCriteria (i.e. commands which
   160  // didn't match any Mocks) and resets the state in `ctx` (wipes all existing
   161  // Mock entries, misses, etc.)
   162  //
   163  // This also `unseals` the state, allowing Mocker.Mock to be called on this
   164  // context again.
   165  func ResetState(ctx context.Context) []*MockCriteria {
   166  	state := mustGetState(ctx)
   167  	state.mu.Lock()
   168  	defer state.mu.Unlock()
   169  
   170  	ret := state.misses
   171  
   172  	state.mocks = nil
   173  	state.misses = nil
   174  	state.sealed = false
   175  
   176  	return ret
   177  }
   178  
   179  var stateCtxKey = "context key for holding a *mockState"
   180  
   181  func getState(ctx context.Context) *mockState {
   182  	stateI := ctx.Value(&stateCtxKey)
   183  	if stateI == nil {
   184  		return nil
   185  	}
   186  
   187  	state, ok := stateI.(*mockState)
   188  	if !ok {
   189  		panic("impossible: execmock context key holds wrong type")
   190  	}
   191  	return state
   192  }
   193  
   194  func mustGetState(ctx context.Context) *mockState {
   195  	state := getState(ctx)
   196  	if state == nil {
   197  		panic("execmock: No MockState: Use execmock.Init on context first.")
   198  	}
   199  	return state
   200  }
   201  
   202  // Init adds a mockState to the context, which indicates that this context
   203  // should mock new exec calls using it.
   204  //
   205  // If `testing.Verbose()` is true (i.e. `go test -v`), this turns on "chatty
   206  // mode", which will emit a log line for every exec this library intercepts,
   207  // and will also copy and dump any stdout/stderr. This allows easier debugging
   208  // of your RunnerFunctions.
   209  //
   210  // Panics if `ctx` has already been initialized for execmock.
   211  func Init(ctx context.Context) context.Context {
   212  	if getState(ctx) != nil {
   213  		panic(errors.New("execmock.Init: called twice on the same context"))
   214  	}
   215  	return context.WithValue(ctx, &stateCtxKey, &mockState{
   216  		chatty: testing.Verbose(),
   217  	})
   218  }