github.com/nats-io/nats-server/v2@v2.11.0-preview.2/internal/testhelper/logging.go (about)

     1  // Copyright 2019-2021 The NATS Authors
     2  // Licensed under the Apache License, Version 2.0 (the "License");
     3  // you may not use this file except in compliance with the License.
     4  // You may obtain a copy of the License at
     5  //
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software
     9  // distributed under the License is distributed on an "AS IS" BASIS,
    10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package testhelper
    15  
    16  // These routines need to be accessible in both the server and test
    17  // directories, and tests importing a package don't get exported symbols from
    18  // _test.go files in the imported package, so we put them here where they can
    19  // be used freely.
    20  
    21  import (
    22  	"fmt"
    23  	"strings"
    24  	"sync"
    25  	"testing"
    26  )
    27  
    28  type DummyLogger struct {
    29  	sync.Mutex
    30  	Msg     string
    31  	AllMsgs []string
    32  }
    33  
    34  func (l *DummyLogger) CheckContent(t *testing.T, expectedStr string) {
    35  	t.Helper()
    36  	l.Lock()
    37  	defer l.Unlock()
    38  	if l.Msg != expectedStr {
    39  		t.Fatalf("Expected log to be: %v, got %v", expectedStr, l.Msg)
    40  	}
    41  }
    42  
    43  func (l *DummyLogger) aggregate() {
    44  	if l.AllMsgs != nil {
    45  		l.AllMsgs = append(l.AllMsgs, l.Msg)
    46  	}
    47  }
    48  
    49  func (l *DummyLogger) Noticef(format string, v ...any) {
    50  	l.Lock()
    51  	defer l.Unlock()
    52  	l.Msg = fmt.Sprintf(format, v...)
    53  	l.aggregate()
    54  }
    55  func (l *DummyLogger) Errorf(format string, v ...any) {
    56  	l.Lock()
    57  	defer l.Unlock()
    58  	l.Msg = fmt.Sprintf(format, v...)
    59  	l.aggregate()
    60  }
    61  func (l *DummyLogger) Warnf(format string, v ...any) {
    62  	l.Lock()
    63  	defer l.Unlock()
    64  	l.Msg = fmt.Sprintf(format, v...)
    65  	l.aggregate()
    66  }
    67  func (l *DummyLogger) Fatalf(format string, v ...any) {
    68  	l.Lock()
    69  	defer l.Unlock()
    70  	l.Msg = fmt.Sprintf(format, v...)
    71  	l.aggregate()
    72  }
    73  func (l *DummyLogger) Debugf(format string, v ...any) {
    74  	l.Lock()
    75  	defer l.Unlock()
    76  	l.Msg = fmt.Sprintf(format, v...)
    77  	l.aggregate()
    78  }
    79  func (l *DummyLogger) Tracef(format string, v ...any) {
    80  	l.Lock()
    81  	defer l.Unlock()
    82  	l.Msg = fmt.Sprintf(format, v...)
    83  	l.aggregate()
    84  }
    85  
    86  // NewDummyLogger creates a dummy logger and allows to ask for logs to be
    87  // retained instead of just keeping the most recent. Use retain to provide an
    88  // initial size estimate on messages (not to provide a max capacity).
    89  func NewDummyLogger(retain uint) *DummyLogger {
    90  	l := &DummyLogger{}
    91  	if retain > 0 {
    92  		l.AllMsgs = make([]string, 0, retain)
    93  	}
    94  	return l
    95  }
    96  
    97  func (l *DummyLogger) Drain() {
    98  	l.Lock()
    99  	defer l.Unlock()
   100  	if l.AllMsgs == nil {
   101  		return
   102  	}
   103  	l.AllMsgs = make([]string, 0, len(l.AllMsgs))
   104  }
   105  
   106  func (l *DummyLogger) CheckForProhibited(t *testing.T, reason, needle string) {
   107  	t.Helper()
   108  	l.Lock()
   109  	defer l.Unlock()
   110  
   111  	if l.AllMsgs == nil {
   112  		t.Fatal("DummyLogger.CheckForProhibited called without AllMsgs being collected")
   113  	}
   114  
   115  	// Collect _all_ matches, rather than have to re-test repeatedly.
   116  	// This will particularly help with less deterministic tests with multiple matches.
   117  	shouldFail := false
   118  	for i := range l.AllMsgs {
   119  		if strings.Contains(l.AllMsgs[i], needle) {
   120  			t.Errorf("log contains %s: %v", reason, l.AllMsgs[i])
   121  			shouldFail = true
   122  		}
   123  	}
   124  	if shouldFail {
   125  		t.FailNow()
   126  	}
   127  }