trpc.group/trpc-go/trpc-go@v1.0.3/server/server.go (about)

     1  //
     2  //
     3  // Tencent is pleased to support the open source community by making tRPC available.
     4  //
     5  // Copyright (C) 2023 THL A29 Limited, a Tencent company.
     6  // All rights reserved.
     7  //
     8  // If you have downloaded a copy of the tRPC source code from Tencent,
     9  // please note that tRPC source code is licensed under the  Apache 2.0 License,
    10  // A copy of the Apache 2.0 License is included in this file.
    11  //
    12  //
    13  
    14  // Package server provides a framework for managing multiple services within a single process.
    15  // A server process may listen on multiple ports, providing different services on different ports.
    16  package server
    17  
    18  import (
    19  	"context"
    20  	"errors"
    21  	"os"
    22  	"sync"
    23  	"time"
    24  )
    25  
    26  // Server is a tRPC server.
    27  // One process, one server. A server may offer one or more services.
    28  type Server struct {
    29  	MaxCloseWaitTime time.Duration // max waiting time when closing server
    30  
    31  	services map[string]Service // k=serviceName,v=Service
    32  
    33  	mux sync.Mutex // guards onShutdownHooks
    34  	// onShutdownHooks are hook functions that would be executed when server is
    35  	// shutting down (before closing all services of the server).
    36  	onShutdownHooks []func()
    37  
    38  	failedServices sync.Map
    39  	signalCh       chan os.Signal
    40  	closeCh        chan struct{}
    41  	closeOnce      sync.Once
    42  }
    43  
    44  // AddService adds a service for the server.
    45  // The param serviceName refers to the name used for Naming Services and
    46  // configured by config file (typically trpc_go.yaml).
    47  // When trpc.NewServer() is called, it will traverse service configuration from config file,
    48  // and call AddService to add a service implementation to the server's map[string]Service (serviceName as key).
    49  func (s *Server) AddService(serviceName string, service Service) {
    50  	if s.services == nil {
    51  		s.services = make(map[string]Service)
    52  	}
    53  	s.services[serviceName] = service
    54  }
    55  
    56  // Service returns a service by service name.
    57  func (s *Server) Service(serviceName string) Service {
    58  	if s.services == nil {
    59  		return nil
    60  	}
    61  	return s.services[serviceName]
    62  }
    63  
    64  // Register implements Service interface, registering a proto service.
    65  // Normally, a server contains only one service, so the registration is straightforward.
    66  // When it comes to server with multiple services, remember to use Service("servicename") to specify
    67  // which service this proto service is registered for.
    68  // Otherwise, this proto service will be registered for all services of the server.
    69  func (s *Server) Register(serviceDesc interface{}, serviceImpl interface{}) error {
    70  	desc, ok := serviceDesc.(*ServiceDesc)
    71  	if !ok {
    72  		return errors.New("service desc type invalid")
    73  	}
    74  
    75  	for _, srv := range s.services {
    76  		if err := srv.Register(desc, serviceImpl); err != nil {
    77  			return err
    78  		}
    79  	}
    80  	return nil
    81  }
    82  
    83  // Close implements Service interface, notifying all services of server shutdown.
    84  // Would wait no more than 10s.
    85  func (s *Server) Close(ch chan struct{}) error {
    86  	if s.closeCh != nil {
    87  		close(s.closeCh)
    88  	}
    89  
    90  	s.tryClose()
    91  
    92  	if ch != nil {
    93  		ch <- struct{}{}
    94  	}
    95  	return nil
    96  }
    97  
    98  func (s *Server) tryClose() {
    99  	fn := func() {
   100  		// execute shutdown hook functions before closing services.
   101  		s.mux.Lock()
   102  		for _, f := range s.onShutdownHooks {
   103  			f()
   104  		}
   105  		s.mux.Unlock()
   106  
   107  		// close all Services
   108  		closeWaitTime := s.MaxCloseWaitTime
   109  		if closeWaitTime < MaxCloseWaitTime {
   110  			closeWaitTime = MaxCloseWaitTime
   111  		}
   112  		ctx, cancel := context.WithTimeout(context.Background(), closeWaitTime)
   113  		defer cancel()
   114  
   115  		var wg sync.WaitGroup
   116  		for name, service := range s.services {
   117  			if _, ok := s.failedServices.Load(name); ok {
   118  				continue
   119  			}
   120  
   121  			wg.Add(1)
   122  			go func(srv Service) {
   123  				defer wg.Done()
   124  
   125  				c := make(chan struct{}, 1)
   126  				go srv.Close(c)
   127  
   128  				select {
   129  				case <-c:
   130  				case <-ctx.Done():
   131  				}
   132  			}(service)
   133  		}
   134  		wg.Wait()
   135  	}
   136  	s.closeOnce.Do(fn)
   137  }
   138  
   139  // RegisterOnShutdown registers a hook function that would be executed when server is shutting down.
   140  func (s *Server) RegisterOnShutdown(fn func()) {
   141  	if fn == nil {
   142  		return
   143  	}
   144  	s.mux.Lock()
   145  	s.onShutdownHooks = append(s.onShutdownHooks, fn)
   146  	s.mux.Unlock()
   147  }