gitee.com/liuxuezhan/go-micro-v1.18.0@v1.0.0/web/service.go (about)

     1  package web
     2  
     3  import (
     4  	"crypto/tls"
     5  	"fmt"
     6  	"net"
     7  	"net/http"
     8  	"os"
     9  	"os/signal"
    10  	"path/filepath"
    11  	"strconv"
    12  	"strings"
    13  	"sync"
    14  	"syscall"
    15  	"time"
    16  
    17  	"github.com/micro/cli"
    18  	"gitee.com/liuxuezhan/go-micro-v1.18.0"
    19  	"gitee.com/liuxuezhan/go-micro-v1.18.0/registry"
    20  	maddr "gitee.com/liuxuezhan/go-micro-v1.18.0/util/addr"
    21  	mhttp "gitee.com/liuxuezhan/go-micro-v1.18.0/util/http"
    22  	"gitee.com/liuxuezhan/go-micro-v1.18.0/util/log"
    23  	mnet "gitee.com/liuxuezhan/go-micro-v1.18.0/util/net"
    24  	mls "gitee.com/liuxuezhan/go-micro-v1.18.0/util/tls"
    25  )
    26  
    27  type service struct {
    28  	opts Options
    29  
    30  	mux *http.ServeMux
    31  	srv *registry.Service
    32  
    33  	sync.Mutex
    34  	running bool
    35  	static  bool
    36  	exit    chan chan error
    37  }
    38  
    39  func newService(opts ...Option) Service {
    40  	options := newOptions(opts...)
    41  	s := &service{
    42  		opts:   options,
    43  		mux:    http.NewServeMux(),
    44  		static: true,
    45  	}
    46  	s.srv = s.genSrv()
    47  	return s
    48  }
    49  
    50  func (s *service) genSrv() *registry.Service {
    51  	// default host:port
    52  	parts := strings.Split(s.opts.Address, ":")
    53  	host := strings.Join(parts[:len(parts)-1], ":")
    54  	port, _ := strconv.Atoi(parts[len(parts)-1])
    55  
    56  	// check the advertise address first
    57  	// if it exists then use it, otherwise
    58  	// use the address
    59  	if len(s.opts.Advertise) > 0 {
    60  		parts = strings.Split(s.opts.Advertise, ":")
    61  
    62  		// we have host:port
    63  		if len(parts) > 1 {
    64  			// set the host
    65  			host = strings.Join(parts[:len(parts)-1], ":")
    66  
    67  			// get the port
    68  			if aport, _ := strconv.Atoi(parts[len(parts)-1]); aport > 0 {
    69  				port = aport
    70  			}
    71  		} else {
    72  			host = parts[0]
    73  		}
    74  	}
    75  
    76  	addr, err := maddr.Extract(host)
    77  	if err != nil {
    78  		// best effort localhost
    79  		addr = "127.0.0.1"
    80  	}
    81  
    82  	return &registry.Service{
    83  		Name:    s.opts.Name,
    84  		Version: s.opts.Version,
    85  		Nodes: []*registry.Node{{
    86  			Id:       s.opts.Id,
    87  			Address:  fmt.Sprintf("%s:%d", addr, port),
    88  			Metadata: s.opts.Metadata,
    89  		}},
    90  	}
    91  }
    92  
    93  func (s *service) run(exit chan bool) {
    94  	if s.opts.RegisterInterval <= time.Duration(0) {
    95  		return
    96  	}
    97  
    98  	t := time.NewTicker(s.opts.RegisterInterval)
    99  
   100  	for {
   101  		select {
   102  		case <-t.C:
   103  			s.register()
   104  		case <-exit:
   105  			t.Stop()
   106  			return
   107  		}
   108  	}
   109  }
   110  
   111  func (s *service) register() error {
   112  	if s.srv == nil {
   113  		return nil
   114  	}
   115  	// default to service registry
   116  	r := s.opts.Service.Client().Options().Registry
   117  	// switch to option if specified
   118  	if s.opts.Registry != nil {
   119  		r = s.opts.Registry
   120  	}
   121  
   122  	// service node need modify, node address maybe changed
   123  	srv := s.genSrv()
   124  	srv.Endpoints = s.srv.Endpoints
   125  	s.srv = srv
   126  	return r.Register(s.srv, registry.RegisterTTL(s.opts.RegisterTTL))
   127  }
   128  
   129  func (s *service) deregister() error {
   130  	if s.srv == nil {
   131  		return nil
   132  	}
   133  	// default to service registry
   134  	r := s.opts.Service.Client().Options().Registry
   135  	// switch to option if specified
   136  	if s.opts.Registry != nil {
   137  		r = s.opts.Registry
   138  	}
   139  	return r.Deregister(s.srv)
   140  }
   141  
   142  func (s *service) start() error {
   143  	s.Lock()
   144  	defer s.Unlock()
   145  
   146  	if s.running {
   147  		return nil
   148  	}
   149  
   150  	l, err := s.listen("tcp", s.opts.Address)
   151  	if err != nil {
   152  		return err
   153  	}
   154  
   155  	s.opts.Address = l.Addr().String()
   156  	srv := s.genSrv()
   157  	srv.Endpoints = s.srv.Endpoints
   158  	s.srv = srv
   159  
   160  	var h http.Handler
   161  
   162  	if s.opts.Handler != nil {
   163  		h = s.opts.Handler
   164  	} else {
   165  		h = s.mux
   166  		var r sync.Once
   167  
   168  		// register the html dir
   169  		r.Do(func() {
   170  			// static dir
   171  			static := s.opts.StaticDir
   172  			if s.opts.StaticDir[0] != '/' {
   173  				dir, _ := os.Getwd()
   174  				static = filepath.Join(dir, static)
   175  			}
   176  
   177  			// set static if no / handler is registered
   178  			if s.static {
   179  				_, err := os.Stat(static)
   180  				if err == nil {
   181  					log.Logf("Enabling static file serving from %s", static)
   182  					s.mux.Handle("/", http.FileServer(http.Dir(static)))
   183  				}
   184  			}
   185  		})
   186  	}
   187  
   188  	for _, fn := range s.opts.BeforeStart {
   189  		if err := fn(); err != nil {
   190  			return err
   191  		}
   192  	}
   193  
   194  	var httpSrv *http.Server
   195  	if s.opts.Server != nil {
   196  		httpSrv = s.opts.Server
   197  	} else {
   198  		httpSrv = &http.Server{}
   199  	}
   200  
   201  	httpSrv.Handler = h
   202  
   203  	go httpSrv.Serve(l)
   204  
   205  	for _, fn := range s.opts.AfterStart {
   206  		if err := fn(); err != nil {
   207  			return err
   208  		}
   209  	}
   210  
   211  	s.exit = make(chan chan error, 1)
   212  	s.running = true
   213  
   214  	go func() {
   215  		ch := <-s.exit
   216  		ch <- l.Close()
   217  	}()
   218  
   219  	log.Logf("Listening on %v", l.Addr().String())
   220  	return nil
   221  }
   222  
   223  func (s *service) stop() error {
   224  	s.Lock()
   225  	defer s.Unlock()
   226  
   227  	if !s.running {
   228  		return nil
   229  	}
   230  
   231  	for _, fn := range s.opts.BeforeStop {
   232  		if err := fn(); err != nil {
   233  			return err
   234  		}
   235  	}
   236  
   237  	ch := make(chan error, 1)
   238  	s.exit <- ch
   239  	s.running = false
   240  
   241  	log.Log("Stopping")
   242  
   243  	for _, fn := range s.opts.AfterStop {
   244  		if err := fn(); err != nil {
   245  			if chErr := <-ch; chErr != nil {
   246  				return chErr
   247  			}
   248  			return err
   249  		}
   250  	}
   251  
   252  	return <-ch
   253  }
   254  
   255  func (s *service) Client() *http.Client {
   256  	rt := mhttp.NewRoundTripper(
   257  		mhttp.WithRegistry(registry.DefaultRegistry),
   258  	)
   259  	return &http.Client{
   260  		Transport: rt,
   261  	}
   262  }
   263  
   264  func (s *service) Handle(pattern string, handler http.Handler) {
   265  	var seen bool
   266  	for _, ep := range s.srv.Endpoints {
   267  		if ep.Name == pattern {
   268  			seen = true
   269  			break
   270  		}
   271  	}
   272  
   273  	// if its unseen then add an endpoint
   274  	if !seen {
   275  		s.srv.Endpoints = append(s.srv.Endpoints, &registry.Endpoint{
   276  			Name: pattern,
   277  		})
   278  	}
   279  
   280  	// disable static serving
   281  	if pattern == "/" {
   282  		s.Lock()
   283  		s.static = false
   284  		s.Unlock()
   285  	}
   286  
   287  	// register the handler
   288  	s.mux.Handle(pattern, handler)
   289  }
   290  
   291  func (s *service) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
   292  	var seen bool
   293  	for _, ep := range s.srv.Endpoints {
   294  		if ep.Name == pattern {
   295  			seen = true
   296  			break
   297  		}
   298  	}
   299  	if !seen {
   300  		s.srv.Endpoints = append(s.srv.Endpoints, &registry.Endpoint{
   301  			Name: pattern,
   302  		})
   303  	}
   304  
   305  	s.mux.HandleFunc(pattern, handler)
   306  }
   307  
   308  func (s *service) Init(opts ...Option) error {
   309  	for _, o := range opts {
   310  		o(&s.opts)
   311  	}
   312  
   313  	serviceOpts := []micro.Option{}
   314  
   315  	if len(s.opts.Flags) > 0 {
   316  		serviceOpts = append(serviceOpts, micro.Flags(s.opts.Flags...))
   317  	}
   318  
   319  	if s.opts.Registry != nil {
   320  		serviceOpts = append(serviceOpts, micro.Registry(s.opts.Registry))
   321  	}
   322  
   323  	serviceOpts = append(serviceOpts, micro.Action(func(ctx *cli.Context) {
   324  		if ttl := ctx.Int("register_ttl"); ttl > 0 {
   325  			s.opts.RegisterTTL = time.Duration(ttl) * time.Second
   326  		}
   327  
   328  		if interval := ctx.Int("register_interval"); interval > 0 {
   329  			s.opts.RegisterInterval = time.Duration(interval) * time.Second
   330  		}
   331  
   332  		if name := ctx.String("server_name"); len(name) > 0 {
   333  			s.opts.Name = name
   334  		}
   335  
   336  		if ver := ctx.String("server_version"); len(ver) > 0 {
   337  			s.opts.Version = ver
   338  		}
   339  
   340  		if id := ctx.String("server_id"); len(id) > 0 {
   341  			s.opts.Id = id
   342  		}
   343  
   344  		if addr := ctx.String("server_address"); len(addr) > 0 {
   345  			s.opts.Address = addr
   346  		}
   347  
   348  		if adv := ctx.String("server_advertise"); len(adv) > 0 {
   349  			s.opts.Advertise = adv
   350  		}
   351  
   352  		if s.opts.Action != nil {
   353  			s.opts.Action(ctx)
   354  		}
   355  	}))
   356  
   357  	s.opts.Service.Init(serviceOpts...)
   358  	srv := s.genSrv()
   359  	srv.Endpoints = s.srv.Endpoints
   360  	s.srv = srv
   361  
   362  	return nil
   363  }
   364  
   365  func (s *service) Run() error {
   366  	if err := s.start(); err != nil {
   367  		return err
   368  	}
   369  
   370  	if err := s.register(); err != nil {
   371  		return err
   372  	}
   373  
   374  	// start reg loop
   375  	ex := make(chan bool)
   376  	go s.run(ex)
   377  
   378  	ch := make(chan os.Signal, 1)
   379  	signal.Notify(ch, syscall.SIGTERM, syscall.SIGINT)
   380  
   381  	select {
   382  	// wait on kill signal
   383  	case sig := <-ch:
   384  		log.Logf("Received signal %s", sig)
   385  	// wait on context cancel
   386  	case <-s.opts.Context.Done():
   387  		log.Logf("Received context shutdown")
   388  	}
   389  
   390  	// exit reg loop
   391  	close(ex)
   392  
   393  	if err := s.deregister(); err != nil {
   394  		return err
   395  	}
   396  
   397  	return s.stop()
   398  }
   399  
   400  // Options returns the options for the given service
   401  func (s *service) Options() Options {
   402  	return s.opts
   403  }
   404  
   405  func (s *service) listen(network, addr string) (net.Listener, error) {
   406  	var l net.Listener
   407  	var err error
   408  
   409  	// TODO: support use of listen options
   410  	if s.opts.Secure || s.opts.TLSConfig != nil {
   411  		config := s.opts.TLSConfig
   412  
   413  		fn := func(addr string) (net.Listener, error) {
   414  			if config == nil {
   415  				hosts := []string{addr}
   416  
   417  				// check if its a valid host:port
   418  				if host, _, err := net.SplitHostPort(addr); err == nil {
   419  					if len(host) == 0 {
   420  						hosts = maddr.IPs()
   421  					} else {
   422  						hosts = []string{host}
   423  					}
   424  				}
   425  
   426  				// generate a certificate
   427  				cert, err := mls.Certificate(hosts...)
   428  				if err != nil {
   429  					return nil, err
   430  				}
   431  				config = &tls.Config{Certificates: []tls.Certificate{cert}}
   432  			}
   433  			return tls.Listen(network, addr, config)
   434  		}
   435  
   436  		l, err = mnet.Listen(addr, fn)
   437  	} else {
   438  		fn := func(addr string) (net.Listener, error) {
   439  			return net.Listen(network, addr)
   440  		}
   441  
   442  		l, err = mnet.Listen(addr, fn)
   443  	}
   444  
   445  	if err != nil {
   446  		return nil, err
   447  	}
   448  
   449  	return l, nil
   450  }