github.com/emersion/go-smtp@v0.20.2/server.go (about)

     1  package smtp
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"io"
     8  	"log"
     9  	"net"
    10  	"os"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/emersion/go-sasl"
    15  )
    16  
    17  var (
    18  	ErrServerClosed = errors.New("smtp: server already closed")
    19  )
    20  
    21  // A function that creates SASL servers.
    22  type SaslServerFactory func(conn *Conn) sasl.Server
    23  
    24  // Logger interface is used by Server to report unexpected internal errors.
    25  type Logger interface {
    26  	Printf(format string, v ...interface{})
    27  	Println(v ...interface{})
    28  }
    29  
    30  // A SMTP server.
    31  type Server struct {
    32  	// The type of network, "tcp" or "unix".
    33  	Network string
    34  	// TCP or Unix address to listen on.
    35  	Addr string
    36  	// The server TLS configuration.
    37  	TLSConfig *tls.Config
    38  	// Enable LMTP mode, as defined in RFC 2033.
    39  	LMTP bool
    40  
    41  	Domain            string
    42  	MaxRecipients     int
    43  	MaxMessageBytes   int64
    44  	MaxLineLength     int
    45  	AllowInsecureAuth bool
    46  	Debug             io.Writer
    47  	ErrorLog          Logger
    48  	ReadTimeout       time.Duration
    49  	WriteTimeout      time.Duration
    50  
    51  	// Advertise SMTPUTF8 (RFC 6531) capability.
    52  	// Should be used only if backend supports it.
    53  	EnableSMTPUTF8 bool
    54  
    55  	// Advertise REQUIRETLS (RFC 8689) capability.
    56  	// Should be used only if backend supports it.
    57  	EnableREQUIRETLS bool
    58  
    59  	// Advertise BINARYMIME (RFC 3030) capability.
    60  	// Should be used only if backend supports it.
    61  	EnableBINARYMIME bool
    62  
    63  	// Advertise DSN (RFC 3461) capability.
    64  	// Should be used only if backend supports it.
    65  	EnableDSN bool
    66  
    67  	// If set, the AUTH command will not be advertised and authentication
    68  	// attempts will be rejected. This setting overrides AllowInsecureAuth.
    69  	AuthDisabled bool
    70  
    71  	// The server backend.
    72  	Backend Backend
    73  
    74  	wg sync.WaitGroup
    75  
    76  	caps  []string
    77  	auths map[string]SaslServerFactory
    78  	done  chan struct{}
    79  
    80  	locker    sync.Mutex
    81  	listeners []net.Listener
    82  	conns     map[*Conn]struct{}
    83  }
    84  
    85  // New creates a new SMTP server.
    86  func NewServer(be Backend) *Server {
    87  	return &Server{
    88  		// Doubled maximum line length per RFC 5321 (Section 4.5.3.1.6)
    89  		MaxLineLength: 2000,
    90  
    91  		Backend:  be,
    92  		done:     make(chan struct{}, 1),
    93  		ErrorLog: log.New(os.Stderr, "smtp/server ", log.LstdFlags),
    94  		caps:     []string{"PIPELINING", "8BITMIME", "ENHANCEDSTATUSCODES", "CHUNKING"},
    95  		auths: map[string]SaslServerFactory{
    96  			sasl.Plain: func(conn *Conn) sasl.Server {
    97  				return sasl.NewPlainServer(func(identity, username, password string) error {
    98  					if identity != "" && identity != username {
    99  						return errors.New("identities not supported")
   100  					}
   101  
   102  					sess := conn.Session()
   103  					if sess == nil {
   104  						panic("No session when AUTH is called")
   105  					}
   106  
   107  					return sess.AuthPlain(username, password)
   108  				})
   109  			},
   110  		},
   111  		conns: make(map[*Conn]struct{}),
   112  	}
   113  }
   114  
   115  // Serve accepts incoming connections on the Listener l.
   116  func (s *Server) Serve(l net.Listener) error {
   117  	s.locker.Lock()
   118  	s.listeners = append(s.listeners, l)
   119  	s.locker.Unlock()
   120  
   121  	var tempDelay time.Duration // how long to sleep on accept failure
   122  
   123  	for {
   124  		c, err := l.Accept()
   125  		if err != nil {
   126  			select {
   127  			case <-s.done:
   128  				// we called Close()
   129  				return nil
   130  			default:
   131  			}
   132  			if ne, ok := err.(net.Error); ok && ne.Temporary() {
   133  				if tempDelay == 0 {
   134  					tempDelay = 5 * time.Millisecond
   135  				} else {
   136  					tempDelay *= 2
   137  				}
   138  				if max := 1 * time.Second; tempDelay > max {
   139  					tempDelay = max
   140  				}
   141  				s.ErrorLog.Printf("accept error: %s; retrying in %s", err, tempDelay)
   142  				time.Sleep(tempDelay)
   143  				continue
   144  			}
   145  			return err
   146  		}
   147  
   148  		s.wg.Add(1)
   149  		go func() {
   150  			defer s.wg.Done()
   151  
   152  			err := s.handleConn(newConn(c, s))
   153  			if err != nil {
   154  				s.ErrorLog.Printf("handler error: %s", err)
   155  			}
   156  		}()
   157  	}
   158  }
   159  
   160  func (s *Server) handleConn(c *Conn) error {
   161  	s.locker.Lock()
   162  	s.conns[c] = struct{}{}
   163  	s.locker.Unlock()
   164  
   165  	defer func() {
   166  		c.Close()
   167  
   168  		s.locker.Lock()
   169  		delete(s.conns, c)
   170  		s.locker.Unlock()
   171  	}()
   172  
   173  	if tlsConn, ok := c.conn.(*tls.Conn); ok {
   174  		if d := s.ReadTimeout; d != 0 {
   175  			c.conn.SetReadDeadline(time.Now().Add(d))
   176  		}
   177  		if d := s.WriteTimeout; d != 0 {
   178  			c.conn.SetWriteDeadline(time.Now().Add(d))
   179  		}
   180  		if err := tlsConn.Handshake(); err != nil {
   181  			return err
   182  		}
   183  	}
   184  
   185  	c.greet()
   186  
   187  	for {
   188  		line, err := c.readLine()
   189  		if err == nil {
   190  			cmd, arg, err := parseCmd(line)
   191  			if err != nil {
   192  				c.protocolError(501, EnhancedCode{5, 5, 2}, "Bad command")
   193  				continue
   194  			}
   195  
   196  			c.handle(cmd, arg)
   197  		} else {
   198  			if err == io.EOF || errors.Is(err, net.ErrClosed) {
   199  				return nil
   200  			}
   201  			if err == ErrTooLongLine {
   202  				c.writeResponse(500, EnhancedCode{5, 4, 0}, "Too long line, closing connection")
   203  				return nil
   204  			}
   205  
   206  			if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
   207  				c.writeResponse(421, EnhancedCode{4, 4, 2}, "Idle timeout, bye bye")
   208  				return nil
   209  			}
   210  
   211  			c.writeResponse(421, EnhancedCode{4, 4, 0}, "Connection error, sorry")
   212  			return err
   213  		}
   214  	}
   215  }
   216  
   217  func (s *Server) network() string {
   218  	if s.Network != "" {
   219  		return s.Network
   220  	}
   221  	if s.LMTP {
   222  		return "unix"
   223  	}
   224  	return "tcp"
   225  }
   226  
   227  // ListenAndServe listens on the network address s.Addr and then calls Serve
   228  // to handle requests on incoming connections.
   229  //
   230  // If s.Addr is blank and LMTP is disabled, ":smtp" is used.
   231  func (s *Server) ListenAndServe() error {
   232  	network := s.network()
   233  
   234  	addr := s.Addr
   235  	if !s.LMTP && addr == "" {
   236  		addr = ":smtp"
   237  	}
   238  
   239  	l, err := net.Listen(network, addr)
   240  	if err != nil {
   241  		return err
   242  	}
   243  
   244  	return s.Serve(l)
   245  }
   246  
   247  // ListenAndServeTLS listens on the TCP network address s.Addr and then calls
   248  // Serve to handle requests on incoming TLS connections.
   249  //
   250  // If s.Addr is blank and LMTP is disabled, ":smtps" is used.
   251  func (s *Server) ListenAndServeTLS() error {
   252  	network := s.network()
   253  
   254  	addr := s.Addr
   255  	if !s.LMTP && addr == "" {
   256  		addr = ":smtps"
   257  	}
   258  
   259  	l, err := tls.Listen(network, addr, s.TLSConfig)
   260  	if err != nil {
   261  		return err
   262  	}
   263  
   264  	return s.Serve(l)
   265  }
   266  
   267  // Close immediately closes all active listeners and connections.
   268  //
   269  // Close returns any error returned from closing the server's underlying
   270  // listener(s).
   271  func (s *Server) Close() error {
   272  	select {
   273  	case <-s.done:
   274  		return ErrServerClosed
   275  	default:
   276  		close(s.done)
   277  	}
   278  
   279  	var err error
   280  	s.locker.Lock()
   281  	for _, l := range s.listeners {
   282  		if lerr := l.Close(); lerr != nil && err == nil {
   283  			err = lerr
   284  		}
   285  	}
   286  
   287  	for conn := range s.conns {
   288  		conn.Close()
   289  	}
   290  	s.locker.Unlock()
   291  
   292  	return err
   293  }
   294  
   295  // Shutdown gracefully shuts down the server without interrupting any
   296  // active connections. Shutdown works by first closing all open
   297  // listeners and then waiting indefinitely for connections to return to
   298  // idle and then shut down.
   299  // If the provided context expires before the shutdown is complete,
   300  // Shutdown returns the context's error, otherwise it returns any
   301  // error returned from closing the Server's underlying Listener(s).
   302  func (s *Server) Shutdown(ctx context.Context) error {
   303  	select {
   304  	case <-s.done:
   305  		return ErrServerClosed
   306  	default:
   307  		close(s.done)
   308  	}
   309  
   310  	var err error
   311  	s.locker.Lock()
   312  	for _, l := range s.listeners {
   313  		if lerr := l.Close(); lerr != nil && err == nil {
   314  			err = lerr
   315  		}
   316  	}
   317  	s.locker.Unlock()
   318  
   319  	connDone := make(chan struct{})
   320  	go func() {
   321  		defer close(connDone)
   322  		s.wg.Wait()
   323  	}()
   324  
   325  	select {
   326  	case <-ctx.Done():
   327  		return ctx.Err()
   328  	case <-connDone:
   329  		return err
   330  	}
   331  }
   332  
   333  // EnableAuth enables an authentication mechanism on this server.
   334  //
   335  // This function should not be called directly, it must only be used by
   336  // libraries implementing extensions of the SMTP protocol.
   337  func (s *Server) EnableAuth(name string, f SaslServerFactory) {
   338  	s.auths[name] = f
   339  }