github.com/micro/go-micro/v2@v2.9.1/web/service.go (about)

     1  package web
     2  
     3  import (
     4  	"crypto/tls"
     5  	"fmt"
     6  	"net"
     7  	"net/http"
     8  	"os"
     9  	"os/signal"
    10  	"path/filepath"
    11  	"strings"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/micro/cli/v2"
    16  	"github.com/micro/go-micro/v2"
    17  	"github.com/micro/go-micro/v2/logger"
    18  	"github.com/micro/go-micro/v2/registry"
    19  	maddr "github.com/micro/go-micro/v2/util/addr"
    20  	authutil "github.com/micro/go-micro/v2/util/auth"
    21  	"github.com/micro/go-micro/v2/util/backoff"
    22  	mhttp "github.com/micro/go-micro/v2/util/http"
    23  	mnet "github.com/micro/go-micro/v2/util/net"
    24  	signalutil "github.com/micro/go-micro/v2/util/signal"
    25  	mls "github.com/micro/go-micro/v2/util/tls"
    26  )
    27  
    28  type service struct {
    29  	opts Options
    30  
    31  	mux *http.ServeMux
    32  	srv *registry.Service
    33  
    34  	sync.RWMutex
    35  	running bool
    36  	static  bool
    37  	exit    chan chan error
    38  }
    39  
    40  func newService(opts ...Option) Service {
    41  	options := newOptions(opts...)
    42  	s := &service{
    43  		opts:   options,
    44  		mux:    http.NewServeMux(),
    45  		static: true,
    46  	}
    47  	s.srv = s.genSrv()
    48  	return s
    49  }
    50  
    51  func (s *service) genSrv() *registry.Service {
    52  	var host string
    53  	var port string
    54  	var err error
    55  
    56  	// default host:port
    57  	if len(s.opts.Address) > 0 {
    58  		host, port, err = net.SplitHostPort(s.opts.Address)
    59  		if err != nil {
    60  			logger.Fatal(err)
    61  		}
    62  	}
    63  
    64  	// check the advertise address first
    65  	// if it exists then use it, otherwise
    66  	// use the address
    67  	if len(s.opts.Advertise) > 0 {
    68  		host, port, err = net.SplitHostPort(s.opts.Advertise)
    69  		if err != nil {
    70  			logger.Fatal(err)
    71  		}
    72  	}
    73  
    74  	addr, err := maddr.Extract(host)
    75  	if err != nil {
    76  		logger.Fatal(err)
    77  	}
    78  
    79  	if strings.Count(addr, ":") > 0 {
    80  		addr = "[" + addr + "]"
    81  	}
    82  
    83  	return &registry.Service{
    84  		Name:    s.opts.Name,
    85  		Version: s.opts.Version,
    86  		Nodes: []*registry.Node{{
    87  			Id:       s.opts.Id,
    88  			Address:  fmt.Sprintf("%s:%s", addr, port),
    89  			Metadata: s.opts.Metadata,
    90  		}},
    91  	}
    92  }
    93  
    94  func (s *service) run(exit chan bool) {
    95  	s.RLock()
    96  	if s.opts.RegisterInterval <= time.Duration(0) {
    97  		s.RUnlock()
    98  		return
    99  	}
   100  
   101  	t := time.NewTicker(s.opts.RegisterInterval)
   102  	s.RUnlock()
   103  
   104  	for {
   105  		select {
   106  		case <-t.C:
   107  			s.register()
   108  		case <-exit:
   109  			t.Stop()
   110  			return
   111  		}
   112  	}
   113  }
   114  
   115  func (s *service) register() error {
   116  	s.Lock()
   117  	defer s.Unlock()
   118  
   119  	if s.srv == nil {
   120  		return nil
   121  	}
   122  	// default to service registry
   123  	r := s.opts.Service.Client().Options().Registry
   124  	// switch to option if specified
   125  	if s.opts.Registry != nil {
   126  		r = s.opts.Registry
   127  	}
   128  
   129  	// service node need modify, node address maybe changed
   130  	srv := s.genSrv()
   131  	srv.Endpoints = s.srv.Endpoints
   132  	s.srv = srv
   133  
   134  	// use RegisterCheck func before register
   135  	if err := s.opts.RegisterCheck(s.opts.Context); err != nil {
   136  		if logger.V(logger.ErrorLevel, logger.DefaultLogger) {
   137  			logger.Errorf("Server %s-%s register check error: %s", s.opts.Name, s.opts.Id, err)
   138  		}
   139  		return err
   140  	}
   141  
   142  	var regErr error
   143  
   144  	// try three times if necessary
   145  	for i := 0; i < 3; i++ {
   146  		// attempt to register
   147  		if err := r.Register(s.srv, registry.RegisterTTL(s.opts.RegisterTTL)); err != nil {
   148  			// set the error
   149  			regErr = err
   150  			// backoff then retry
   151  			time.Sleep(backoff.Do(i + 1))
   152  			continue
   153  		}
   154  		// success so nil error
   155  		regErr = nil
   156  		break
   157  	}
   158  
   159  	return regErr
   160  }
   161  
   162  func (s *service) deregister() error {
   163  	s.Lock()
   164  	defer s.Unlock()
   165  
   166  	if s.srv == nil {
   167  		return nil
   168  	}
   169  	// default to service registry
   170  	r := s.opts.Service.Client().Options().Registry
   171  	// switch to option if specified
   172  	if s.opts.Registry != nil {
   173  		r = s.opts.Registry
   174  	}
   175  	return r.Deregister(s.srv)
   176  }
   177  
   178  func (s *service) start() error {
   179  	s.Lock()
   180  	defer s.Unlock()
   181  
   182  	if s.running {
   183  		return nil
   184  	}
   185  
   186  	for _, fn := range s.opts.BeforeStart {
   187  		if err := fn(); err != nil {
   188  			return err
   189  		}
   190  	}
   191  
   192  	l, err := s.listen("tcp", s.opts.Address)
   193  	if err != nil {
   194  		return err
   195  	}
   196  
   197  	s.opts.Address = l.Addr().String()
   198  	srv := s.genSrv()
   199  	srv.Endpoints = s.srv.Endpoints
   200  	s.srv = srv
   201  
   202  	var h http.Handler
   203  
   204  	if s.opts.Handler != nil {
   205  		h = s.opts.Handler
   206  	} else {
   207  		h = s.mux
   208  		var r sync.Once
   209  
   210  		// register the html dir
   211  		r.Do(func() {
   212  			// static dir
   213  			static := s.opts.StaticDir
   214  			if s.opts.StaticDir[0] != '/' {
   215  				dir, _ := os.Getwd()
   216  				static = filepath.Join(dir, static)
   217  			}
   218  
   219  			// set static if no / handler is registered
   220  			if s.static {
   221  				_, err := os.Stat(static)
   222  				if err == nil {
   223  					if logger.V(logger.InfoLevel, logger.DefaultLogger) {
   224  						logger.Infof("Enabling static file serving from %s", static)
   225  					}
   226  					s.mux.Handle("/", http.FileServer(http.Dir(static)))
   227  				}
   228  			}
   229  		})
   230  	}
   231  
   232  	var httpSrv *http.Server
   233  	if s.opts.Server != nil {
   234  		httpSrv = s.opts.Server
   235  	} else {
   236  		httpSrv = &http.Server{}
   237  	}
   238  
   239  	httpSrv.Handler = h
   240  
   241  	go httpSrv.Serve(l)
   242  
   243  	for _, fn := range s.opts.AfterStart {
   244  		if err := fn(); err != nil {
   245  			return err
   246  		}
   247  	}
   248  
   249  	s.exit = make(chan chan error, 1)
   250  	s.running = true
   251  
   252  	go func() {
   253  		ch := <-s.exit
   254  		ch <- l.Close()
   255  	}()
   256  
   257  	if logger.V(logger.InfoLevel, logger.DefaultLogger) {
   258  		logger.Infof("Listening on %v", l.Addr().String())
   259  	}
   260  	return nil
   261  }
   262  
   263  func (s *service) stop() error {
   264  	s.Lock()
   265  	defer s.Unlock()
   266  
   267  	if !s.running {
   268  		return nil
   269  	}
   270  
   271  	for _, fn := range s.opts.BeforeStop {
   272  		if err := fn(); err != nil {
   273  			return err
   274  		}
   275  	}
   276  
   277  	ch := make(chan error, 1)
   278  	s.exit <- ch
   279  	s.running = false
   280  
   281  	if logger.V(logger.InfoLevel, logger.DefaultLogger) {
   282  		logger.Info("Stopping")
   283  	}
   284  
   285  	for _, fn := range s.opts.AfterStop {
   286  		if err := fn(); err != nil {
   287  			if chErr := <-ch; chErr != nil {
   288  				return chErr
   289  			}
   290  			return err
   291  		}
   292  	}
   293  
   294  	return <-ch
   295  }
   296  
   297  func (s *service) Client() *http.Client {
   298  	rt := mhttp.NewRoundTripper(
   299  		mhttp.WithRegistry(s.opts.Registry),
   300  	)
   301  	return &http.Client{
   302  		Transport: rt,
   303  	}
   304  }
   305  
   306  func (s *service) Handle(pattern string, handler http.Handler) {
   307  	var seen bool
   308  	s.RLock()
   309  	for _, ep := range s.srv.Endpoints {
   310  		if ep.Name == pattern {
   311  			seen = true
   312  			break
   313  		}
   314  	}
   315  	s.RUnlock()
   316  
   317  	// if its unseen then add an endpoint
   318  	if !seen {
   319  		s.Lock()
   320  		s.srv.Endpoints = append(s.srv.Endpoints, &registry.Endpoint{
   321  			Name: pattern,
   322  		})
   323  		s.Unlock()
   324  	}
   325  
   326  	// disable static serving
   327  	if pattern == "/" {
   328  		s.Lock()
   329  		s.static = false
   330  		s.Unlock()
   331  	}
   332  
   333  	// register the handler
   334  	s.mux.Handle(pattern, handler)
   335  }
   336  
   337  func (s *service) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
   338  
   339  	var seen bool
   340  	s.RLock()
   341  	for _, ep := range s.srv.Endpoints {
   342  		if ep.Name == pattern {
   343  			seen = true
   344  			break
   345  		}
   346  	}
   347  	s.RUnlock()
   348  
   349  	if !seen {
   350  		s.Lock()
   351  		s.srv.Endpoints = append(s.srv.Endpoints, &registry.Endpoint{
   352  			Name: pattern,
   353  		})
   354  		s.Unlock()
   355  	}
   356  
   357  	// disable static serving
   358  	if pattern == "/" {
   359  		s.Lock()
   360  		s.static = false
   361  		s.Unlock()
   362  	}
   363  
   364  	s.mux.HandleFunc(pattern, handler)
   365  }
   366  
   367  func (s *service) Init(opts ...Option) error {
   368  	s.Lock()
   369  
   370  	for _, o := range opts {
   371  		o(&s.opts)
   372  	}
   373  
   374  	serviceOpts := []micro.Option{}
   375  
   376  	if len(s.opts.Flags) > 0 {
   377  		serviceOpts = append(serviceOpts, micro.Flags(s.opts.Flags...))
   378  	}
   379  
   380  	if s.opts.Registry != nil {
   381  		serviceOpts = append(serviceOpts, micro.Registry(s.opts.Registry))
   382  	}
   383  
   384  	s.Unlock()
   385  
   386  	serviceOpts = append(serviceOpts, micro.Action(func(ctx *cli.Context) error {
   387  		s.Lock()
   388  		defer s.Unlock()
   389  
   390  		if ttl := ctx.Int("register_ttl"); ttl > 0 {
   391  			s.opts.RegisterTTL = time.Duration(ttl) * time.Second
   392  		}
   393  
   394  		if interval := ctx.Int("register_interval"); interval > 0 {
   395  			s.opts.RegisterInterval = time.Duration(interval) * time.Second
   396  		}
   397  
   398  		if name := ctx.String("server_name"); len(name) > 0 {
   399  			s.opts.Name = name
   400  		}
   401  
   402  		if ver := ctx.String("server_version"); len(ver) > 0 {
   403  			s.opts.Version = ver
   404  		}
   405  
   406  		if id := ctx.String("server_id"); len(id) > 0 {
   407  			s.opts.Id = id
   408  		}
   409  
   410  		if addr := ctx.String("server_address"); len(addr) > 0 {
   411  			s.opts.Address = addr
   412  		}
   413  
   414  		if adv := ctx.String("server_advertise"); len(adv) > 0 {
   415  			s.opts.Advertise = adv
   416  		}
   417  
   418  		if s.opts.Action != nil {
   419  			s.opts.Action(ctx)
   420  		}
   421  
   422  		return nil
   423  	}))
   424  
   425  	s.RLock()
   426  	// pass in own name and version
   427  	if s.opts.Service.Name() == "" {
   428  		serviceOpts = append(serviceOpts, micro.Name(s.opts.Name))
   429  	}
   430  	serviceOpts = append(serviceOpts, micro.Version(s.opts.Version))
   431  	s.RUnlock()
   432  
   433  	s.opts.Service.Init(serviceOpts...)
   434  
   435  	s.Lock()
   436  	srv := s.genSrv()
   437  	srv.Endpoints = s.srv.Endpoints
   438  	s.srv = srv
   439  	s.Unlock()
   440  
   441  	return nil
   442  }
   443  
   444  func (s *service) Run() error {
   445  	// generate an auth account
   446  	srvID := s.opts.Service.Server().Options().Id
   447  	srvName := s.Options().Name
   448  	if err := authutil.Generate(srvID, srvName, s.opts.Service.Options().Auth); err != nil {
   449  		return err
   450  	}
   451  
   452  	if err := s.start(); err != nil {
   453  		return err
   454  	}
   455  
   456  	if err := s.register(); err != nil {
   457  		return err
   458  	}
   459  
   460  	// start reg loop
   461  	ex := make(chan bool)
   462  	go s.run(ex)
   463  
   464  	ch := make(chan os.Signal, 1)
   465  	if s.opts.Signal {
   466  		signal.Notify(ch, signalutil.Shutdown()...)
   467  	}
   468  
   469  	select {
   470  	// wait on kill signal
   471  	case sig := <-ch:
   472  		if logger.V(logger.InfoLevel, logger.DefaultLogger) {
   473  			logger.Infof("Received signal %s", sig)
   474  		}
   475  	// wait on context cancel
   476  	case <-s.opts.Context.Done():
   477  		if logger.V(logger.InfoLevel, logger.DefaultLogger) {
   478  			logger.Info("Received context shutdown")
   479  		}
   480  	}
   481  
   482  	// exit reg loop
   483  	close(ex)
   484  
   485  	if err := s.deregister(); err != nil {
   486  		return err
   487  	}
   488  
   489  	return s.stop()
   490  }
   491  
   492  // Options returns the options for the given service
   493  func (s *service) Options() Options {
   494  	return s.opts
   495  }
   496  
   497  func (s *service) listen(network, addr string) (net.Listener, error) {
   498  	var l net.Listener
   499  	var err error
   500  
   501  	// TODO: support use of listen options
   502  	if s.opts.Secure || s.opts.TLSConfig != nil {
   503  		config := s.opts.TLSConfig
   504  
   505  		fn := func(addr string) (net.Listener, error) {
   506  			if config == nil {
   507  				hosts := []string{addr}
   508  
   509  				// check if its a valid host:port
   510  				if host, _, err := net.SplitHostPort(addr); err == nil {
   511  					if len(host) == 0 {
   512  						hosts = maddr.IPs()
   513  					} else {
   514  						hosts = []string{host}
   515  					}
   516  				}
   517  
   518  				// generate a certificate
   519  				cert, err := mls.Certificate(hosts...)
   520  				if err != nil {
   521  					return nil, err
   522  				}
   523  				config = &tls.Config{Certificates: []tls.Certificate{cert}}
   524  			}
   525  			return tls.Listen(network, addr, config)
   526  		}
   527  
   528  		l, err = mnet.Listen(addr, fn)
   529  	} else {
   530  		fn := func(addr string) (net.Listener, error) {
   531  			return net.Listen(network, addr)
   532  		}
   533  
   534  		l, err = mnet.Listen(addr, fn)
   535  	}
   536  
   537  	if err != nil {
   538  		return nil, err
   539  	}
   540  
   541  	return l, nil
   542  }