github.com/dzsibi/gophish@v0.7.1-0.20190719042945-1f16c7237d0d/mailer/mailer_test.go (about)

     1  package mailer
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"io"
     8  	"net/textproto"
     9  	"reflect"
    10  	"testing"
    11  
    12  	"github.com/stretchr/testify/suite"
    13  )
    14  
    15  type MailerSuite struct {
    16  	suite.Suite
    17  }
    18  
    19  func generateMessages(dialer Dialer) []Mail {
    20  	to := []string{"to@example.com"}
    21  
    22  	messageContents := []io.WriterTo{
    23  		bytes.NewBuffer([]byte("First email")),
    24  		bytes.NewBuffer([]byte("Second email")),
    25  	}
    26  
    27  	m1 := newMockMessage("first@example.com", to, messageContents[0])
    28  	m2 := newMockMessage("second@example.com", to, messageContents[1])
    29  
    30  	m1.setDialer(func() (Dialer, error) { return dialer, nil })
    31  
    32  	messages := []Mail{m1, m2}
    33  	return messages
    34  }
    35  
    36  func newMockErrorSender(err error) *mockSender {
    37  	sender := newMockSender()
    38  	// The sending function will send a temporary error to emulate
    39  	// a backoff.
    40  	sender.setSend(func(mm *mockMessage) error {
    41  		if len(sender.messages) == 1 {
    42  			return err
    43  		}
    44  		sender.messageChan <- mm
    45  		return nil
    46  	})
    47  	return sender
    48  }
    49  
    50  func (ms *MailerSuite) TestDialHost() {
    51  	ctx, cancel := context.WithCancel(context.Background())
    52  	defer cancel()
    53  	md := newMockDialer()
    54  	md.setDial(md.unreachableDial)
    55  	_, err := dialHost(ctx, md)
    56  	if _, ok := err.(*ErrMaxConnectAttempts); !ok {
    57  		ms.T().Fatalf("Didn't receive expected ErrMaxConnectAttempts. Got: %s", err)
    58  	}
    59  	e := err.(*ErrMaxConnectAttempts)
    60  	if e.underlyingError != errHostUnreachable {
    61  		ms.T().Fatalf("Got invalid underlying error. Expected %s Got %s\n", e.underlyingError, errHostUnreachable)
    62  	}
    63  	if md.dialCount != MaxReconnectAttempts {
    64  		ms.T().Fatalf("Unexpected number of reconnect attempts. Expected %d, Got %d", MaxReconnectAttempts, md.dialCount)
    65  	}
    66  	md.setDial(md.defaultDial)
    67  	_, err = dialHost(ctx, md)
    68  	if err != nil {
    69  		ms.T().Fatalf("Unexpected error when dialing the mock host: %s", err)
    70  	}
    71  }
    72  
    73  func (ms *MailerSuite) TestMailWorkerStart() {
    74  	ctx, cancel := context.WithCancel(context.Background())
    75  	defer cancel()
    76  
    77  	mw := NewMailWorker()
    78  	go func(ctx context.Context) {
    79  		mw.Start(ctx)
    80  	}(ctx)
    81  
    82  	sender := newMockSender()
    83  	dialer := newMockDialer()
    84  	dialer.setDial(func() (Sender, error) {
    85  		return sender, nil
    86  	})
    87  
    88  	messages := generateMessages(dialer)
    89  
    90  	// Send the campaign
    91  	mw.Queue(messages)
    92  
    93  	got := []*mockMessage{}
    94  
    95  	idx := 0
    96  	for message := range sender.messageChan {
    97  		got = append(got, message)
    98  		original := messages[idx].(*mockMessage)
    99  		if original.from != message.from {
   100  			ms.T().Fatalf("Invalid message received. Expected %s, Got %s", original.from, message.from)
   101  		}
   102  		idx++
   103  	}
   104  	if len(got) != len(messages) {
   105  		ms.T().Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), len(messages))
   106  	}
   107  }
   108  
   109  func (ms *MailerSuite) TestBackoff() {
   110  	ctx, cancel := context.WithCancel(context.Background())
   111  	defer cancel()
   112  
   113  	mw := NewMailWorker()
   114  	go func(ctx context.Context) {
   115  		mw.Start(ctx)
   116  	}(ctx)
   117  
   118  	expectedError := &textproto.Error{
   119  		Code: 400,
   120  		Msg:  "Temporary error",
   121  	}
   122  
   123  	sender := newMockErrorSender(expectedError)
   124  	dialer := newMockDialer()
   125  	dialer.setDial(func() (Sender, error) {
   126  		return sender, nil
   127  	})
   128  
   129  	messages := generateMessages(dialer)
   130  
   131  	// Send the campaign
   132  	mw.Queue(messages)
   133  
   134  	got := []*mockMessage{}
   135  
   136  	for message := range sender.messageChan {
   137  		got = append(got, message)
   138  	}
   139  	// Check that we only sent one message
   140  	expectedCount := 1
   141  	if len(got) != expectedCount {
   142  		ms.T().Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), expectedCount)
   143  	}
   144  
   145  	// Check that it's the correct message
   146  	originalFrom := messages[1].(*mockMessage).from
   147  	if got[0].from != originalFrom {
   148  		ms.T().Fatalf("Invalid message received. Expected %s, Got %s", originalFrom, got[0].from)
   149  	}
   150  
   151  	// Check that the first message performed a backoff
   152  	backoffCount := messages[0].(*mockMessage).backoffCount
   153  	if backoffCount != expectedCount {
   154  		ms.T().Fatalf("Did not receive expected backoff. Got backoffCount %d, Expected %d", backoffCount, expectedCount)
   155  	}
   156  
   157  	// Check that there was a reset performed on the sender
   158  	if sender.resetCount != expectedCount {
   159  		ms.T().Fatalf("Did not receive expected reset. Got resetCount %d, expected %d", sender.resetCount, expectedCount)
   160  	}
   161  }
   162  
   163  func (ms *MailerSuite) TestPermError() {
   164  	ctx, cancel := context.WithCancel(context.Background())
   165  	defer cancel()
   166  
   167  	mw := NewMailWorker()
   168  	go func(ctx context.Context) {
   169  		mw.Start(ctx)
   170  	}(ctx)
   171  
   172  	expectedError := &textproto.Error{
   173  		Code: 500,
   174  		Msg:  "Permanent error",
   175  	}
   176  
   177  	sender := newMockErrorSender(expectedError)
   178  	dialer := newMockDialer()
   179  	dialer.setDial(func() (Sender, error) {
   180  		return sender, nil
   181  	})
   182  
   183  	messages := generateMessages(dialer)
   184  
   185  	// Send the campaign
   186  	mw.Queue(messages)
   187  
   188  	got := []*mockMessage{}
   189  
   190  	for message := range sender.messageChan {
   191  		got = append(got, message)
   192  	}
   193  	// Check that we only sent one message
   194  	expectedCount := 1
   195  	if len(got) != expectedCount {
   196  		ms.T().Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), expectedCount)
   197  	}
   198  
   199  	// Check that it's the correct message
   200  	originalFrom := messages[1].(*mockMessage).from
   201  	if got[0].from != originalFrom {
   202  		ms.T().Fatalf("Invalid message received. Expected %s, Got %s", originalFrom, got[0].from)
   203  	}
   204  
   205  	message := messages[0].(*mockMessage)
   206  
   207  	// Check that the first message did not perform a backoff
   208  	expectedBackoffCount := 0
   209  	backoffCount := message.backoffCount
   210  	if backoffCount != expectedBackoffCount {
   211  		ms.T().Fatalf("Did not receive expected backoff. Got backoffCount %d, Expected %d", backoffCount, expectedCount)
   212  	}
   213  
   214  	// Check that there was a reset performed on the sender
   215  	if sender.resetCount != expectedCount {
   216  		ms.T().Fatalf("Did not receive expected reset. Got resetCount %d, expected %d", sender.resetCount, expectedCount)
   217  	}
   218  
   219  	// Check that the email errored out appropriately
   220  	if !reflect.DeepEqual(message.err, expectedError) {
   221  		ms.T().Fatalf("Did not received expected error. Got %#v\nExpected %#v", message.err, expectedError)
   222  	}
   223  }
   224  
   225  func (ms *MailerSuite) TestUnknownError() {
   226  	ctx, cancel := context.WithCancel(context.Background())
   227  	defer cancel()
   228  
   229  	mw := NewMailWorker()
   230  	go func(ctx context.Context) {
   231  		mw.Start(ctx)
   232  	}(ctx)
   233  
   234  	expectedError := errors.New("Unexpected error")
   235  
   236  	sender := newMockErrorSender(expectedError)
   237  	dialer := newMockDialer()
   238  	dialer.setDial(func() (Sender, error) {
   239  		return sender, nil
   240  	})
   241  
   242  	messages := generateMessages(dialer)
   243  
   244  	// Send the campaign
   245  	mw.Queue(messages)
   246  
   247  	got := []*mockMessage{}
   248  
   249  	for message := range sender.messageChan {
   250  		got = append(got, message)
   251  	}
   252  	// Check that we only sent one message
   253  	expectedCount := 1
   254  	if len(got) != expectedCount {
   255  		ms.T().Fatalf("Unexpected number of messages received. Expected %d Got %d", len(got), expectedCount)
   256  	}
   257  
   258  	// Check that it's the correct message
   259  	originalFrom := messages[1].(*mockMessage).from
   260  	if got[0].from != originalFrom {
   261  		ms.T().Fatalf("Invalid message received. Expected %s, Got %s", originalFrom, got[0].from)
   262  	}
   263  
   264  	message := messages[0].(*mockMessage)
   265  
   266  	// If we get an unexpected error, this means that it's likely the
   267  	// underlying connection dropped. When this happens, we expect the
   268  	// connection to be re-established (see #997).
   269  	// In this case, we're successfully reestablishing the connection
   270  	// so we expect the backoff to occur.
   271  	expectedBackoffCount := 1
   272  	backoffCount := message.backoffCount
   273  	if backoffCount != expectedBackoffCount {
   274  		ms.T().Fatalf("Did not receive expected backoff. Got backoffCount %d, Expected %d", backoffCount, expectedBackoffCount)
   275  	}
   276  
   277  	// Check that the underlying connection was reestablished
   278  	expectedDialCount := 2
   279  	if dialer.dialCount != expectedDialCount {
   280  		ms.T().Fatalf("Did not receive expected dial count. Got %d expected %d", dialer.dialCount, expectedDialCount)
   281  	}
   282  
   283  	// Check that the email errored out appropriately
   284  	if !reflect.DeepEqual(message.err, expectedError) {
   285  		ms.T().Fatalf("Did not received expected error. Got %#v\nExpected %#v", message.err, expectedError)
   286  	}
   287  }
   288  
   289  func TestMailerSuite(t *testing.T) {
   290  	suite.Run(t, new(MailerSuite))
   291  }