github.com/gophish/gophish@v0.12.2-0.20230915144530-8e7929441393/mailer/mailer_test.go (about)

     1  package mailer
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"io"
     8  	"net/textproto"
     9  	"reflect"
    10  	"testing"
    11  )
    12  
    13  func generateMessages(dialer Dialer) []Mail {
    14  	to := []string{"to@example.com"}
    15  
    16  	messageContents := []io.WriterTo{
    17  		bytes.NewBuffer([]byte("First email")),
    18  		bytes.NewBuffer([]byte("Second email")),
    19  	}
    20  
    21  	m1 := newMockMessage("first@example.com", to, messageContents[0])
    22  	m2 := newMockMessage("second@example.com", to, messageContents[1])
    23  
    24  	m1.setDialer(func() (Dialer, error) { return dialer, nil })
    25  
    26  	messages := []Mail{m1, m2}
    27  	return messages
    28  }
    29  
    30  func newMockErrorSender(err error) *mockSender {
    31  	sender := newMockSender()
    32  	// The sending function will send a temporary error to emulate
    33  	// a backoff.
    34  	sender.setSend(func(mm *mockMessage) error {
    35  		if len(sender.messages) == 1 {
    36  			return err
    37  		}
    38  		sender.messageChan <- mm
    39  		return nil
    40  	})
    41  	return sender
    42  }
    43  
    44  func TestDialHost(t *testing.T) {
    45  	ctx, cancel := context.WithCancel(context.Background())
    46  	defer cancel()
    47  	md := newMockDialer()
    48  	md.setDial(md.unreachableDial)
    49  	_, err := dialHost(ctx, md)
    50  	if _, ok := err.(*ErrMaxConnectAttempts); !ok {
    51  		t.Fatalf("Didn't receive expected ErrMaxConnectAttempts. Got: %s", err)
    52  	}
    53  	e := err.(*ErrMaxConnectAttempts)
    54  	if e.underlyingError != errHostUnreachable {
    55  		t.Fatalf("Got invalid underlying error. Expected %s Got %s\n", e.underlyingError, errHostUnreachable)
    56  	}
    57  	if md.dialCount != MaxReconnectAttempts {
    58  		t.Fatalf("Unexpected number of reconnect attempts. Expected %d, Got %d", MaxReconnectAttempts, md.dialCount)
    59  	}
    60  	md.setDial(md.defaultDial)
    61  	_, err = dialHost(ctx, md)
    62  	if err != nil {
    63  		t.Fatalf("Unexpected error when dialing the mock host: %s", err)
    64  	}
    65  }
    66  
    67  func TestMailWorkerStart(t *testing.T) {
    68  	ctx, cancel := context.WithCancel(context.Background())
    69  	defer cancel()
    70  
    71  	mw := NewMailWorker()
    72  	go func(ctx context.Context) {
    73  		mw.Start(ctx)
    74  	}(ctx)
    75  
    76  	sender := newMockSender()
    77  	dialer := newMockDialer()
    78  	dialer.setDial(func() (Sender, error) {
    79  		return sender, nil
    80  	})
    81  
    82  	messages := generateMessages(dialer)
    83  
    84  	// Send the campaign
    85  	mw.Queue(messages)
    86  
    87  	got := []*mockMessage{}
    88  
    89  	idx := 0
    90  	for message := range sender.messageChan {
    91  		got = append(got, message)
    92  		original := messages[idx].(*mockMessage)
    93  		if original.from != message.from {
    94  			t.Fatalf("Invalid message received. Expected %s, Got %s", original.from, message.from)
    95  		}
    96  		idx++
    97  	}
    98  	if len(got) != len(messages) {
    99  		t.Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), len(messages))
   100  	}
   101  }
   102  
   103  func TestBackoff(t *testing.T) {
   104  	ctx, cancel := context.WithCancel(context.Background())
   105  	defer cancel()
   106  
   107  	mw := NewMailWorker()
   108  	go func(ctx context.Context) {
   109  		mw.Start(ctx)
   110  	}(ctx)
   111  
   112  	expectedError := &textproto.Error{
   113  		Code: 400,
   114  		Msg:  "Temporary error",
   115  	}
   116  
   117  	sender := newMockErrorSender(expectedError)
   118  	dialer := newMockDialer()
   119  	dialer.setDial(func() (Sender, error) {
   120  		return sender, nil
   121  	})
   122  
   123  	messages := generateMessages(dialer)
   124  
   125  	// Send the campaign
   126  	mw.Queue(messages)
   127  
   128  	got := []*mockMessage{}
   129  
   130  	for message := range sender.messageChan {
   131  		got = append(got, message)
   132  	}
   133  	// Check that we only sent one message
   134  	expectedCount := 1
   135  	if len(got) != expectedCount {
   136  		t.Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), expectedCount)
   137  	}
   138  
   139  	// Check that it's the correct message
   140  	originalFrom := messages[1].(*mockMessage).from
   141  	if got[0].from != originalFrom {
   142  		t.Fatalf("Invalid message received. Expected %s, Got %s", originalFrom, got[0].from)
   143  	}
   144  
   145  	// Check that the first message performed a backoff
   146  	backoffCount := messages[0].(*mockMessage).backoffCount
   147  	if backoffCount != expectedCount {
   148  		t.Fatalf("Did not receive expected backoff. Got backoffCount %d, Expected %d", backoffCount, expectedCount)
   149  	}
   150  
   151  	// Check that there was a reset performed on the sender
   152  	if sender.resetCount != expectedCount {
   153  		t.Fatalf("Did not receive expected reset. Got resetCount %d, expected %d", sender.resetCount, expectedCount)
   154  	}
   155  }
   156  
   157  func TestPermError(t *testing.T) {
   158  	ctx, cancel := context.WithCancel(context.Background())
   159  	defer cancel()
   160  
   161  	mw := NewMailWorker()
   162  	go func(ctx context.Context) {
   163  		mw.Start(ctx)
   164  	}(ctx)
   165  
   166  	expectedError := &textproto.Error{
   167  		Code: 500,
   168  		Msg:  "Permanent error",
   169  	}
   170  
   171  	sender := newMockErrorSender(expectedError)
   172  	dialer := newMockDialer()
   173  	dialer.setDial(func() (Sender, error) {
   174  		return sender, nil
   175  	})
   176  
   177  	messages := generateMessages(dialer)
   178  
   179  	// Send the campaign
   180  	mw.Queue(messages)
   181  
   182  	got := []*mockMessage{}
   183  
   184  	for message := range sender.messageChan {
   185  		got = append(got, message)
   186  	}
   187  	// Check that we only sent one message
   188  	expectedCount := 1
   189  	if len(got) != expectedCount {
   190  		t.Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), expectedCount)
   191  	}
   192  
   193  	// Check that it's the correct message
   194  	originalFrom := messages[1].(*mockMessage).from
   195  	if got[0].from != originalFrom {
   196  		t.Fatalf("Invalid message received. Expected %s, Got %s", originalFrom, got[0].from)
   197  	}
   198  
   199  	message := messages[0].(*mockMessage)
   200  
   201  	// Check that the first message did not perform a backoff
   202  	expectedBackoffCount := 0
   203  	backoffCount := message.backoffCount
   204  	if backoffCount != expectedBackoffCount {
   205  		t.Fatalf("Did not receive expected backoff. Got backoffCount %d, Expected %d", backoffCount, expectedCount)
   206  	}
   207  
   208  	// Check that there was a reset performed on the sender
   209  	if sender.resetCount != expectedCount {
   210  		t.Fatalf("Did not receive expected reset. Got resetCount %d, expected %d", sender.resetCount, expectedCount)
   211  	}
   212  
   213  	// Check that the email errored out appropriately
   214  	if !reflect.DeepEqual(message.err, expectedError) {
   215  		t.Fatalf("Did not received expected error. Got %#v\nExpected %#v", message.err, expectedError)
   216  	}
   217  }
   218  
   219  func TestUnknownError(t *testing.T) {
   220  	ctx, cancel := context.WithCancel(context.Background())
   221  	defer cancel()
   222  
   223  	mw := NewMailWorker()
   224  	go func(ctx context.Context) {
   225  		mw.Start(ctx)
   226  	}(ctx)
   227  
   228  	expectedError := errors.New("Unexpected error")
   229  
   230  	sender := newMockErrorSender(expectedError)
   231  	dialer := newMockDialer()
   232  	dialer.setDial(func() (Sender, error) {
   233  		return sender, nil
   234  	})
   235  
   236  	messages := generateMessages(dialer)
   237  
   238  	// Send the campaign
   239  	mw.Queue(messages)
   240  
   241  	got := []*mockMessage{}
   242  
   243  	for message := range sender.messageChan {
   244  		got = append(got, message)
   245  	}
   246  	// Check that we only sent one message
   247  	expectedCount := 1
   248  	if len(got) != expectedCount {
   249  		t.Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), expectedCount)
   250  	}
   251  
   252  	// Check that it's the correct message
   253  	originalFrom := messages[1].(*mockMessage).from
   254  	if got[0].from != originalFrom {
   255  		t.Fatalf("Invalid message received. Expected %s, Got %s", originalFrom, got[0].from)
   256  	}
   257  
   258  	message := messages[0].(*mockMessage)
   259  
   260  	// If we get an unexpected error, this means that it's likely the
   261  	// underlying connection dropped. When this happens, we expect the
   262  	// connection to be re-established (see #997).
   263  	// In this case, we're successfully reestablishing the connection
   264  	// so we expect the backoff to occur.
   265  	expectedBackoffCount := 1
   266  	backoffCount := message.backoffCount
   267  	if backoffCount != expectedBackoffCount {
   268  		t.Fatalf("Did not receive expected backoff. Got backoffCount %d, Expected %d", backoffCount, expectedBackoffCount)
   269  	}
   270  
   271  	// Check that the underlying connection was reestablished
   272  	expectedDialCount := 2
   273  	if dialer.dialCount != expectedDialCount {
   274  		t.Fatalf("Did not receive expected dial count. Got %d expected %d", dialer.dialCount, expectedDialCount)
   275  	}
   276  
   277  	// Check that the email errored out appropriately
   278  	if !reflect.DeepEqual(message.err, expectedError) {
   279  		t.Fatalf("Did not received expected error. Got %#v\nExpected %#v", message.err, expectedError)
   280  	}
   281  }