github.com/go-email-validator/go-email-validator@v0.0.0-20230409163946-b8b9e6a0552e/pkg/ev/evsmtp/proxy_test.go (about)

     1  package evsmtp
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"github.com/go-email-validator/go-email-validator/pkg/ev/evsmtp/smtpclient"
     7  	"github.com/go-email-validator/go-email-validator/pkg/ev/evtests"
     8  	mockevsmtp "github.com/go-email-validator/go-email-validator/test/mock/ev/evsmtp"
     9  	mocknet "github.com/go-email-validator/go-email-validator/test/mock/net"
    10  	"github.com/golang/mock/gomock"
    11  	"github.com/stretchr/testify/require"
    12  	"h12.io/socks"
    13  	"net"
    14  	"net/smtp"
    15  	"os"
    16  	"reflect"
    17  	"sync"
    18  	"syscall"
    19  	"testing"
    20  	"time"
    21  )
    22  
    23  const (
    24  	proxyURL = "socks5://username:password@127.0.0.1:1080"
    25  )
    26  
    27  var (
    28  	errMissingPort    = &net.AddrError{Err: "missing port in address", Addr: localhost}
    29  	ctxBackground     = context.Background()
    30  	ctxBackgroundFunc = func() context.Context { return ctxBackground }
    31  )
    32  
    33  func TestDirectDial(t *testing.T) {
    34  	type fields struct {
    35  		server []string
    36  	}
    37  	type args struct {
    38  		ctx      func() context.Context
    39  		addr     string
    40  		proxyURL string
    41  	}
    42  	tests := []struct {
    43  		name       string
    44  		fields     fields
    45  		args       args
    46  		wantClient bool
    47  		wantErr    error
    48  	}{
    49  		{
    50  			name: "success",
    51  			fields: fields{
    52  				server: []string{
    53  					"220 hello world",
    54  				},
    55  			},
    56  			args: args{
    57  				ctx:      ctxBackgroundFunc,
    58  				proxyURL: "",
    59  			},
    60  			wantClient: true,
    61  			wantErr:    nil,
    62  		},
    63  		{
    64  			name: "fail port",
    65  			args: args{
    66  				ctx:      ctxBackgroundFunc,
    67  				addr:     localhost,
    68  				proxyURL: "",
    69  			},
    70  			wantClient: false,
    71  			wantErr:    &net.OpError{Op: "dial", Net: "tcp", Err: errMissingPort},
    72  		},
    73  		{
    74  			name: "fail",
    75  			args: args{
    76  				ctx:      ctxBackgroundFunc,
    77  				addr:     localhost + ":25",
    78  				proxyURL: "",
    79  			},
    80  			wantClient: false,
    81  			wantErr: &net.OpError{
    82  				Op:  "dial",
    83  				Net: "tcp",
    84  				Addr: &net.TCPAddr{
    85  					IP:   net.IPv4(127, 0, 0, 1),
    86  					Port: 25,
    87  					Zone: "",
    88  				},
    89  				Err: &os.SyscallError{
    90  					Syscall: "connect",
    91  					Err:     syscall.ECONNREFUSED,
    92  				},
    93  			},
    94  		},
    95  		{
    96  			name: "expired timeout",
    97  			args: args{
    98  				ctx: func() context.Context {
    99  					ctx, _ := context.WithTimeout(ctxBackground, 0)
   100  					return ctx
   101  				},
   102  				addr: localhost + ":25",
   103  			},
   104  			wantClient: false,
   105  			wantErr: &net.OpError{
   106  				Op:  "dial",
   107  				Net: "tcp",
   108  				Addr: &net.TCPAddr{
   109  					IP:   net.IPv4(127, 0, 0, 1),
   110  					Port: 25,
   111  					Zone: "",
   112  				},
   113  				Err: errors.New("i/o timeout"),
   114  			},
   115  		},
   116  	}
   117  	for _, tt := range tests {
   118  		t.Run(tt.name, func(t *testing.T) {
   119  			var done chan string
   120  			addr := tt.args.addr
   121  			if len(tt.fields.server) > 0 {
   122  				addr, done = mockevsmtp.Server(t, tt.fields.server, time.Second, "", false)
   123  			}
   124  
   125  			gotClient, err := DirectDial(tt.args.ctx(), addr, tt.args.proxyURL)
   126  			if len(tt.fields.server) > 0 {
   127  				<-done
   128  				if gotClient != nil {
   129  					gotClient.Quit()
   130  				}
   131  			}
   132  
   133  			var errStr string
   134  			if errOp, ok := err.(*net.OpError); ok && errOp.Err != nil {
   135  				errStr = errOp.Err.Error()
   136  				errOp.Err = nil
   137  			}
   138  			var wantErrStr string
   139  			wantErrOp, ok := tt.wantErr.(*net.OpError)
   140  			if ok && wantErrOp.Err != nil {
   141  				wantErrStr = wantErrOp.Err.Error()
   142  				wantErrOp.Err = nil
   143  			}
   144  			if !reflect.DeepEqual(err, tt.wantErr) && errStr != wantErrStr {
   145  				t.Errorf("DirectDial() error = %v, wantErr %v", err, tt.wantErr)
   146  				return
   147  			}
   148  			if (gotClient == nil) == tt.wantClient {
   149  				t.Errorf("DirectDial() got = %v, want %v", gotClient, tt.wantClient)
   150  			}
   151  		})
   152  	}
   153  }
   154  
   155  func TestH12IODial(t *testing.T) {
   156  	evtests.FunctionalSkip(t)
   157  
   158  	ctrl := gomock.NewController(t)
   159  	defer ctrl.Finish()
   160  
   161  	defer func() {
   162  		smtpNewClient = smtp.NewClient
   163  		h12ioDial = socks.Dial
   164  	}()
   165  	var cancel context.CancelFunc
   166  	var wg sync.WaitGroup
   167  
   168  	type fields struct {
   169  		server        []string
   170  		dial          func(proxyURI string) func(string, string) (net.Conn, error)
   171  		smtpNewClient func(conn net.Conn, host string) (*smtp.Client, error)
   172  	}
   173  	type args struct {
   174  		ctx      func() context.Context
   175  		addr     string
   176  		proxyURL string
   177  	}
   178  	tests := []struct {
   179  		name       string
   180  		fields     fields
   181  		args       args
   182  		wantClient bool
   183  		wantErr    error
   184  	}{
   185  		{
   186  			name: "success",
   187  			fields: fields{
   188  				server: []string{
   189  					"220 hello world",
   190  				},
   191  				smtpNewClient: smtp.NewClient,
   192  			},
   193  			args: args{
   194  				ctx:      ctxBackgroundFunc,
   195  				proxyURL: proxyURL,
   196  			},
   197  			wantClient: true,
   198  			wantErr:    nil,
   199  		},
   200  		{
   201  			name: "faild proxy connection",
   202  			fields: fields{
   203  				smtpNewClient: smtp.NewClient,
   204  			},
   205  			args: args{
   206  				ctx:      ctxBackgroundFunc,
   207  				proxyURL: "asd",
   208  			},
   209  			wantClient: false,
   210  			wantErr:    errors.New("unknown SOCKS protocol "),
   211  		},
   212  		{
   213  			name: "expired timeout in connection",
   214  			fields: fields{
   215  				smtpNewClient: smtp.NewClient,
   216  			},
   217  			args: args{
   218  				ctx: func() context.Context {
   219  					ctx, _ := context.WithTimeout(ctxBackground, 0)
   220  					return ctx
   221  				},
   222  				proxyURL: "asd",
   223  			},
   224  			wantClient: false,
   225  			wantErr:    context.DeadlineExceeded,
   226  		},
   227  		{
   228  			name: "expired timeout smtp connection",
   229  			fields: fields{
   230  				dial: func(proxyURI string) func(string, string) (net.Conn, error) {
   231  					wg.Add(1)
   232  
   233  					return func(s string, s2 string) (net.Conn, error) {
   234  						mock := mocknet.NewMockConn(ctrl)
   235  						mock.EXPECT().Close().Do(func() {
   236  							wg.Done()
   237  						}).Times(1)
   238  
   239  						return mock, nil
   240  					}
   241  				},
   242  				smtpNewClient: func(conn net.Conn, host string) (*smtp.Client, error) {
   243  					cancel()
   244  					time.Sleep(1 * time.Millisecond)
   245  					return &smtp.Client{}, nil
   246  				},
   247  			},
   248  			args: args{
   249  				ctx: func() context.Context {
   250  					var ctx context.Context
   251  					ctx, cancel = context.WithTimeout(ctxBackground, 1*time.Second)
   252  					return ctx
   253  				},
   254  				addr:     localhost + ":25",
   255  				proxyURL: proxyURL,
   256  			},
   257  			wantClient: false,
   258  			wantErr:    context.Canceled,
   259  		},
   260  	}
   261  	for _, tt := range tests {
   262  		t.Run(tt.name, func(t *testing.T) {
   263  			smtpNewClient = tt.fields.smtpNewClient
   264  			if tt.fields.dial != nil {
   265  				h12ioDial = tt.fields.dial
   266  			}
   267  			var done chan string
   268  			addr := tt.args.addr
   269  			if len(tt.fields.server) > 0 {
   270  				addr, done = mockevsmtp.Server(t, tt.fields.server, 1*time.Second, "", false)
   271  				addr = localIP() + addr[4:]
   272  			}
   273  
   274  			gotClient, err := H12IODial(tt.args.ctx(), addr, tt.args.proxyURL)
   275  			if len(tt.fields.server) > 0 {
   276  				<-done
   277  				if gotClient != nil {
   278  					gotClient.Quit()
   279  				}
   280  			}
   281  			if !reflect.DeepEqual(err, tt.wantErr) {
   282  				t.Errorf("H12IODial() error = %v, wantErr %v", err, tt.wantErr)
   283  				return
   284  			}
   285  			if (gotClient == nil) == tt.wantClient {
   286  				t.Errorf("H12IODial() got = %v, want %v", gotClient, tt.wantClient)
   287  			}
   288  			wg.Wait()
   289  		})
   290  	}
   291  }
   292  
   293  func TestH12IODial_Direct(t *testing.T) {
   294  	wantAddr := localhost
   295  	wantProxyURL := ""
   296  	var wantErr error = nil
   297  	wantCtx := context.Background()
   298  	directDial = func(ctx context.Context, addr, proxyURL string) (smtpclient.SMTPClient, error) {
   299  		require.Equal(t, wantCtx, ctx)
   300  		require.Equal(t, wantAddr, addr)
   301  		require.Equal(t, wantProxyURL, proxyURL)
   302  
   303  		return nil, nil
   304  	}
   305  	got, err := H12IODial(wantCtx, wantAddr, wantProxyURL)
   306  	directDial = DirectDial
   307  
   308  	if !reflect.DeepEqual(err, wantErr) {
   309  		t.Errorf("H12IODial() error = %v, wantErr %v", err, wantErr)
   310  		return
   311  	}
   312  
   313  	if got != nil {
   314  		t.Errorf("H12IODial() should not be null")
   315  	}
   316  }