github.com/cs3org/reva/v2@v2.27.7/pkg/rgrpc/rgrpc.go (about)

     1  // Copyright 2018-2021 CERN
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // In applying this license, CERN does not waive the privileges and immunities
    16  // granted to it by virtue of its status as an Intergovernmental Organization
    17  // or submit itself to any jurisdiction.
    18  
    19  package rgrpc
    20  
    21  import (
    22  	"crypto/tls"
    23  	"fmt"
    24  	"io"
    25  	"net"
    26  	"sort"
    27  
    28  	"github.com/cs3org/reva/v2/internal/grpc/interceptors/appctx"
    29  	"github.com/cs3org/reva/v2/internal/grpc/interceptors/auth"
    30  	"github.com/cs3org/reva/v2/internal/grpc/interceptors/log"
    31  	"github.com/cs3org/reva/v2/internal/grpc/interceptors/recovery"
    32  	"github.com/cs3org/reva/v2/internal/grpc/interceptors/token"
    33  	"github.com/cs3org/reva/v2/internal/grpc/interceptors/useragent"
    34  	"github.com/cs3org/reva/v2/pkg/sharedconf"
    35  	rtrace "github.com/cs3org/reva/v2/pkg/trace"
    36  	grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
    37  	"github.com/mitchellh/mapstructure"
    38  	"github.com/pkg/errors"
    39  	"github.com/rs/zerolog"
    40  	mtls "go-micro.dev/v4/util/tls"
    41  	"go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc"
    42  	"go.opentelemetry.io/otel/trace"
    43  	"google.golang.org/grpc"
    44  	"google.golang.org/grpc/credentials"
    45  	"google.golang.org/grpc/keepalive"
    46  	"google.golang.org/grpc/reflection"
    47  )
    48  
    49  // UnaryInterceptors is a map of registered unary grpc interceptors.
    50  var UnaryInterceptors = map[string]NewUnaryInterceptor{}
    51  
    52  // StreamInterceptors is a map of registered streaming grpc interceptor
    53  var StreamInterceptors = map[string]NewStreamInterceptor{}
    54  
    55  // NewUnaryInterceptor is the type that unary interceptors need to register.
    56  type NewUnaryInterceptor func(m map[string]interface{}) (grpc.UnaryServerInterceptor, int, error)
    57  
    58  // NewStreamInterceptor is the type that stream interceptors need to register.
    59  type NewStreamInterceptor func(m map[string]interface{}) (grpc.StreamServerInterceptor, int, error)
    60  
    61  // RegisterUnaryInterceptor registers a new unary interceptor.
    62  func RegisterUnaryInterceptor(name string, newFunc NewUnaryInterceptor) {
    63  	UnaryInterceptors[name] = newFunc
    64  }
    65  
    66  // RegisterStreamInterceptor registers a new stream interceptor.
    67  func RegisterStreamInterceptor(name string, newFunc NewStreamInterceptor) {
    68  	StreamInterceptors[name] = newFunc
    69  }
    70  
    71  // Services is a map of service name and its new function.
    72  var Services = map[string]NewService{}
    73  
    74  // Register registers a new gRPC service with name and new function.
    75  func Register(name string, newFunc NewService) {
    76  	Services[name] = newFunc
    77  }
    78  
    79  // NewService is the function that gRPC services need to register at init time.
    80  // It returns an io.Closer to close the service and a list of service endpoints that need to be unprotected.
    81  type NewService func(conf map[string]interface{}, ss *grpc.Server, log *zerolog.Logger) (Service, error)
    82  
    83  // Service represents a grpc service.
    84  type Service interface {
    85  	Register(ss *grpc.Server)
    86  	io.Closer
    87  	UnprotectedEndpoints() []string
    88  }
    89  
    90  type unaryInterceptorTriple struct {
    91  	Name        string
    92  	Priority    int
    93  	Interceptor grpc.UnaryServerInterceptor
    94  }
    95  
    96  type streamInterceptorTriple struct {
    97  	Name        string
    98  	Priority    int
    99  	Interceptor grpc.StreamServerInterceptor
   100  }
   101  
   102  type tlsSettings struct {
   103  	Enabled         bool   `mapstructure:"enabled"`
   104  	CertificateFile string `mapstructure:"certificate"`
   105  	KeyFile         string `mapstructure:"key"`
   106  	tlsConfig       *tls.Config
   107  }
   108  
   109  type config struct {
   110  	Network          string                            `mapstructure:"network"`
   111  	Address          string                            `mapstructure:"address"`
   112  	TLSSettings      tlsSettings                       `mapstructure:"tls_settings"`
   113  	ShutdownDeadline int                               `mapstructure:"shutdown_deadline"`
   114  	Services         map[string]map[string]interface{} `mapstructure:"services"`
   115  	Interceptors     map[string]map[string]interface{} `mapstructure:"interceptors"`
   116  	EnableReflection bool                              `mapstructure:"enable_reflection"`
   117  }
   118  
   119  func (c *config) init() {
   120  	if c.Network == "" {
   121  		c.Network = "tcp"
   122  	}
   123  
   124  	if c.Address == "" {
   125  		c.Address = sharedconf.GetGatewaySVC("0.0.0.0:19000")
   126  	}
   127  }
   128  
   129  // Server is a gRPC server.
   130  type Server struct {
   131  	s              *grpc.Server
   132  	conf           *config
   133  	listener       net.Listener
   134  	log            zerolog.Logger
   135  	tracerProvider trace.TracerProvider
   136  	services       map[string]Service
   137  }
   138  
   139  // NewServer returns a new Server.
   140  func NewServer(m interface{}, log zerolog.Logger, tp trace.TracerProvider) (*Server, error) {
   141  	var err error
   142  	conf := &config{}
   143  	if err := mapstructure.Decode(m, conf); err != nil {
   144  		return nil, err
   145  	}
   146  
   147  	conf.init()
   148  
   149  	if conf.TLSSettings.Enabled {
   150  		var cert tls.Certificate
   151  		switch {
   152  		case conf.TLSSettings.CertificateFile == "" && conf.TLSSettings.KeyFile == "":
   153  			// Generate a self-signed server certificate on the fly. This requires the clients
   154  			// to connect with InsecureSkipVerify.
   155  			subj := []string{conf.Address}
   156  			if host, _, err := net.SplitHostPort(conf.Address); err == nil && host != "" {
   157  				subj = []string{host}
   158  			}
   159  
   160  			log.Warn().Str("address", conf.Address).Str("network", conf.Network).
   161  				Msg("No server certificate configured. Generating a temporary self-signed certificate")
   162  
   163  			cert, err = mtls.Certificate(subj...)
   164  			if err != nil {
   165  				return nil, err
   166  			}
   167  		default:
   168  			cert, err = tls.LoadX509KeyPair(
   169  				conf.TLSSettings.CertificateFile,
   170  				conf.TLSSettings.KeyFile,
   171  			)
   172  			if err != nil {
   173  				return nil, err
   174  			}
   175  		}
   176  		conf.TLSSettings.tlsConfig = &tls.Config{Certificates: []tls.Certificate{cert}}
   177  	}
   178  
   179  	server := &Server{conf: conf, log: log, tracerProvider: tp, services: map[string]Service{}}
   180  
   181  	return server, nil
   182  }
   183  
   184  // Start starts the server.
   185  func (s *Server) Start(ln net.Listener) error {
   186  	if err := s.registerServices(); err != nil {
   187  		err = errors.Wrap(err, "unable to register services")
   188  		return err
   189  	}
   190  
   191  	s.listener = ln
   192  	s.log.Info().Msgf("grpc server listening at %s:%s", s.Network(), s.Address())
   193  	err := s.s.Serve(s.listener)
   194  	if err != nil {
   195  		err = errors.Wrap(err, "serve failed")
   196  		return err
   197  	}
   198  	return nil
   199  }
   200  
   201  func (s *Server) isInterceptorEnabled(name string) bool {
   202  	for k := range s.conf.Interceptors {
   203  		if k == name {
   204  			return true
   205  		}
   206  	}
   207  	return false
   208  }
   209  
   210  func (s *Server) isServiceEnabled(svcName string) bool {
   211  	for key := range Services {
   212  		if key == svcName {
   213  			return true
   214  		}
   215  	}
   216  	return false
   217  }
   218  
   219  func (s *Server) registerServices() error {
   220  	for svcName := range s.conf.Services {
   221  		if s.isServiceEnabled(svcName) {
   222  			newFunc := Services[svcName]
   223  			svc, err := newFunc(s.conf.Services[svcName], s.s, &s.log)
   224  			if err != nil {
   225  				return errors.Wrapf(err, "rgrpc: grpc service %s could not be started,", svcName)
   226  			}
   227  			s.services[svcName] = svc
   228  			s.log.Info().Msgf("rgrpc: grpc service enabled: %s", svcName)
   229  		} else {
   230  			message := fmt.Sprintf("rgrpc: grpc service %s does not exist", svcName)
   231  			return errors.New(message)
   232  		}
   233  	}
   234  
   235  	// obtain list of unprotected endpoints
   236  	unprotected := []string{}
   237  	for _, svc := range s.services {
   238  		unprotected = append(unprotected, svc.UnprotectedEndpoints()...)
   239  	}
   240  
   241  	opts, err := s.getInterceptors(unprotected)
   242  	if err != nil {
   243  		return err
   244  	}
   245  
   246  	if s.conf.TLSSettings.tlsConfig != nil {
   247  		opts = append(opts, grpc.Creds(credentials.NewTLS(s.conf.TLSSettings.tlsConfig)))
   248  	}
   249  	opts = append(opts, grpc.KeepaliveParams(keepalive.ServerParameters{
   250  		MaxConnectionAge: GetMaxConnectionAge(), // this forces clients to reconnect after 30 seconds, triggering a new DNS lookup to pick up new IPs
   251  	}))
   252  
   253  	grpcServer := grpc.NewServer(opts...)
   254  
   255  	for _, svc := range s.services {
   256  		svc.Register(grpcServer)
   257  	}
   258  
   259  	if s.conf.EnableReflection {
   260  		s.log.Info().Msg("rgrpc: grpc server reflection enabled")
   261  		reflection.Register(grpcServer)
   262  	}
   263  
   264  	s.s = grpcServer
   265  
   266  	return nil
   267  }
   268  
   269  // TODO(labkode): make closing with deadline.
   270  func (s *Server) cleanupServices() {
   271  	for name, svc := range s.services {
   272  		if err := svc.Close(); err != nil {
   273  			s.log.Error().Err(err).Msgf("error closing service %q", name)
   274  		} else {
   275  			s.log.Info().Msgf("service %q correctly closed", name)
   276  		}
   277  	}
   278  }
   279  
   280  // Stop stops the server.
   281  func (s *Server) Stop() error {
   282  	s.cleanupServices()
   283  	s.s.Stop()
   284  	return nil
   285  }
   286  
   287  // GracefulStop gracefully stops the server.
   288  func (s *Server) GracefulStop() error {
   289  	s.cleanupServices()
   290  	s.s.GracefulStop()
   291  	return nil
   292  }
   293  
   294  // Network returns the network type.
   295  func (s *Server) Network() string {
   296  	return s.conf.Network
   297  }
   298  
   299  // Address returns the network address.
   300  func (s *Server) Address() string {
   301  	return s.conf.Address
   302  }
   303  
   304  func (s *Server) getInterceptors(unprotected []string) ([]grpc.ServerOption, error) {
   305  	unaryTriples := []*unaryInterceptorTriple{}
   306  	for name, newFunc := range UnaryInterceptors {
   307  		if s.isInterceptorEnabled(name) {
   308  			inter, prio, err := newFunc(s.conf.Interceptors[name])
   309  			if err != nil {
   310  				err = errors.Wrapf(err, "rgrpc: error creating unary interceptor: %s,", name)
   311  				return nil, err
   312  			}
   313  			triple := &unaryInterceptorTriple{
   314  				Name:        name,
   315  				Priority:    prio,
   316  				Interceptor: inter,
   317  			}
   318  			unaryTriples = append(unaryTriples, triple)
   319  		}
   320  	}
   321  
   322  	// sort unary triples
   323  	sort.SliceStable(unaryTriples, func(i, j int) bool {
   324  		return unaryTriples[i].Priority < unaryTriples[j].Priority
   325  	})
   326  
   327  	authUnary, err := auth.NewUnary(s.conf.Interceptors["auth"], unprotected, s.tracerProvider)
   328  	if err != nil {
   329  		return nil, errors.Wrap(err, "rgrpc: error creating unary auth interceptor")
   330  	}
   331  
   332  	unaryInterceptors := []grpc.UnaryServerInterceptor{
   333  		appctx.NewUnary(s.log, s.tracerProvider),
   334  		token.NewUnary(),
   335  		useragent.NewUnary(),
   336  		log.NewUnary(),
   337  		recovery.NewUnary(),
   338  		authUnary,
   339  	}
   340  
   341  	for _, t := range unaryTriples {
   342  		unaryInterceptors = append(unaryInterceptors, t.Interceptor)
   343  		s.log.Info().Msgf("rgrpc: chaining grpc unary interceptor %s with priority %d", t.Name, t.Priority)
   344  	}
   345  
   346  	unaryChain := grpc_middleware.ChainUnaryServer(unaryInterceptors...)
   347  
   348  	streamTriples := []*streamInterceptorTriple{}
   349  	for name, newFunc := range StreamInterceptors {
   350  		if s.isInterceptorEnabled(name) {
   351  			inter, prio, err := newFunc(s.conf.Interceptors[name])
   352  			if err != nil {
   353  				err = errors.Wrapf(err, "rgrpc: error creating streaming interceptor: %s,", name)
   354  				return nil, err
   355  			}
   356  			triple := &streamInterceptorTriple{
   357  				Name:        name,
   358  				Priority:    prio,
   359  				Interceptor: inter,
   360  			}
   361  			streamTriples = append(streamTriples, triple)
   362  		}
   363  	}
   364  	// sort stream triples
   365  	sort.SliceStable(streamTriples, func(i, j int) bool {
   366  		return streamTriples[i].Priority < streamTriples[j].Priority
   367  	})
   368  
   369  	authStream, err := auth.NewStream(s.conf.Interceptors["auth"], unprotected, s.tracerProvider)
   370  	if err != nil {
   371  		return nil, errors.Wrap(err, "rgrpc: error creating stream auth interceptor")
   372  	}
   373  
   374  	streamInterceptors := []grpc.StreamServerInterceptor{
   375  		appctx.NewStream(s.log, s.tracerProvider),
   376  		token.NewStream(),
   377  		useragent.NewStream(),
   378  		log.NewStream(),
   379  		recovery.NewStream(),
   380  		authStream,
   381  	}
   382  
   383  	for _, t := range streamTriples {
   384  		streamInterceptors = append(streamInterceptors, t.Interceptor)
   385  		s.log.Info().Msgf("rgrpc: chaining grpc streaming interceptor %s with priority %d", t.Name, t.Priority)
   386  	}
   387  	streamChain := grpc_middleware.ChainStreamServer(streamInterceptors...)
   388  
   389  	opts := []grpc.ServerOption{
   390  		grpc.StatsHandler(otelgrpc.NewServerHandler(
   391  			otelgrpc.WithTracerProvider(s.tracerProvider),
   392  			otelgrpc.WithPropagators(rtrace.Propagator))),
   393  		grpc.UnaryInterceptor(unaryChain),
   394  		grpc.StreamInterceptor(streamChain),
   395  	}
   396  
   397  	return opts, nil
   398  }