github.com/go-email-validator/go-email-validator@v0.0.0-20230409163946-b8b9e6a0552e/pkg/ev/evsmtp/proxy_test.go (about) 1 package evsmtp 2 3 import ( 4 "context" 5 "errors" 6 "github.com/go-email-validator/go-email-validator/pkg/ev/evsmtp/smtpclient" 7 "github.com/go-email-validator/go-email-validator/pkg/ev/evtests" 8 mockevsmtp "github.com/go-email-validator/go-email-validator/test/mock/ev/evsmtp" 9 mocknet "github.com/go-email-validator/go-email-validator/test/mock/net" 10 "github.com/golang/mock/gomock" 11 "github.com/stretchr/testify/require" 12 "h12.io/socks" 13 "net" 14 "net/smtp" 15 "os" 16 "reflect" 17 "sync" 18 "syscall" 19 "testing" 20 "time" 21 ) 22 23 const ( 24 proxyURL = "socks5://username:password@127.0.0.1:1080" 25 ) 26 27 var ( 28 errMissingPort = &net.AddrError{Err: "missing port in address", Addr: localhost} 29 ctxBackground = context.Background() 30 ctxBackgroundFunc = func() context.Context { return ctxBackground } 31 ) 32 33 func TestDirectDial(t *testing.T) { 34 type fields struct { 35 server []string 36 } 37 type args struct { 38 ctx func() context.Context 39 addr string 40 proxyURL string 41 } 42 tests := []struct { 43 name string 44 fields fields 45 args args 46 wantClient bool 47 wantErr error 48 }{ 49 { 50 name: "success", 51 fields: fields{ 52 server: []string{ 53 "220 hello world", 54 }, 55 }, 56 args: args{ 57 ctx: ctxBackgroundFunc, 58 proxyURL: "", 59 }, 60 wantClient: true, 61 wantErr: nil, 62 }, 63 { 64 name: "fail port", 65 args: args{ 66 ctx: ctxBackgroundFunc, 67 addr: localhost, 68 proxyURL: "", 69 }, 70 wantClient: false, 71 wantErr: &net.OpError{Op: "dial", Net: "tcp", Err: errMissingPort}, 72 }, 73 { 74 name: "fail", 75 args: args{ 76 ctx: ctxBackgroundFunc, 77 addr: localhost + ":25", 78 proxyURL: "", 79 }, 80 wantClient: false, 81 wantErr: &net.OpError{ 82 Op: "dial", 83 Net: "tcp", 84 Addr: &net.TCPAddr{ 85 IP: net.IPv4(127, 0, 0, 1), 86 Port: 25, 87 Zone: "", 88 }, 89 Err: &os.SyscallError{ 90 Syscall: "connect", 91 Err: syscall.ECONNREFUSED, 92 }, 93 }, 94 }, 95 { 96 name: "expired timeout", 97 args: args{ 98 ctx: func() context.Context { 99 ctx, _ := context.WithTimeout(ctxBackground, 0) 100 return ctx 101 }, 102 addr: localhost + ":25", 103 }, 104 wantClient: false, 105 wantErr: &net.OpError{ 106 Op: "dial", 107 Net: "tcp", 108 Addr: &net.TCPAddr{ 109 IP: net.IPv4(127, 0, 0, 1), 110 Port: 25, 111 Zone: "", 112 }, 113 Err: errors.New("i/o timeout"), 114 }, 115 }, 116 } 117 for _, tt := range tests { 118 t.Run(tt.name, func(t *testing.T) { 119 var done chan string 120 addr := tt.args.addr 121 if len(tt.fields.server) > 0 { 122 addr, done = mockevsmtp.Server(t, tt.fields.server, time.Second, "", false) 123 } 124 125 gotClient, err := DirectDial(tt.args.ctx(), addr, tt.args.proxyURL) 126 if len(tt.fields.server) > 0 { 127 <-done 128 if gotClient != nil { 129 gotClient.Quit() 130 } 131 } 132 133 var errStr string 134 if errOp, ok := err.(*net.OpError); ok && errOp.Err != nil { 135 errStr = errOp.Err.Error() 136 errOp.Err = nil 137 } 138 var wantErrStr string 139 wantErrOp, ok := tt.wantErr.(*net.OpError) 140 if ok && wantErrOp.Err != nil { 141 wantErrStr = wantErrOp.Err.Error() 142 wantErrOp.Err = nil 143 } 144 if !reflect.DeepEqual(err, tt.wantErr) && errStr != wantErrStr { 145 t.Errorf("DirectDial() error = %v, wantErr %v", err, tt.wantErr) 146 return 147 } 148 if (gotClient == nil) == tt.wantClient { 149 t.Errorf("DirectDial() got = %v, want %v", gotClient, tt.wantClient) 150 } 151 }) 152 } 153 } 154 155 func TestH12IODial(t *testing.T) { 156 evtests.FunctionalSkip(t) 157 158 ctrl := gomock.NewController(t) 159 defer ctrl.Finish() 160 161 defer func() { 162 smtpNewClient = smtp.NewClient 163 h12ioDial = socks.Dial 164 }() 165 var cancel context.CancelFunc 166 var wg sync.WaitGroup 167 168 type fields struct { 169 server []string 170 dial func(proxyURI string) func(string, string) (net.Conn, error) 171 smtpNewClient func(conn net.Conn, host string) (*smtp.Client, error) 172 } 173 type args struct { 174 ctx func() context.Context 175 addr string 176 proxyURL string 177 } 178 tests := []struct { 179 name string 180 fields fields 181 args args 182 wantClient bool 183 wantErr error 184 }{ 185 { 186 name: "success", 187 fields: fields{ 188 server: []string{ 189 "220 hello world", 190 }, 191 smtpNewClient: smtp.NewClient, 192 }, 193 args: args{ 194 ctx: ctxBackgroundFunc, 195 proxyURL: proxyURL, 196 }, 197 wantClient: true, 198 wantErr: nil, 199 }, 200 { 201 name: "faild proxy connection", 202 fields: fields{ 203 smtpNewClient: smtp.NewClient, 204 }, 205 args: args{ 206 ctx: ctxBackgroundFunc, 207 proxyURL: "asd", 208 }, 209 wantClient: false, 210 wantErr: errors.New("unknown SOCKS protocol "), 211 }, 212 { 213 name: "expired timeout in connection", 214 fields: fields{ 215 smtpNewClient: smtp.NewClient, 216 }, 217 args: args{ 218 ctx: func() context.Context { 219 ctx, _ := context.WithTimeout(ctxBackground, 0) 220 return ctx 221 }, 222 proxyURL: "asd", 223 }, 224 wantClient: false, 225 wantErr: context.DeadlineExceeded, 226 }, 227 { 228 name: "expired timeout smtp connection", 229 fields: fields{ 230 dial: func(proxyURI string) func(string, string) (net.Conn, error) { 231 wg.Add(1) 232 233 return func(s string, s2 string) (net.Conn, error) { 234 mock := mocknet.NewMockConn(ctrl) 235 mock.EXPECT().Close().Do(func() { 236 wg.Done() 237 }).Times(1) 238 239 return mock, nil 240 } 241 }, 242 smtpNewClient: func(conn net.Conn, host string) (*smtp.Client, error) { 243 cancel() 244 time.Sleep(1 * time.Millisecond) 245 return &smtp.Client{}, nil 246 }, 247 }, 248 args: args{ 249 ctx: func() context.Context { 250 var ctx context.Context 251 ctx, cancel = context.WithTimeout(ctxBackground, 1*time.Second) 252 return ctx 253 }, 254 addr: localhost + ":25", 255 proxyURL: proxyURL, 256 }, 257 wantClient: false, 258 wantErr: context.Canceled, 259 }, 260 } 261 for _, tt := range tests { 262 t.Run(tt.name, func(t *testing.T) { 263 smtpNewClient = tt.fields.smtpNewClient 264 if tt.fields.dial != nil { 265 h12ioDial = tt.fields.dial 266 } 267 var done chan string 268 addr := tt.args.addr 269 if len(tt.fields.server) > 0 { 270 addr, done = mockevsmtp.Server(t, tt.fields.server, 1*time.Second, "", false) 271 addr = localIP() + addr[4:] 272 } 273 274 gotClient, err := H12IODial(tt.args.ctx(), addr, tt.args.proxyURL) 275 if len(tt.fields.server) > 0 { 276 <-done 277 if gotClient != nil { 278 gotClient.Quit() 279 } 280 } 281 if !reflect.DeepEqual(err, tt.wantErr) { 282 t.Errorf("H12IODial() error = %v, wantErr %v", err, tt.wantErr) 283 return 284 } 285 if (gotClient == nil) == tt.wantClient { 286 t.Errorf("H12IODial() got = %v, want %v", gotClient, tt.wantClient) 287 } 288 wg.Wait() 289 }) 290 } 291 } 292 293 func TestH12IODial_Direct(t *testing.T) { 294 wantAddr := localhost 295 wantProxyURL := "" 296 var wantErr error = nil 297 wantCtx := context.Background() 298 directDial = func(ctx context.Context, addr, proxyURL string) (smtpclient.SMTPClient, error) { 299 require.Equal(t, wantCtx, ctx) 300 require.Equal(t, wantAddr, addr) 301 require.Equal(t, wantProxyURL, proxyURL) 302 303 return nil, nil 304 } 305 got, err := H12IODial(wantCtx, wantAddr, wantProxyURL) 306 directDial = DirectDial 307 308 if !reflect.DeepEqual(err, wantErr) { 309 t.Errorf("H12IODial() error = %v, wantErr %v", err, wantErr) 310 return 311 } 312 313 if got != nil { 314 t.Errorf("H12IODial() should not be null") 315 } 316 }