github.com/gophish/gophish@v0.12.2-0.20230915144530-8e7929441393/mailer/mockmailer.go (about)

     1  package mailer
     2  
     3  import (
     4  	"bytes"
     5  	"errors"
     6  	"io"
     7  	"time"
     8  
     9  	"github.com/gophish/gomail"
    10  )
    11  
    12  // errHostUnreachable is a mock error to represent a host
    13  // being unreachable
    14  var errHostUnreachable = errors.New("host unreachable")
    15  
    16  // mockDialer keeps track of calls to Dial
    17  type mockDialer struct {
    18  	dialCount int
    19  	dial      func() (Sender, error)
    20  }
    21  
    22  // newMockDialer returns a new instance of the mockDialer with the default
    23  // dialer set.
    24  func newMockDialer() *mockDialer {
    25  	md := &mockDialer{}
    26  	md.dial = md.defaultDial
    27  	return md
    28  }
    29  
    30  // defaultDial simply returns a mockSender
    31  func (md *mockDialer) defaultDial() (Sender, error) {
    32  	return newMockSender(), nil
    33  }
    34  
    35  // unreachableDial is to simulate network error conditions in which
    36  // a host is unavailable.
    37  func (md *mockDialer) unreachableDial() (Sender, error) {
    38  	return nil, errHostUnreachable
    39  }
    40  
    41  // Dial increments the internal dial count. Otherwise, it's a no-op for the mock client.
    42  func (md *mockDialer) Dial() (Sender, error) {
    43  	md.dialCount++
    44  	return md.dial()
    45  }
    46  
    47  // setDial sets the Dial function for the mockDialer
    48  func (md *mockDialer) setDial(dial func() (Sender, error)) {
    49  	md.dial = dial
    50  }
    51  
    52  // mockSender is a mock gomail.Sender used for testing.
    53  type mockSender struct {
    54  	messages    []*mockMessage
    55  	status      string
    56  	send        func(*mockMessage) error
    57  	messageChan chan *mockMessage
    58  	resetCount  int
    59  }
    60  
    61  func newMockSender() *mockSender {
    62  	ms := &mockSender{
    63  		status:      "ehlo",
    64  		messageChan: make(chan *mockMessage),
    65  	}
    66  	ms.send = ms.defaultSend
    67  	return ms
    68  }
    69  
    70  func (ms *mockSender) setSend(send func(*mockMessage) error) {
    71  	ms.send = send
    72  }
    73  
    74  func (ms *mockSender) defaultSend(mm *mockMessage) error {
    75  	ms.messageChan <- mm
    76  	return nil
    77  }
    78  
    79  // Send just appends the provided message record to the internal slice
    80  func (ms *mockSender) Send(from string, to []string, msg io.WriterTo) error {
    81  	mm := newMockMessage(from, to, msg)
    82  	ms.messages = append(ms.messages, mm)
    83  	ms.status = "sent"
    84  	return ms.send(mm)
    85  }
    86  
    87  // Close is a noop for the mock client
    88  func (ms *mockSender) Close() error {
    89  	ms.status = "closed"
    90  	close(ms.messageChan)
    91  	return nil
    92  }
    93  
    94  // Reset sets the status to "Reset". In practice, this would reset the connection
    95  // to the same state as if the client had just sent an EHLO command.
    96  func (ms *mockSender) Reset() error {
    97  	ms.status = "reset"
    98  	ms.resetCount++
    99  	return nil
   100  }
   101  
   102  // mockMessage holds the information sent via a call to MockClient.Send()
   103  type mockMessage struct {
   104  	from         string
   105  	to           []string
   106  	message      []byte
   107  	sendAt       time.Time
   108  	backoffCount int
   109  	getdialer    func() (Dialer, error)
   110  	err          error
   111  	finished     bool
   112  }
   113  
   114  func newMockMessage(from string, to []string, msg io.WriterTo) *mockMessage {
   115  	buff := &bytes.Buffer{}
   116  	msg.WriteTo(buff)
   117  	mm := &mockMessage{
   118  		from:    from,
   119  		to:      to,
   120  		message: buff.Bytes(),
   121  		sendAt:  time.Now(),
   122  	}
   123  	mm.getdialer = mm.defaultDialer
   124  	return mm
   125  }
   126  
   127  func (mm *mockMessage) setDialer(dialer func() (Dialer, error)) {
   128  	mm.getdialer = dialer
   129  }
   130  
   131  func (mm *mockMessage) defaultDialer() (Dialer, error) {
   132  	return newMockDialer(), nil
   133  }
   134  
   135  func (mm *mockMessage) GetDialer() (Dialer, error) {
   136  	return mm.getdialer()
   137  }
   138  
   139  func (mm *mockMessage) Backoff(reason error) error {
   140  	mm.backoffCount++
   141  	mm.err = reason
   142  	return nil
   143  }
   144  
   145  func (mm *mockMessage) Error(err error) error {
   146  	mm.err = err
   147  	mm.finished = true
   148  	return nil
   149  }
   150  
   151  func (mm *mockMessage) Finish() error {
   152  	mm.finished = true
   153  	return nil
   154  }
   155  
   156  func (mm *mockMessage) Generate(message *gomail.Message) error {
   157  	message.SetHeaders(map[string][]string{
   158  		"From": {mm.from},
   159  		"To":   mm.to,
   160  	})
   161  	message.SetBody("text/html", string(mm.message))
   162  	return nil
   163  }
   164  
   165  func (mm *mockMessage) GetSmtpFrom() (string, error) {
   166  	return mm.from, nil
   167  }
   168  
   169  func (mm *mockMessage) Success() error {
   170  	mm.finished = true
   171  	return nil
   172  }