github.com/annwntech/go-micro/v2@v2.9.5/web/service.go (about) 1 package web 2 3 import ( 4 "crypto/tls" 5 "fmt" 6 "net" 7 "net/http" 8 "os" 9 "os/signal" 10 "path/filepath" 11 "strings" 12 "sync" 13 "time" 14 15 "github.com/micro/cli/v2" 16 "github.com/annwntech/go-micro/v2" 17 "github.com/annwntech/go-micro/v2/logger" 18 "github.com/annwntech/go-micro/v2/registry" 19 maddr "github.com/annwntech/go-micro/v2/util/addr" 20 authutil "github.com/annwntech/go-micro/v2/util/auth" 21 "github.com/annwntech/go-micro/v2/util/backoff" 22 mhttp "github.com/annwntech/go-micro/v2/util/http" 23 mnet "github.com/annwntech/go-micro/v2/util/net" 24 signalutil "github.com/annwntech/go-micro/v2/util/signal" 25 mls "github.com/annwntech/go-micro/v2/util/tls" 26 ) 27 28 type service struct { 29 opts Options 30 31 mux *http.ServeMux 32 srv *registry.Service 33 34 sync.RWMutex 35 running bool 36 static bool 37 exit chan chan error 38 } 39 40 func newService(opts ...Option) Service { 41 options := newOptions(opts...) 42 s := &service{ 43 opts: options, 44 mux: http.NewServeMux(), 45 static: true, 46 } 47 s.srv = s.genSrv() 48 return s 49 } 50 51 func (s *service) genSrv() *registry.Service { 52 var host string 53 var port string 54 var err error 55 56 // default host:port 57 if len(s.opts.Address) > 0 { 58 host, port, err = net.SplitHostPort(s.opts.Address) 59 if err != nil { 60 logger.Fatal(err) 61 } 62 } 63 64 // check the advertise address first 65 // if it exists then use it, otherwise 66 // use the address 67 if len(s.opts.Advertise) > 0 { 68 host, port, err = net.SplitHostPort(s.opts.Advertise) 69 if err != nil { 70 logger.Fatal(err) 71 } 72 } 73 74 addr, err := maddr.Extract(host) 75 if err != nil { 76 logger.Fatal(err) 77 } 78 79 if strings.Count(addr, ":") > 0 { 80 addr = "[" + addr + "]" 81 } 82 83 return ®istry.Service{ 84 Name: s.opts.Name, 85 Version: s.opts.Version, 86 Nodes: []*registry.Node{{ 87 Id: s.opts.Id, 88 Address: fmt.Sprintf("%s:%s", addr, port), 89 Metadata: s.opts.Metadata, 90 }}, 91 } 92 } 93 94 func (s *service) run(exit chan bool) { 95 s.RLock() 96 if s.opts.RegisterInterval <= time.Duration(0) { 97 s.RUnlock() 98 return 99 } 100 101 t := time.NewTicker(s.opts.RegisterInterval) 102 s.RUnlock() 103 104 for { 105 select { 106 case <-t.C: 107 s.register() 108 case <-exit: 109 t.Stop() 110 return 111 } 112 } 113 } 114 115 func (s *service) register() error { 116 s.Lock() 117 defer s.Unlock() 118 119 if s.srv == nil { 120 return nil 121 } 122 // default to service registry 123 r := s.opts.Service.Client().Options().Registry 124 // switch to option if specified 125 if s.opts.Registry != nil { 126 r = s.opts.Registry 127 } 128 129 // service node need modify, node address maybe changed 130 srv := s.genSrv() 131 srv.Endpoints = s.srv.Endpoints 132 s.srv = srv 133 134 // use RegisterCheck func before register 135 if err := s.opts.RegisterCheck(s.opts.Context); err != nil { 136 if logger.V(logger.ErrorLevel, logger.DefaultLogger) { 137 logger.Errorf("Server %s-%s register check error: %s", s.opts.Name, s.opts.Id, err) 138 } 139 return err 140 } 141 142 var regErr error 143 144 // try three times if necessary 145 for i := 0; i < 3; i++ { 146 // attempt to register 147 if err := r.Register(s.srv, registry.RegisterTTL(s.opts.RegisterTTL)); err != nil { 148 // set the error 149 regErr = err 150 // backoff then retry 151 time.Sleep(backoff.Do(i + 1)) 152 continue 153 } 154 // success so nil error 155 regErr = nil 156 break 157 } 158 159 return regErr 160 } 161 162 func (s *service) deregister() error { 163 s.Lock() 164 defer s.Unlock() 165 166 if s.srv == nil { 167 return nil 168 } 169 // default to service registry 170 r := s.opts.Service.Client().Options().Registry 171 // switch to option if specified 172 if s.opts.Registry != nil { 173 r = s.opts.Registry 174 } 175 return r.Deregister(s.srv) 176 } 177 178 func (s *service) start() error { 179 s.Lock() 180 defer s.Unlock() 181 182 if s.running { 183 return nil 184 } 185 186 for _, fn := range s.opts.BeforeStart { 187 if err := fn(); err != nil { 188 return err 189 } 190 } 191 192 l, err := s.listen("tcp", s.opts.Address) 193 if err != nil { 194 return err 195 } 196 197 s.opts.Address = l.Addr().String() 198 srv := s.genSrv() 199 srv.Endpoints = s.srv.Endpoints 200 s.srv = srv 201 202 var h http.Handler 203 204 if s.opts.Handler != nil { 205 h = s.opts.Handler 206 } else { 207 h = s.mux 208 var r sync.Once 209 210 // register the html dir 211 r.Do(func() { 212 // static dir 213 static := s.opts.StaticDir 214 if s.opts.StaticDir[0] != '/' { 215 dir, _ := os.Getwd() 216 static = filepath.Join(dir, static) 217 } 218 219 // set static if no / handler is registered 220 if s.static { 221 _, err := os.Stat(static) 222 if err == nil { 223 if logger.V(logger.InfoLevel, logger.DefaultLogger) { 224 logger.Infof("Enabling static file serving from %s", static) 225 } 226 s.mux.Handle("/", http.FileServer(http.Dir(static))) 227 } 228 } 229 }) 230 } 231 232 var httpSrv *http.Server 233 if s.opts.Server != nil { 234 httpSrv = s.opts.Server 235 } else { 236 httpSrv = &http.Server{} 237 } 238 239 httpSrv.Handler = h 240 241 go httpSrv.Serve(l) 242 243 for _, fn := range s.opts.AfterStart { 244 if err := fn(); err != nil { 245 return err 246 } 247 } 248 249 s.exit = make(chan chan error, 1) 250 s.running = true 251 252 go func() { 253 ch := <-s.exit 254 ch <- l.Close() 255 }() 256 257 if logger.V(logger.InfoLevel, logger.DefaultLogger) { 258 logger.Infof("Listening on %v", l.Addr().String()) 259 } 260 return nil 261 } 262 263 func (s *service) stop() error { 264 s.Lock() 265 defer s.Unlock() 266 267 if !s.running { 268 return nil 269 } 270 271 for _, fn := range s.opts.BeforeStop { 272 if err := fn(); err != nil { 273 return err 274 } 275 } 276 277 ch := make(chan error, 1) 278 s.exit <- ch 279 s.running = false 280 281 if logger.V(logger.InfoLevel, logger.DefaultLogger) { 282 logger.Info("Stopping") 283 } 284 285 for _, fn := range s.opts.AfterStop { 286 if err := fn(); err != nil { 287 if chErr := <-ch; chErr != nil { 288 return chErr 289 } 290 return err 291 } 292 } 293 294 return <-ch 295 } 296 297 func (s *service) Client() *http.Client { 298 rt := mhttp.NewRoundTripper( 299 mhttp.WithRegistry(s.opts.Registry), 300 ) 301 return &http.Client{ 302 Transport: rt, 303 } 304 } 305 306 func (s *service) Handle(pattern string, handler http.Handler) { 307 var seen bool 308 s.RLock() 309 for _, ep := range s.srv.Endpoints { 310 if ep.Name == pattern { 311 seen = true 312 break 313 } 314 } 315 s.RUnlock() 316 317 // if its unseen then add an endpoint 318 if !seen { 319 s.Lock() 320 s.srv.Endpoints = append(s.srv.Endpoints, ®istry.Endpoint{ 321 Name: pattern, 322 }) 323 s.Unlock() 324 } 325 326 // disable static serving 327 if pattern == "/" { 328 s.Lock() 329 s.static = false 330 s.Unlock() 331 } 332 333 // register the handler 334 s.mux.Handle(pattern, handler) 335 } 336 337 func (s *service) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) { 338 339 var seen bool 340 s.RLock() 341 for _, ep := range s.srv.Endpoints { 342 if ep.Name == pattern { 343 seen = true 344 break 345 } 346 } 347 s.RUnlock() 348 349 if !seen { 350 s.Lock() 351 s.srv.Endpoints = append(s.srv.Endpoints, ®istry.Endpoint{ 352 Name: pattern, 353 }) 354 s.Unlock() 355 } 356 357 // disable static serving 358 if pattern == "/" { 359 s.Lock() 360 s.static = false 361 s.Unlock() 362 } 363 364 s.mux.HandleFunc(pattern, handler) 365 } 366 367 func (s *service) Init(opts ...Option) error { 368 s.Lock() 369 370 for _, o := range opts { 371 o(&s.opts) 372 } 373 374 serviceOpts := []micro.Option{} 375 376 if len(s.opts.Flags) > 0 { 377 serviceOpts = append(serviceOpts, micro.Flags(s.opts.Flags...)) 378 } 379 380 if s.opts.Registry != nil { 381 serviceOpts = append(serviceOpts, micro.Registry(s.opts.Registry)) 382 } 383 384 s.Unlock() 385 386 serviceOpts = append(serviceOpts, micro.Action(func(ctx *cli.Context) error { 387 s.Lock() 388 defer s.Unlock() 389 390 if ttl := ctx.Int("register_ttl"); ttl > 0 { 391 s.opts.RegisterTTL = time.Duration(ttl) * time.Second 392 } 393 394 if interval := ctx.Int("register_interval"); interval > 0 { 395 s.opts.RegisterInterval = time.Duration(interval) * time.Second 396 } 397 398 if name := ctx.String("server_name"); len(name) > 0 { 399 s.opts.Name = name 400 } 401 402 if ver := ctx.String("server_version"); len(ver) > 0 { 403 s.opts.Version = ver 404 } 405 406 if id := ctx.String("server_id"); len(id) > 0 { 407 s.opts.Id = id 408 } 409 410 if addr := ctx.String("server_address"); len(addr) > 0 { 411 s.opts.Address = addr 412 } 413 414 if adv := ctx.String("server_advertise"); len(adv) > 0 { 415 s.opts.Advertise = adv 416 } 417 418 if s.opts.Action != nil { 419 s.opts.Action(ctx) 420 } 421 422 return nil 423 })) 424 425 s.RLock() 426 // pass in own name and version 427 if s.opts.Service.Name() == "" { 428 serviceOpts = append(serviceOpts, micro.Name(s.opts.Name)) 429 } 430 serviceOpts = append(serviceOpts, micro.Version(s.opts.Version)) 431 s.RUnlock() 432 433 s.opts.Service.Init(serviceOpts...) 434 435 s.Lock() 436 srv := s.genSrv() 437 srv.Endpoints = s.srv.Endpoints 438 s.srv = srv 439 s.Unlock() 440 441 return nil 442 } 443 444 func (s *service) Run() error { 445 // generate an auth account 446 srvID := s.opts.Service.Server().Options().Id 447 srvName := s.Options().Name 448 if err := authutil.Generate(srvID, srvName, s.opts.Service.Options().Auth); err != nil { 449 return err 450 } 451 452 if err := s.start(); err != nil { 453 return err 454 } 455 456 if err := s.register(); err != nil { 457 return err 458 } 459 460 // start reg loop 461 ex := make(chan bool) 462 go s.run(ex) 463 464 ch := make(chan os.Signal, 1) 465 if s.opts.Signal { 466 signal.Notify(ch, signalutil.Shutdown()...) 467 } 468 469 select { 470 // wait on kill signal 471 case sig := <-ch: 472 if logger.V(logger.InfoLevel, logger.DefaultLogger) { 473 logger.Infof("Received signal %s", sig) 474 } 475 // wait on context cancel 476 case <-s.opts.Context.Done(): 477 if logger.V(logger.InfoLevel, logger.DefaultLogger) { 478 logger.Info("Received context shutdown") 479 } 480 } 481 482 // exit reg loop 483 close(ex) 484 485 if err := s.deregister(); err != nil { 486 return err 487 } 488 489 return s.stop() 490 } 491 492 // Options returns the options for the given service 493 func (s *service) Options() Options { 494 return s.opts 495 } 496 497 func (s *service) listen(network, addr string) (net.Listener, error) { 498 var l net.Listener 499 var err error 500 501 // TODO: support use of listen options 502 if s.opts.Secure || s.opts.TLSConfig != nil { 503 config := s.opts.TLSConfig 504 505 fn := func(addr string) (net.Listener, error) { 506 if config == nil { 507 hosts := []string{addr} 508 509 // check if its a valid host:port 510 if host, _, err := net.SplitHostPort(addr); err == nil { 511 if len(host) == 0 { 512 hosts = maddr.IPs() 513 } else { 514 hosts = []string{host} 515 } 516 } 517 518 // generate a certificate 519 cert, err := mls.Certificate(hosts...) 520 if err != nil { 521 return nil, err 522 } 523 config = &tls.Config{Certificates: []tls.Certificate{cert}} 524 } 525 return tls.Listen(network, addr, config) 526 } 527 528 l, err = mnet.Listen(addr, fn) 529 } else { 530 fn := func(addr string) (net.Listener, error) { 531 return net.Listen(network, addr) 532 } 533 534 l, err = mnet.Listen(addr, fn) 535 } 536 537 if err != nil { 538 return nil, err 539 } 540 541 return l, nil 542 }