gitee.com/sasukebo/go-micro/v4@v4.7.1/web/service.go (about)

     1  package web
     2  
     3  import (
     4  	"crypto/tls"
     5  	"net"
     6  	"net/http"
     7  	"os"
     8  	"os/signal"
     9  	"path/filepath"
    10  	"runtime"
    11  	"strings"
    12  	"sync"
    13  	"time"
    14  
    15  	"gitee.com/sasukebo/go-micro/v4"
    16  	"gitee.com/sasukebo/go-micro/v4/logger"
    17  	"gitee.com/sasukebo/go-micro/v4/registry"
    18  	maddr "gitee.com/sasukebo/go-micro/v4/util/addr"
    19  	"gitee.com/sasukebo/go-micro/v4/util/backoff"
    20  	mhttp "gitee.com/sasukebo/go-micro/v4/util/http"
    21  	mnet "gitee.com/sasukebo/go-micro/v4/util/net"
    22  	signalutil "gitee.com/sasukebo/go-micro/v4/util/signal"
    23  	mls "gitee.com/sasukebo/go-micro/v4/util/tls"
    24  	"github.com/urfave/cli/v2"
    25  )
    26  
    27  type service struct {
    28  	opts Options
    29  
    30  	mux *http.ServeMux
    31  	srv *registry.Service
    32  
    33  	sync.RWMutex
    34  	running bool
    35  	static  bool
    36  	exit    chan chan error
    37  }
    38  
    39  func newService(opts ...Option) Service {
    40  	options := newOptions(opts...)
    41  	s := &service{
    42  		opts:   options,
    43  		mux:    http.NewServeMux(),
    44  		static: true,
    45  	}
    46  	s.srv = s.genSrv()
    47  	return s
    48  }
    49  
    50  func (s *service) genSrv() *registry.Service {
    51  	var host string
    52  	var port string
    53  	var err error
    54  
    55  	// default host:port
    56  	if len(s.opts.Address) > 0 {
    57  		host, port, err = net.SplitHostPort(s.opts.Address)
    58  		if err != nil {
    59  			logger.Fatal(err)
    60  		}
    61  	}
    62  
    63  	// check the advertise address first
    64  	// if it exists then use it, otherwise
    65  	// use the address
    66  	if len(s.opts.Advertise) > 0 {
    67  		host, port, err = net.SplitHostPort(s.opts.Advertise)
    68  		if err != nil {
    69  			logger.Fatal(err)
    70  		}
    71  	}
    72  
    73  	addr, err := maddr.Extract(host)
    74  	if err != nil {
    75  		logger.Fatal(err)
    76  	}
    77  
    78  	if strings.Count(addr, ":") > 0 {
    79  		addr = "[" + addr + "]"
    80  	}
    81  
    82  	return &registry.Service{
    83  		Name:    s.opts.Name,
    84  		Version: s.opts.Version,
    85  		Nodes: []*registry.Node{{
    86  			Id:       s.opts.Id,
    87  			Address:  net.JoinHostPort(addr, port),
    88  			Metadata: s.opts.Metadata,
    89  		}},
    90  	}
    91  }
    92  
    93  func (s *service) run(exit chan bool) {
    94  	s.RLock()
    95  	if s.opts.RegisterInterval <= time.Duration(0) {
    96  		s.RUnlock()
    97  		return
    98  	}
    99  
   100  	t := time.NewTicker(s.opts.RegisterInterval)
   101  	s.RUnlock()
   102  
   103  	for {
   104  		select {
   105  		case <-t.C:
   106  			s.register()
   107  		case <-exit:
   108  			t.Stop()
   109  			return
   110  		}
   111  	}
   112  }
   113  
   114  func (s *service) register() error {
   115  	s.Lock()
   116  	defer s.Unlock()
   117  
   118  	if s.srv == nil {
   119  		return nil
   120  	}
   121  	// default to service registry
   122  	r := s.opts.Service.Client().Options().Registry
   123  	// switch to option if specified
   124  	if s.opts.Registry != nil {
   125  		r = s.opts.Registry
   126  	}
   127  
   128  	// service node need modify, node address maybe changed
   129  	srv := s.genSrv()
   130  	srv.Endpoints = s.srv.Endpoints
   131  	s.srv = srv
   132  
   133  	// use RegisterCheck func before register
   134  	if err := s.opts.RegisterCheck(s.opts.Context); err != nil {
   135  		if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   136  			logger.Errorf("Server %s-%s register check error: %s", s.opts.Name, s.opts.Id, err)
   137  		}
   138  		return err
   139  	}
   140  
   141  	var regErr error
   142  
   143  	// try three times if necessary
   144  	for i := 0; i < 3; i++ {
   145  		// attempt to register
   146  		if err := r.Register(s.srv, registry.RegisterTTL(s.opts.RegisterTTL)); err != nil {
   147  			// set the error
   148  			regErr = err
   149  			// backoff then retry
   150  			time.Sleep(backoff.Do(i + 1))
   151  			continue
   152  		}
   153  		// success so nil error
   154  		regErr = nil
   155  		break
   156  	}
   157  
   158  	return regErr
   159  }
   160  
   161  func (s *service) deregister() error {
   162  	s.Lock()
   163  	defer s.Unlock()
   164  
   165  	if s.srv == nil {
   166  		return nil
   167  	}
   168  	// default to service registry
   169  	r := s.opts.Service.Client().Options().Registry
   170  	// switch to option if specified
   171  	if s.opts.Registry != nil {
   172  		r = s.opts.Registry
   173  	}
   174  	return r.Deregister(s.srv)
   175  }
   176  
   177  func (s *service) start() error {
   178  	s.Lock()
   179  	defer s.Unlock()
   180  
   181  	if s.running {
   182  		return nil
   183  	}
   184  
   185  	for _, fn := range s.opts.BeforeStart {
   186  		if err := fn(); err != nil {
   187  			return err
   188  		}
   189  	}
   190  
   191  	l, err := s.listen("tcp", s.opts.Address)
   192  	if err != nil {
   193  		return err
   194  	}
   195  
   196  	s.opts.Address = l.Addr().String()
   197  	srv := s.genSrv()
   198  	srv.Endpoints = s.srv.Endpoints
   199  	s.srv = srv
   200  
   201  	var h http.Handler
   202  
   203  	if s.opts.Handler != nil {
   204  		h = s.opts.Handler
   205  	} else {
   206  		h = s.mux
   207  		var r sync.Once
   208  
   209  		// register the html dir
   210  		r.Do(func() {
   211  			// static dir
   212  			static := s.opts.StaticDir
   213  			if s.opts.StaticDir[0] != '/' {
   214  				dir, _ := os.Getwd()
   215  				static = filepath.Join(dir, static)
   216  			}
   217  
   218  			// set static if no / handler is registered
   219  			if s.static {
   220  				_, err := os.Stat(static)
   221  				if err == nil {
   222  					if logger.V(logger.InfoLevel, logger.DefaultLogger) {
   223  						logger.Infof("Enabling static file serving from %s", static)
   224  					}
   225  					s.mux.Handle("/", http.FileServer(http.Dir(static)))
   226  				}
   227  			}
   228  		})
   229  	}
   230  
   231  	var httpSrv *http.Server
   232  	if s.opts.Server != nil {
   233  		httpSrv = s.opts.Server
   234  	} else {
   235  		httpSrv = &http.Server{}
   236  	}
   237  
   238  	httpSrv.Handler = h
   239  
   240  	go httpSrv.Serve(l)
   241  
   242  	for _, fn := range s.opts.AfterStart {
   243  		if err := fn(); err != nil {
   244  			return err
   245  		}
   246  	}
   247  
   248  	s.exit = make(chan chan error, 1)
   249  	s.running = true
   250  
   251  	go func() {
   252  		ch := <-s.exit
   253  		ch <- l.Close()
   254  	}()
   255  
   256  	if logger.V(logger.InfoLevel, logger.DefaultLogger) {
   257  		logger.Infof("Listening on %v", l.Addr().String())
   258  	}
   259  	return nil
   260  }
   261  
   262  func (s *service) stop() error {
   263  	s.Lock()
   264  	defer s.Unlock()
   265  
   266  	if !s.running {
   267  		return nil
   268  	}
   269  
   270  	for _, fn := range s.opts.BeforeStop {
   271  		if err := fn(); err != nil {
   272  			return err
   273  		}
   274  	}
   275  
   276  	ch := make(chan error, 1)
   277  	s.exit <- ch
   278  	s.running = false
   279  
   280  	if logger.V(logger.InfoLevel, logger.DefaultLogger) {
   281  		logger.Info("Stopping")
   282  	}
   283  
   284  	for _, fn := range s.opts.AfterStop {
   285  		if err := fn(); err != nil {
   286  			if chErr := <-ch; chErr != nil {
   287  				return chErr
   288  			}
   289  			return err
   290  		}
   291  	}
   292  
   293  	return <-ch
   294  }
   295  
   296  func (s *service) Client() *http.Client {
   297  	rt := mhttp.NewRoundTripper(
   298  		mhttp.WithRegistry(s.opts.Registry),
   299  	)
   300  	return &http.Client{
   301  		Transport: rt,
   302  	}
   303  }
   304  
   305  func (s *service) Handle(pattern string, handler http.Handler) {
   306  	var seen bool
   307  	s.RLock()
   308  	for _, ep := range s.srv.Endpoints {
   309  		if ep.Name == pattern {
   310  			seen = true
   311  			break
   312  		}
   313  	}
   314  	s.RUnlock()
   315  
   316  	// if its unseen then add an endpoint
   317  	if !seen {
   318  		s.Lock()
   319  		s.srv.Endpoints = append(s.srv.Endpoints, &registry.Endpoint{
   320  			Name: pattern,
   321  		})
   322  		s.Unlock()
   323  	}
   324  
   325  	// disable static serving
   326  	if pattern == "/" {
   327  		s.Lock()
   328  		s.static = false
   329  		s.Unlock()
   330  	}
   331  
   332  	// register the handler
   333  	s.mux.Handle(pattern, handler)
   334  }
   335  
   336  func (s *service) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
   337  
   338  	var seen bool
   339  	s.RLock()
   340  	for _, ep := range s.srv.Endpoints {
   341  		if ep.Name == pattern {
   342  			seen = true
   343  			break
   344  		}
   345  	}
   346  	s.RUnlock()
   347  
   348  	if !seen {
   349  		s.Lock()
   350  		s.srv.Endpoints = append(s.srv.Endpoints, &registry.Endpoint{
   351  			Name: pattern,
   352  		})
   353  		s.Unlock()
   354  	}
   355  
   356  	// disable static serving
   357  	if pattern == "/" {
   358  		s.Lock()
   359  		s.static = false
   360  		s.Unlock()
   361  	}
   362  
   363  	s.mux.HandleFunc(pattern, handler)
   364  }
   365  
   366  func (s *service) Init(opts ...Option) error {
   367  	s.Lock()
   368  
   369  	for _, o := range opts {
   370  		o(&s.opts)
   371  	}
   372  
   373  	serviceOpts := []micro.Option{}
   374  
   375  	if len(s.opts.Flags) > 0 {
   376  		serviceOpts = append(serviceOpts, micro.Flags(s.opts.Flags...))
   377  	}
   378  
   379  	if s.opts.Registry != nil {
   380  		serviceOpts = append(serviceOpts, micro.Registry(s.opts.Registry))
   381  	}
   382  
   383  	s.Unlock()
   384  
   385  	serviceOpts = append(serviceOpts, micro.Action(func(ctx *cli.Context) error {
   386  		s.Lock()
   387  		defer s.Unlock()
   388  
   389  		if ttl := ctx.Int("register_ttl"); ttl > 0 {
   390  			s.opts.RegisterTTL = time.Duration(ttl) * time.Second
   391  		}
   392  
   393  		if interval := ctx.Int("register_interval"); interval > 0 {
   394  			s.opts.RegisterInterval = time.Duration(interval) * time.Second
   395  		}
   396  
   397  		if name := ctx.String("server_name"); len(name) > 0 {
   398  			s.opts.Name = name
   399  		}
   400  
   401  		if ver := ctx.String("server_version"); len(ver) > 0 {
   402  			s.opts.Version = ver
   403  		}
   404  
   405  		if id := ctx.String("server_id"); len(id) > 0 {
   406  			s.opts.Id = id
   407  		}
   408  
   409  		if addr := ctx.String("server_address"); len(addr) > 0 {
   410  			s.opts.Address = addr
   411  		}
   412  
   413  		if adv := ctx.String("server_advertise"); len(adv) > 0 {
   414  			s.opts.Advertise = adv
   415  		}
   416  
   417  		if s.opts.Action != nil {
   418  			s.opts.Action(ctx)
   419  		}
   420  
   421  		return nil
   422  	}))
   423  
   424  	s.RLock()
   425  	// pass in own name and version
   426  	if s.opts.Service.Name() == "" {
   427  		serviceOpts = append(serviceOpts, micro.Name(s.opts.Name))
   428  	}
   429  	serviceOpts = append(serviceOpts, micro.Version(s.opts.Version))
   430  	s.RUnlock()
   431  
   432  	s.opts.Service.Init(serviceOpts...)
   433  
   434  	s.Lock()
   435  	srv := s.genSrv()
   436  	srv.Endpoints = s.srv.Endpoints
   437  	s.srv = srv
   438  	s.Unlock()
   439  
   440  	return nil
   441  }
   442  
   443  func (s *service) Run() error {
   444  	if err := s.start(); err != nil {
   445  		return err
   446  	}
   447  
   448  	// start the profiler
   449  	if s.opts.Service.Options().Profile != nil {
   450  		// to view mutex contention
   451  		runtime.SetMutexProfileFraction(5)
   452  		// to view blocking profile
   453  		runtime.SetBlockProfileRate(1)
   454  
   455  		if err := s.opts.Service.Options().Profile.Start(); err != nil {
   456  			return err
   457  		}
   458  		defer func() {
   459  			if err := s.opts.Service.Options().Profile.Stop(); err != nil {
   460  				logger.Error(err)
   461  			}
   462  		}()
   463  	}
   464  
   465  	if err := s.register(); err != nil {
   466  		return err
   467  	}
   468  
   469  	// start reg loop
   470  	ex := make(chan bool)
   471  	go s.run(ex)
   472  
   473  	ch := make(chan os.Signal, 1)
   474  	if s.opts.Signal {
   475  		signal.Notify(ch, signalutil.Shutdown()...)
   476  	}
   477  
   478  	select {
   479  	// wait on kill signal
   480  	case sig := <-ch:
   481  		if logger.V(logger.InfoLevel, logger.DefaultLogger) {
   482  			logger.Infof("Received signal %s", sig)
   483  		}
   484  	// wait on context cancel
   485  	case <-s.opts.Context.Done():
   486  		if logger.V(logger.InfoLevel, logger.DefaultLogger) {
   487  			logger.Info("Received context shutdown")
   488  		}
   489  	}
   490  
   491  	// exit reg loop
   492  	close(ex)
   493  
   494  	if err := s.deregister(); err != nil {
   495  		return err
   496  	}
   497  
   498  	return s.stop()
   499  }
   500  
   501  // Options returns the options for the given service
   502  func (s *service) Options() Options {
   503  	return s.opts
   504  }
   505  
   506  func (s *service) listen(network, addr string) (net.Listener, error) {
   507  	var l net.Listener
   508  	var err error
   509  
   510  	// TODO: support use of listen options
   511  	if s.opts.Secure || s.opts.TLSConfig != nil {
   512  		config := s.opts.TLSConfig
   513  
   514  		fn := func(addr string) (net.Listener, error) {
   515  			if config == nil {
   516  				hosts := []string{addr}
   517  
   518  				// check if its a valid host:port
   519  				if host, _, err := net.SplitHostPort(addr); err == nil {
   520  					if len(host) == 0 {
   521  						hosts = maddr.IPs()
   522  					} else {
   523  						hosts = []string{host}
   524  					}
   525  				}
   526  
   527  				// generate a certificate
   528  				cert, err := mls.Certificate(hosts...)
   529  				if err != nil {
   530  					return nil, err
   531  				}
   532  				config = &tls.Config{Certificates: []tls.Certificate{cert}}
   533  			}
   534  			return tls.Listen(network, addr, config)
   535  		}
   536  
   537  		l, err = mnet.Listen(addr, fn)
   538  	} else {
   539  		fn := func(addr string) (net.Listener, error) {
   540  			return net.Listen(network, addr)
   541  		}
   542  
   543  		l, err = mnet.Listen(addr, fn)
   544  	}
   545  
   546  	if err != nil {
   547  		return nil, err
   548  	}
   549  
   550  	return l, nil
   551  }