github.com/safing/portbase@v0.19.5/modules/modules_test.go (about)

     1  package modules
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"sync"
     7  	"testing"
     8  )
     9  
    10  var (
    11  	changeHistoryLock sync.Mutex
    12  	changeHistory     string
    13  )
    14  
    15  func registerTestModule(t *testing.T, name string, dependencies ...string) {
    16  	t.Helper()
    17  
    18  	Register(
    19  		name,
    20  		func() error {
    21  			t.Logf("prep %s\n", name)
    22  			return nil
    23  		},
    24  		func() error {
    25  			changeHistoryLock.Lock()
    26  			defer changeHistoryLock.Unlock()
    27  			t.Logf("start %s\n", name)
    28  			changeHistory = fmt.Sprintf("%s on:%s", changeHistory, name)
    29  			return nil
    30  		},
    31  		func() error {
    32  			changeHistoryLock.Lock()
    33  			defer changeHistoryLock.Unlock()
    34  			t.Logf("stop %s\n", name)
    35  			changeHistory = fmt.Sprintf("%s off:%s", changeHistory, name)
    36  			return nil
    37  		},
    38  		dependencies...,
    39  	)
    40  }
    41  
    42  func testFail() error {
    43  	return errors.New("test error")
    44  }
    45  
    46  func testCleanExit() error {
    47  	return ErrCleanExit
    48  }
    49  
    50  func TestModules(t *testing.T) { //nolint:tparallel // Too much interference expected.
    51  	t.Parallel() // Not really, just a workaround for running these tests last.
    52  
    53  	t.Run("TestModuleOrder", testModuleOrder)   //nolint:paralleltest // Too much interference expected.
    54  	t.Run("TestModuleMgmt", testModuleMgmt)     //nolint:paralleltest // Too much interference expected.
    55  	t.Run("TestModuleErrors", testModuleErrors) //nolint:paralleltest // Too much interference expected.
    56  }
    57  
    58  func testModuleOrder(t *testing.T) {
    59  	registerTestModule(t, "database")
    60  	registerTestModule(t, "stats", "database")
    61  	registerTestModule(t, "service", "database")
    62  	registerTestModule(t, "analytics", "stats", "database")
    63  
    64  	err := Start()
    65  	if err != nil {
    66  		t.Error(err)
    67  	}
    68  
    69  	if changeHistory != " on:database on:service on:stats on:analytics" &&
    70  		changeHistory != " on:database on:stats on:service on:analytics" &&
    71  		changeHistory != " on:database on:stats on:analytics on:service" {
    72  		t.Errorf("start order mismatch, was %s", changeHistory)
    73  	}
    74  	changeHistory = ""
    75  
    76  	err = Shutdown()
    77  	if err != nil {
    78  		t.Error(err)
    79  	}
    80  
    81  	if changeHistory != " off:analytics off:service off:stats off:database" &&
    82  		changeHistory != " off:analytics off:stats off:service off:database" &&
    83  		changeHistory != " off:service off:analytics off:stats off:database" {
    84  		t.Errorf("shutdown order mismatch, was %s", changeHistory)
    85  	}
    86  	changeHistory = ""
    87  
    88  	resetTestEnvironment()
    89  }
    90  
    91  func testModuleErrors(t *testing.T) {
    92  	// test prep error
    93  	Register("prepfail", testFail, nil, nil)
    94  	err := Start()
    95  	if err == nil {
    96  		t.Error("should fail")
    97  	}
    98  
    99  	resetTestEnvironment()
   100  
   101  	// test prep clean exit
   102  	Register("prepcleanexit", testCleanExit, nil, nil)
   103  	err = Start()
   104  	if !errors.Is(err, ErrCleanExit) {
   105  		t.Error("should fail with clean exit")
   106  	}
   107  
   108  	resetTestEnvironment()
   109  
   110  	// test invalid dependency
   111  	Register("database", nil, nil, nil, "invalid")
   112  	err = Start()
   113  	if err == nil {
   114  		t.Error("should fail")
   115  	}
   116  
   117  	resetTestEnvironment()
   118  
   119  	// test dependency loop
   120  	registerTestModule(t, "database", "helper")
   121  	registerTestModule(t, "helper", "database")
   122  	err = Start()
   123  	if err == nil {
   124  		t.Error("should fail")
   125  	}
   126  
   127  	resetTestEnvironment()
   128  
   129  	// test failing module start
   130  	Register("startfail", nil, testFail, nil)
   131  	err = Start()
   132  	if err == nil {
   133  		t.Error("should fail")
   134  	}
   135  
   136  	resetTestEnvironment()
   137  
   138  	// test failing module stop
   139  	Register("stopfail", nil, nil, testFail)
   140  	err = Start()
   141  	if err != nil {
   142  		t.Error("should not fail")
   143  	}
   144  	err = Shutdown()
   145  	if err == nil {
   146  		t.Error("should fail")
   147  	}
   148  
   149  	resetTestEnvironment()
   150  
   151  	// test help flag
   152  	HelpFlag = true
   153  	err = Start()
   154  	if err == nil {
   155  		t.Error("should fail")
   156  	}
   157  	HelpFlag = false
   158  
   159  	resetTestEnvironment()
   160  }
   161  
   162  func printModules() { //nolint:unused,deadcode
   163  	fmt.Printf("All %d modules:\n", len(modules))
   164  	for _, m := range modules {
   165  		fmt.Printf("module %s: %+v\n", m.Name, m)
   166  	}
   167  }
   168  
   169  func resetTestEnvironment() {
   170  	modules = make(map[string]*Module)
   171  	shutdownSignal = make(chan struct{})
   172  	shutdownCompleteSignal = make(chan struct{})
   173  	shutdownFlag.UnSet()
   174  	modulesLocked.UnSet()
   175  }