github.com/zhongdalu/gf@v1.0.0/g/net/ghttp/ghttp_server_graceful.go (about)

     1  // Copyright 2018 gf Author(https://github.com/zhongdalu/gf). All Rights Reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the MIT License.
     4  // If a copy of the MIT was not distributed with this file,
     5  // You can obtain one at https://github.com/zhongdalu/gf.
     6  
     7  package ghttp
     8  
     9  import (
    10  	"context"
    11  	"crypto/tls"
    12  	"errors"
    13  	"fmt"
    14  	"github.com/zhongdalu/gf/g/os/glog"
    15  	"github.com/zhongdalu/gf/g/os/gproc"
    16  	"net"
    17  	"net/http"
    18  	"os"
    19  	"time"
    20  )
    21  
    22  // 优雅的Web Server对象封装
    23  type gracefulServer struct {
    24  	fd          uintptr      // 热重启时传递的socket监听文件句柄
    25  	addr        string       // 监听地址信息
    26  	httpServer  *http.Server // 底层http.Server
    27  	rawListener net.Listener // 原始listener
    28  	listener    net.Listener // 接口化封装的listener
    29  	isHttps     bool         // 是否HTTPS
    30  	status      int          // 当前Server状态(关闭/运行)
    31  }
    32  
    33  // 创建一个优雅的Http Server
    34  func (s *Server) newGracefulServer(addr string, fd ...int) *gracefulServer {
    35  	gs := &gracefulServer{
    36  		addr:       addr,
    37  		httpServer: s.newHttpServer(addr),
    38  	}
    39  	// 是否有继承的文件描述符
    40  	if len(fd) > 0 && fd[0] > 0 {
    41  		gs.fd = uintptr(fd[0])
    42  	}
    43  	return gs
    44  }
    45  
    46  // 生成一个底层的Web Server对象
    47  func (s *Server) newHttpServer(addr string) *http.Server {
    48  	server := &http.Server{
    49  		Addr:           addr,
    50  		Handler:        s.config.Handler,
    51  		ReadTimeout:    s.config.ReadTimeout,
    52  		WriteTimeout:   s.config.WriteTimeout,
    53  		IdleTimeout:    s.config.IdleTimeout,
    54  		MaxHeaderBytes: s.config.MaxHeaderBytes,
    55  	}
    56  	server.SetKeepAlivesEnabled(s.config.KeepAlive)
    57  	return server
    58  }
    59  
    60  // 执行HTTP监听
    61  func (s *gracefulServer) ListenAndServe() error {
    62  	addr := s.httpServer.Addr
    63  	ln, err := s.getNetListener(addr)
    64  	if err != nil {
    65  		return err
    66  	}
    67  	s.listener = ln
    68  	s.rawListener = ln
    69  	return s.doServe()
    70  }
    71  
    72  // 获得文件描述符
    73  func (s *gracefulServer) Fd() uintptr {
    74  	if s.rawListener != nil {
    75  		file, err := s.rawListener.(*net.TCPListener).File()
    76  		if err == nil {
    77  			return file.Fd()
    78  		}
    79  	}
    80  	return 0
    81  }
    82  
    83  // 设置自定义fd
    84  func (s *gracefulServer) setFd(fd int) {
    85  	s.fd = uintptr(fd)
    86  }
    87  
    88  // 执行HTTPS监听
    89  func (s *gracefulServer) ListenAndServeTLS(certFile, keyFile string, tlsConfig ...*tls.Config) error {
    90  	addr := s.httpServer.Addr
    91  	config := (*tls.Config)(nil)
    92  	if len(tlsConfig) > 0 {
    93  		config = tlsConfig[0]
    94  	} else if s.httpServer.TLSConfig != nil {
    95  		*config = *s.httpServer.TLSConfig
    96  	}
    97  	if config.NextProtos == nil {
    98  		config.NextProtos = []string{"http/1.1"}
    99  	}
   100  	err := error(nil)
   101  	if len(config.Certificates) == 0 {
   102  		config.Certificates = make([]tls.Certificate, 1)
   103  		config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
   104  	}
   105  	if err != nil {
   106  		return errors.New(fmt.Sprintf(`open cert file "%s","%s" failed: %s`, certFile, keyFile, err.Error()))
   107  	}
   108  	ln, err := s.getNetListener(addr)
   109  	if err != nil {
   110  		return err
   111  	}
   112  
   113  	s.listener = tls.NewListener(ln, config)
   114  	s.rawListener = ln
   115  	return s.doServe()
   116  }
   117  
   118  // 获取服务协议字符串
   119  func (s *gracefulServer) getProto() string {
   120  	proto := "http"
   121  	if s.isHttps {
   122  		proto = "https"
   123  	}
   124  	return proto
   125  }
   126  
   127  // 开始执行Web Server服务处理
   128  func (s *gracefulServer) doServe() error {
   129  	action := "started"
   130  	if s.fd != 0 {
   131  		action = "reloaded"
   132  	}
   133  	glog.Printf("%d: %s server %s listening on [%s]", gproc.Pid(), s.getProto(), action, s.addr)
   134  	s.status = SERVER_STATUS_RUNNING
   135  	err := s.httpServer.Serve(s.listener)
   136  	s.status = SERVER_STATUS_STOPPED
   137  	return err
   138  }
   139  
   140  // 自定义的net.Listener
   141  func (s *gracefulServer) getNetListener(addr string) (net.Listener, error) {
   142  	var ln net.Listener
   143  	var err error
   144  	if s.fd > 0 {
   145  		f := os.NewFile(s.fd, "")
   146  		ln, err = net.FileListener(f)
   147  		if err != nil {
   148  			err = fmt.Errorf("%d: net.FileListener error: %v", gproc.Pid(), err)
   149  			return nil, err
   150  		}
   151  	} else {
   152  		// 如果监听失败,1秒后重试,最多重试3次
   153  		for i := 0; i < 3; i++ {
   154  			ln, err = net.Listen("tcp", addr)
   155  			if err != nil {
   156  				err = fmt.Errorf("%d: net.Listen error: %v", gproc.Pid(), err)
   157  				time.Sleep(time.Second)
   158  			} else {
   159  				err = nil
   160  				break
   161  			}
   162  		}
   163  		if err != nil {
   164  			return nil, err
   165  		}
   166  	}
   167  	return ln, nil
   168  }
   169  
   170  // 执行请求优雅关闭
   171  func (s *gracefulServer) shutdown() {
   172  	if s.status == SERVER_STATUS_STOPPED {
   173  		return
   174  	}
   175  	if err := s.httpServer.Shutdown(context.Background()); err != nil {
   176  		glog.Errorf("%d: %s server [%s] shutdown error: %v", gproc.Pid(), s.getProto(), s.addr, err)
   177  	}
   178  }
   179  
   180  // 执行请求强制关闭
   181  func (s *gracefulServer) close() {
   182  	if s.status == SERVER_STATUS_STOPPED {
   183  		return
   184  	}
   185  	if err := s.httpServer.Close(); err != nil {
   186  		glog.Errorf("%d: %s server [%s] closed error: %v", gproc.Pid(), s.getProto(), s.addr, err)
   187  	}
   188  }