go-micro.dev/v5@v5.12.0/web/service.go (about)

     1  package web
     2  
     3  import (
     4  	"crypto/tls"
     5  	"net"
     6  	"net/http"
     7  	"os"
     8  	"os/signal"
     9  	"path/filepath"
    10  	"runtime"
    11  	"strings"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/urfave/cli/v2"
    16  	"go-micro.dev/v5"
    17  	log "go-micro.dev/v5/logger"
    18  	"go-micro.dev/v5/registry"
    19  	maddr "go-micro.dev/v5/util/addr"
    20  	"go-micro.dev/v5/util/backoff"
    21  	mhttp "go-micro.dev/v5/util/http"
    22  	mnet "go-micro.dev/v5/util/net"
    23  	signalutil "go-micro.dev/v5/util/signal"
    24  	mls "go-micro.dev/v5/util/tls"
    25  )
    26  
    27  type service struct {
    28  	mux *http.ServeMux
    29  	srv *registry.Service
    30  
    31  	exit chan chan error
    32  	ex   chan bool
    33  	opts Options
    34  
    35  	sync.RWMutex
    36  	running bool
    37  	static  bool
    38  }
    39  
    40  func newService(opts ...Option) Service {
    41  	options := newOptions(opts...)
    42  	s := &service{
    43  		opts:   options,
    44  		mux:    http.NewServeMux(),
    45  		static: true,
    46  		ex:     make(chan bool),
    47  	}
    48  	s.srv = s.genSrv()
    49  
    50  	return s
    51  }
    52  
    53  func (s *service) genSrv() *registry.Service {
    54  	var (
    55  		host string
    56  		port string
    57  		err  error
    58  	)
    59  
    60  	logger := s.opts.Logger
    61  
    62  	// default host:port
    63  	if len(s.opts.Address) > 0 {
    64  		host, port, err = net.SplitHostPort(s.opts.Address)
    65  		if err != nil {
    66  			logger.Log(log.FatalLevel, err)
    67  		}
    68  	}
    69  
    70  	// check the advertise address first
    71  	// if it exists then use it, otherwise
    72  	// use the address
    73  	if len(s.opts.Advertise) > 0 {
    74  		host, port, err = net.SplitHostPort(s.opts.Advertise)
    75  		if err != nil {
    76  			logger.Log(log.FatalLevel, err)
    77  		}
    78  	}
    79  
    80  	addr, err := maddr.Extract(host)
    81  	if err != nil {
    82  		logger.Log(log.FatalLevel, err)
    83  	}
    84  
    85  	if strings.Count(addr, ":") > 0 {
    86  		addr = "[" + addr + "]"
    87  	}
    88  
    89  	return &registry.Service{
    90  		Name:    s.opts.Name,
    91  		Version: s.opts.Version,
    92  		Nodes: []*registry.Node{{
    93  			Id:       s.opts.Id,
    94  			Address:  net.JoinHostPort(addr, port),
    95  			Metadata: s.opts.Metadata,
    96  		}},
    97  	}
    98  }
    99  
   100  func (s *service) run() {
   101  	s.RLock()
   102  	if s.opts.RegisterInterval <= time.Duration(0) {
   103  		s.RUnlock()
   104  		return
   105  	}
   106  
   107  	t := time.NewTicker(s.opts.RegisterInterval)
   108  	s.RUnlock()
   109  
   110  	for {
   111  		select {
   112  		case <-t.C:
   113  			s.register()
   114  		case <-s.ex:
   115  			t.Stop()
   116  			return
   117  		}
   118  	}
   119  }
   120  
   121  func (s *service) register() error {
   122  	s.Lock()
   123  	defer s.Unlock()
   124  
   125  	if s.srv == nil {
   126  		return nil
   127  	}
   128  
   129  	logger := s.opts.Logger
   130  
   131  	// default to service registry
   132  	r := s.opts.Service.Client().Options().Registry
   133  	// switch to option if specified
   134  	if s.opts.Registry != nil {
   135  		r = s.opts.Registry
   136  	}
   137  
   138  	// service node need modify, node address maybe changed
   139  	srv := s.genSrv()
   140  	srv.Endpoints = s.srv.Endpoints
   141  	s.srv = srv
   142  
   143  	// use RegisterCheck func before register
   144  	if err := s.opts.RegisterCheck(s.opts.Context); err != nil {
   145  		logger.Logf(log.ErrorLevel, "Server %s-%s register check error: %s", s.opts.Name, s.opts.Id, err)
   146  		return err
   147  	}
   148  
   149  	var regErr error
   150  
   151  	// try three times if necessary
   152  	for i := 0; i < 3; i++ {
   153  		// attempt to register
   154  		if err := r.Register(s.srv, registry.RegisterTTL(s.opts.RegisterTTL)); err != nil {
   155  			// set the error
   156  			regErr = err
   157  			// backoff then retry
   158  			time.Sleep(backoff.Do(i + 1))
   159  
   160  			continue
   161  		}
   162  		// success so nil error
   163  		regErr = nil
   164  
   165  		break
   166  	}
   167  
   168  	return regErr
   169  }
   170  
   171  func (s *service) deregister() error {
   172  	s.Lock()
   173  	defer s.Unlock()
   174  
   175  	if s.srv == nil {
   176  		return nil
   177  	}
   178  	// default to service registry
   179  	r := s.opts.Service.Client().Options().Registry
   180  	// switch to option if specified
   181  	if s.opts.Registry != nil {
   182  		r = s.opts.Registry
   183  	}
   184  
   185  	return r.Deregister(s.srv)
   186  }
   187  
   188  func (s *service) start() error {
   189  	s.Lock()
   190  	defer s.Unlock()
   191  
   192  	if s.running {
   193  		return nil
   194  	}
   195  
   196  	for _, fn := range s.opts.BeforeStart {
   197  		if err := fn(); err != nil {
   198  			return err
   199  		}
   200  	}
   201  
   202  	listener, err := s.listen("tcp", s.opts.Address)
   203  	if err != nil {
   204  		return err
   205  	}
   206  
   207  	logger := s.opts.Logger
   208  
   209  	s.opts.Address = listener.Addr().String()
   210  	srv := s.genSrv()
   211  	srv.Endpoints = s.srv.Endpoints
   212  	s.srv = srv
   213  
   214  	var handler http.Handler
   215  
   216  	if s.opts.Handler != nil {
   217  		handler = s.opts.Handler
   218  	} else {
   219  		handler = s.mux
   220  		var r sync.Once
   221  
   222  		// register the html dir
   223  		r.Do(func() {
   224  			// static dir
   225  			static := s.opts.StaticDir
   226  			if s.opts.StaticDir[0] != '/' {
   227  				dir, _ := os.Getwd()
   228  				static = filepath.Join(dir, static)
   229  			}
   230  
   231  			// set static if no / handler is registered
   232  			if s.static {
   233  				_, err := os.Stat(static)
   234  				if err == nil {
   235  					logger.Logf(log.InfoLevel, "Enabling static file serving from %s", static)
   236  					s.mux.Handle("/", http.FileServer(http.Dir(static)))
   237  				}
   238  			}
   239  		})
   240  	}
   241  
   242  	var httpSrv *http.Server
   243  	if s.opts.Server != nil {
   244  		httpSrv = s.opts.Server
   245  	} else {
   246  		httpSrv = &http.Server{}
   247  	}
   248  
   249  	httpSrv.Handler = handler
   250  
   251  	go httpSrv.Serve(listener)
   252  
   253  	for _, fn := range s.opts.AfterStart {
   254  		if err := fn(); err != nil {
   255  			return err
   256  		}
   257  	}
   258  
   259  	s.exit = make(chan chan error, 1)
   260  	s.running = true
   261  
   262  	go func() {
   263  		ch := <-s.exit
   264  		ch <- listener.Close()
   265  	}()
   266  
   267  	logger.Logf(log.InfoLevel, "Listening on %v", listener.Addr().String())
   268  
   269  	return nil
   270  }
   271  
   272  func (s *service) stop() error {
   273  	s.Lock()
   274  	defer s.Unlock()
   275  
   276  	if !s.running {
   277  		return nil
   278  	}
   279  
   280  	for _, fn := range s.opts.BeforeStop {
   281  		if err := fn(); err != nil {
   282  			return err
   283  		}
   284  	}
   285  
   286  	ch := make(chan error, 1)
   287  	s.exit <- ch
   288  	s.running = false
   289  
   290  	s.opts.Logger.Log(log.InfoLevel, "Stopping")
   291  
   292  	for _, fn := range s.opts.AfterStop {
   293  		if err := fn(); err != nil {
   294  			if chErr := <-ch; chErr != nil {
   295  				return chErr
   296  			}
   297  
   298  			return err
   299  		}
   300  	}
   301  
   302  	return <-ch
   303  }
   304  
   305  func (s *service) Client() *http.Client {
   306  	rt := mhttp.NewRoundTripper(
   307  		mhttp.WithRegistry(s.opts.Registry),
   308  	)
   309  	return &http.Client{
   310  		Transport: rt,
   311  	}
   312  }
   313  
   314  func (s *service) Handle(pattern string, handler http.Handler) {
   315  	var seen bool
   316  	s.RLock()
   317  	for _, ep := range s.srv.Endpoints {
   318  		if ep.Name == pattern {
   319  			seen = true
   320  			break
   321  		}
   322  	}
   323  	s.RUnlock()
   324  
   325  	// if its unseen then add an endpoint
   326  	if !seen {
   327  		s.Lock()
   328  		s.srv.Endpoints = append(s.srv.Endpoints, &registry.Endpoint{
   329  			Name: pattern,
   330  		})
   331  		s.Unlock()
   332  	}
   333  
   334  	// disable static serving
   335  	if pattern == "/" {
   336  		s.Lock()
   337  		s.static = false
   338  		s.Unlock()
   339  	}
   340  
   341  	// register the handler
   342  	s.mux.Handle(pattern, handler)
   343  }
   344  
   345  func (s *service) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
   346  	var seen bool
   347  
   348  	s.RLock()
   349  	for _, ep := range s.srv.Endpoints {
   350  		if ep.Name == pattern {
   351  			seen = true
   352  			break
   353  		}
   354  	}
   355  	s.RUnlock()
   356  
   357  	if !seen {
   358  		s.Lock()
   359  		s.srv.Endpoints = append(s.srv.Endpoints, &registry.Endpoint{
   360  			Name: pattern,
   361  		})
   362  		s.Unlock()
   363  	}
   364  
   365  	// disable static serving
   366  	if pattern == "/" {
   367  		s.Lock()
   368  		s.static = false
   369  		s.Unlock()
   370  	}
   371  
   372  	s.mux.HandleFunc(pattern, handler)
   373  }
   374  
   375  func (s *service) Init(opts ...Option) error {
   376  	s.Lock()
   377  
   378  	for _, o := range opts {
   379  		o(&s.opts)
   380  	}
   381  
   382  	serviceOpts := []micro.Option{}
   383  
   384  	if len(s.opts.Flags) > 0 {
   385  		serviceOpts = append(serviceOpts, micro.Flags(s.opts.Flags...))
   386  	}
   387  
   388  	if s.opts.Registry != nil {
   389  		serviceOpts = append(serviceOpts, micro.Registry(s.opts.Registry))
   390  	}
   391  
   392  	s.Unlock()
   393  
   394  	serviceOpts = append(serviceOpts, micro.Action(func(ctx *cli.Context) error {
   395  		s.Lock()
   396  		defer s.Unlock()
   397  
   398  		if ttl := ctx.Int("register_ttl"); ttl > 0 {
   399  			s.opts.RegisterTTL = time.Duration(ttl) * time.Second
   400  		}
   401  
   402  		if interval := ctx.Int("register_interval"); interval > 0 {
   403  			s.opts.RegisterInterval = time.Duration(interval) * time.Second
   404  		}
   405  
   406  		if name := ctx.String("server_name"); len(name) > 0 {
   407  			s.opts.Name = name
   408  		}
   409  
   410  		if ver := ctx.String("server_version"); len(ver) > 0 {
   411  			s.opts.Version = ver
   412  		}
   413  
   414  		if id := ctx.String("server_id"); len(id) > 0 {
   415  			s.opts.Id = id
   416  		}
   417  
   418  		if addr := ctx.String("server_address"); len(addr) > 0 {
   419  			s.opts.Address = addr
   420  		}
   421  
   422  		if adv := ctx.String("server_advertise"); len(adv) > 0 {
   423  			s.opts.Advertise = adv
   424  		}
   425  
   426  		if s.opts.Action != nil {
   427  			s.opts.Action(ctx)
   428  		}
   429  
   430  		return nil
   431  	}))
   432  
   433  	s.RLock()
   434  	// pass in own name and version
   435  	if s.opts.Service.Name() == "" {
   436  		serviceOpts = append(serviceOpts, micro.Name(s.opts.Name))
   437  	}
   438  
   439  	serviceOpts = append(serviceOpts, micro.Version(s.opts.Version))
   440  
   441  	s.RUnlock()
   442  
   443  	s.opts.Service.Init(serviceOpts...)
   444  
   445  	s.Lock()
   446  	srv := s.genSrv()
   447  	srv.Endpoints = s.srv.Endpoints
   448  	s.srv = srv
   449  	s.Unlock()
   450  
   451  	return nil
   452  }
   453  
   454  func (s *service) Start() error {
   455  	if err := s.start(); err != nil {
   456  		return err
   457  	}
   458  
   459  	if err := s.register(); err != nil {
   460  		return err
   461  	}
   462  
   463  	// start reg loop
   464  	go s.run()
   465  
   466  	return nil
   467  }
   468  
   469  func (s *service) Stop() error {
   470  	// exit reg loop
   471  	close(s.ex)
   472  
   473  	if err := s.deregister(); err != nil {
   474  		return err
   475  	}
   476  
   477  	return s.stop()
   478  }
   479  
   480  func (s *service) Run() error {
   481  	if err := s.start(); err != nil {
   482  		return err
   483  	}
   484  
   485  	logger := s.opts.Logger
   486  	// start the profiler
   487  	if s.opts.Service.Options().Profile != nil {
   488  		// to view mutex contention
   489  		runtime.SetMutexProfileFraction(5)
   490  		// to view blocking profile
   491  		runtime.SetBlockProfileRate(1)
   492  
   493  		if err := s.opts.Service.Options().Profile.Start(); err != nil {
   494  			return err
   495  		}
   496  
   497  		defer func() {
   498  			if err := s.opts.Service.Options().Profile.Stop(); err != nil {
   499  				logger.Log(log.ErrorLevel, err)
   500  			}
   501  		}()
   502  	}
   503  
   504  	if err := s.register(); err != nil {
   505  		return err
   506  	}
   507  
   508  	// start reg loop
   509  	go s.run()
   510  
   511  	ch := make(chan os.Signal, 1)
   512  	if s.opts.Signal {
   513  		signal.Notify(ch, signalutil.Shutdown()...)
   514  	}
   515  
   516  	select {
   517  	// wait on kill signal
   518  	case sig := <-ch:
   519  		logger.Logf(log.InfoLevel, "Received signal %s", sig)
   520  	// wait on context cancel
   521  	case <-s.opts.Context.Done():
   522  		logger.Log(log.InfoLevel, "Received context shutdown")
   523  	}
   524  
   525  	// exit reg loop
   526  	close(s.ex)
   527  
   528  	if err := s.deregister(); err != nil {
   529  		return err
   530  	}
   531  
   532  	return s.stop()
   533  }
   534  
   535  // Options returns the options for the given service.
   536  func (s *service) Options() Options {
   537  	return s.opts
   538  }
   539  
   540  func (s *service) listen(network, addr string) (net.Listener, error) {
   541  	var (
   542  		listener net.Listener
   543  		err      error
   544  	)
   545  
   546  	// TODO: support use of listen options
   547  	if s.opts.Secure || s.opts.TLSConfig != nil {
   548  		config := s.opts.TLSConfig
   549  
   550  		fn := func(addr string) (net.Listener, error) {
   551  			if config == nil {
   552  				hosts := []string{addr}
   553  
   554  				// check if its a valid host:port
   555  				if host, _, err := net.SplitHostPort(addr); err == nil {
   556  					if len(host) == 0 {
   557  						hosts = maddr.IPs()
   558  					} else {
   559  						hosts = []string{host}
   560  					}
   561  				}
   562  
   563  				// generate a certificate
   564  				cert, err := mls.Certificate(hosts...)
   565  				if err != nil {
   566  					return nil, err
   567  				}
   568  				config = &tls.Config{Certificates: []tls.Certificate{cert}}
   569  			}
   570  
   571  			return tls.Listen(network, addr, config)
   572  		}
   573  
   574  		listener, err = mnet.Listen(addr, fn)
   575  	} else {
   576  		fn := func(addr string) (net.Listener, error) {
   577  			return net.Listen(network, addr)
   578  		}
   579  
   580  		listener, err = mnet.Listen(addr, fn)
   581  	}
   582  
   583  	if err != nil {
   584  		return nil, err
   585  	}
   586  
   587  	return listener, nil
   588  }