github.com/cs3org/reva/v2@v2.27.7/pkg/rhttp/rhttp.go (about)

     1  // Copyright 2018-2021 CERN
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // In applying this license, CERN does not waive the privileges and immunities
    16  // granted to it by virtue of its status as an Intergovernmental Organization
    17  // or submit itself to any jurisdiction.
    18  
    19  package rhttp
    20  
    21  import (
    22  	"context"
    23  	"fmt"
    24  	"net"
    25  	"net/http"
    26  	"path"
    27  	"sort"
    28  	"time"
    29  
    30  	"github.com/cs3org/reva/v2/internal/http/interceptors/appctx"
    31  	"github.com/cs3org/reva/v2/internal/http/interceptors/auth"
    32  	"github.com/cs3org/reva/v2/internal/http/interceptors/log"
    33  	"github.com/cs3org/reva/v2/internal/http/interceptors/providerauthorizer"
    34  	"github.com/cs3org/reva/v2/pkg/rhttp/global"
    35  	"github.com/cs3org/reva/v2/pkg/rhttp/router"
    36  	rtrace "github.com/cs3org/reva/v2/pkg/trace"
    37  	"github.com/mitchellh/mapstructure"
    38  	"github.com/pkg/errors"
    39  	"github.com/rs/zerolog"
    40  	"go.opentelemetry.io/otel/propagation"
    41  	"go.opentelemetry.io/otel/trace"
    42  )
    43  
    44  // name is the Tracer name used to identify this instrumentation library.
    45  const tracerName = "rhttp"
    46  
    47  // New returns a new server
    48  func New(m interface{}, l zerolog.Logger, tp trace.TracerProvider) (*Server, error) {
    49  	conf := &config{}
    50  	if err := mapstructure.Decode(m, conf); err != nil {
    51  		return nil, err
    52  	}
    53  
    54  	conf.init()
    55  
    56  	httpServer := &http.Server{}
    57  	s := &Server{
    58  		httpServer:     httpServer,
    59  		conf:           conf,
    60  		svcs:           map[string]global.Service{},
    61  		unprotected:    []string{},
    62  		handlers:       map[string]http.Handler{},
    63  		log:            l,
    64  		tracerProvider: tp,
    65  	}
    66  	return s, nil
    67  }
    68  
    69  // Server contains the server info.
    70  type Server struct {
    71  	httpServer     *http.Server
    72  	conf           *config
    73  	listener       net.Listener
    74  	svcs           map[string]global.Service // map key is svc Prefix
    75  	unprotected    []string
    76  	handlers       map[string]http.Handler
    77  	middlewares    []*middlewareTriple
    78  	log            zerolog.Logger
    79  	tracerProvider trace.TracerProvider
    80  }
    81  
    82  type config struct {
    83  	Network     string                            `mapstructure:"network"`
    84  	Address     string                            `mapstructure:"address"`
    85  	Services    map[string]map[string]interface{} `mapstructure:"services"`
    86  	Middlewares map[string]map[string]interface{} `mapstructure:"middlewares"`
    87  	CertFile    string                            `mapstructure:"certfile"`
    88  	KeyFile     string                            `mapstructure:"keyfile"`
    89  }
    90  
    91  func (c *config) init() {
    92  	// apply defaults
    93  	if c.Network == "" {
    94  		c.Network = "tcp"
    95  	}
    96  
    97  	if c.Address == "" {
    98  		c.Address = "0.0.0.0:19001"
    99  	}
   100  }
   101  
   102  // Start starts the server
   103  func (s *Server) Start(ln net.Listener) error {
   104  	if err := s.registerServices(); err != nil {
   105  		return err
   106  	}
   107  
   108  	if err := s.registerMiddlewares(); err != nil {
   109  		return err
   110  	}
   111  
   112  	handler, err := s.getHandler()
   113  	if err != nil {
   114  		return errors.Wrap(err, "rhttp: error creating http handler")
   115  	}
   116  
   117  	s.httpServer.Handler = handler
   118  	s.listener = ln
   119  
   120  	if (s.conf.CertFile != "") && (s.conf.KeyFile != "") {
   121  		s.log.Info().Msgf("https server listening at https://%s '%s' '%s'", s.conf.Address, s.conf.CertFile, s.conf.KeyFile)
   122  		err = s.httpServer.ServeTLS(s.listener, s.conf.CertFile, s.conf.KeyFile)
   123  	} else {
   124  		s.log.Info().Msgf("http server listening at http://%s '%s' '%s'", s.conf.Address, s.conf.CertFile, s.conf.KeyFile)
   125  		err = s.httpServer.Serve(s.listener)
   126  	}
   127  	if err == nil || err == http.ErrServerClosed {
   128  		return nil
   129  	}
   130  	return err
   131  }
   132  
   133  // Stop stops the server.
   134  func (s *Server) Stop() error {
   135  	s.closeServices()
   136  	// TODO(labkode): set ctx deadline to zero
   137  	ctx, cancel := context.WithTimeout(context.Background(), time.Second)
   138  	defer cancel()
   139  	return s.httpServer.Shutdown(ctx)
   140  }
   141  
   142  // TODO(labkode): we can't stop the server shutdown because a service cannot be shutdown.
   143  // What do we do in case a service cannot be properly closed? Now we just log the error.
   144  // TODO(labkode): the close should be given a deadline using context.Context.
   145  func (s *Server) closeServices() {
   146  	for _, svc := range s.svcs {
   147  		if err := svc.Close(); err != nil {
   148  			s.log.Error().Err(err).Msgf("error closing service %q", svc.Prefix())
   149  		} else {
   150  			s.log.Info().Msgf("service %q correctly closed", svc.Prefix())
   151  		}
   152  	}
   153  }
   154  
   155  // Network return the network type.
   156  func (s *Server) Network() string {
   157  	return s.conf.Network
   158  }
   159  
   160  // Address returns the network address.
   161  func (s *Server) Address() string {
   162  	return s.conf.Address
   163  }
   164  
   165  // GracefulStop gracefully stops the server.
   166  func (s *Server) GracefulStop() error {
   167  	s.closeServices()
   168  	return s.httpServer.Shutdown(context.Background())
   169  }
   170  
   171  // middlewareTriple represents a middleware with the
   172  // priority to be chained.
   173  type middlewareTriple struct {
   174  	Name       string
   175  	Priority   int
   176  	Middleware global.Middleware
   177  }
   178  
   179  func (s *Server) registerMiddlewares() error {
   180  	middlewares := []*middlewareTriple{}
   181  	for name, newFunc := range global.NewMiddlewares {
   182  		if s.isMiddlewareEnabled(name) {
   183  			m, prio, err := newFunc(s.conf.Middlewares[name])
   184  			if err != nil {
   185  				err = errors.Wrapf(err, "error creating new middleware: %s,", name)
   186  				return err
   187  			}
   188  			middlewares = append(middlewares, &middlewareTriple{
   189  				Name:       name,
   190  				Priority:   prio,
   191  				Middleware: m,
   192  			})
   193  			s.log.Info().Msgf("http middleware enabled: %s", name)
   194  		}
   195  	}
   196  	s.middlewares = middlewares
   197  	return nil
   198  }
   199  
   200  func (s *Server) isMiddlewareEnabled(name string) bool {
   201  	_, ok := s.conf.Middlewares[name]
   202  	return ok
   203  }
   204  
   205  func (s *Server) registerServices() error {
   206  	for svcName := range s.conf.Services {
   207  		if s.isServiceEnabled(svcName) {
   208  			newFunc := global.Services[svcName]
   209  			svc, err := newFunc(s.conf.Services[svcName], &s.log)
   210  			if err != nil {
   211  				err = errors.Wrapf(err, "http service %s could not be started,", svcName)
   212  				return err
   213  			}
   214  
   215  			// instrument services with opencensus tracing.
   216  			h := traceHandler(svcName, svc.Handler(), s.tracerProvider)
   217  			s.handlers[svc.Prefix()] = h
   218  			s.svcs[svc.Prefix()] = svc
   219  			s.unprotected = append(s.unprotected, getUnprotected(svc.Prefix(), svc.Unprotected())...)
   220  			s.log.Info().Msgf("http service enabled: %s@/%s", svcName, svc.Prefix())
   221  		} else {
   222  			message := fmt.Sprintf("http service %s does not exist", svcName)
   223  			return errors.New(message)
   224  		}
   225  	}
   226  	return nil
   227  }
   228  
   229  func (s *Server) isServiceEnabled(svcName string) bool {
   230  	_, ok := global.Services[svcName]
   231  	return ok
   232  }
   233  
   234  // TODO(labkode): if the http server is exposed under a basename we need to prepend
   235  // to prefix.
   236  func getUnprotected(prefix string, unprotected []string) []string {
   237  	for i := range unprotected {
   238  		unprotected[i] = path.Join("/", prefix, unprotected[i])
   239  	}
   240  	return unprotected
   241  }
   242  
   243  func (s *Server) getHandler() (http.Handler, error) {
   244  	h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   245  		head, tail := router.ShiftPath(r.URL.Path)
   246  		if h, ok := s.handlers[head]; ok {
   247  			r.URL.Path = tail
   248  			s.log.Debug().Msgf("http routing: head=%s tail=%s svc=%s", head, r.URL.Path, head)
   249  			h.ServeHTTP(w, r)
   250  			return
   251  		}
   252  
   253  		// when a service is exposed at the root.
   254  		if h, ok := s.handlers[""]; ok {
   255  			r.URL.Path = "/" + head + tail
   256  			s.log.Debug().Msgf("http routing: head= tail=%s svc=root", r.URL.Path)
   257  			h.ServeHTTP(w, r)
   258  			return
   259  		}
   260  
   261  		s.log.Debug().Msgf("http routing: head=%s tail=%s svc=not-found", head, tail)
   262  		w.WriteHeader(http.StatusNotFound)
   263  	})
   264  
   265  	// sort middlewares by priority.
   266  	sort.SliceStable(s.middlewares, func(i, j int) bool {
   267  		return s.middlewares[i].Priority > s.middlewares[j].Priority
   268  	})
   269  
   270  	handler := http.Handler(h)
   271  
   272  	for _, triple := range s.middlewares {
   273  		s.log.Info().Msgf("chaining http middleware %s with priority  %d", triple.Name, triple.Priority)
   274  		handler = triple.Middleware(traceHandler(triple.Name, handler, s.tracerProvider))
   275  	}
   276  
   277  	for _, v := range s.unprotected {
   278  		s.log.Info().Msgf("unprotected URL: %s", v)
   279  	}
   280  	authMiddle, err := auth.New(s.conf.Middlewares["auth"], s.unprotected, s.tracerProvider)
   281  	if err != nil {
   282  		return nil, errors.Wrap(err, "rhttp: error creating auth middleware")
   283  	}
   284  
   285  	// add always the logctx middleware as most priority, this middleware is internal
   286  	// and cannot be configured from the configuration.
   287  	coreMiddlewares := []*middlewareTriple{}
   288  
   289  	providerAuthMiddle, err := addProviderAuthMiddleware(s.conf, s.unprotected)
   290  	if err != nil {
   291  		return nil, errors.Wrap(err, "rhttp: error creating providerauthorizer middleware")
   292  	}
   293  	if providerAuthMiddle != nil {
   294  		coreMiddlewares = append(coreMiddlewares, &middlewareTriple{Middleware: providerAuthMiddle, Name: "providerauthorizer"})
   295  	}
   296  
   297  	coreMiddlewares = append(coreMiddlewares, &middlewareTriple{Middleware: authMiddle, Name: "auth"})
   298  	coreMiddlewares = append(coreMiddlewares, &middlewareTriple{Middleware: log.New(), Name: "log"})
   299  	coreMiddlewares = append(coreMiddlewares, &middlewareTriple{Middleware: appctx.New(s.log, s.tracerProvider), Name: "appctx"})
   300  
   301  	for _, triple := range coreMiddlewares {
   302  		handler = triple.Middleware(traceHandler(triple.Name, handler, s.tracerProvider))
   303  	}
   304  
   305  	return handler, nil
   306  }
   307  
   308  func traceHandler(name string, h http.Handler, tp trace.TracerProvider) http.Handler {
   309  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   310  		ctx := rtrace.Propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header))
   311  		t := tp.Tracer(tracerName)
   312  		ctx, span := t.Start(ctx, name)
   313  		defer span.End()
   314  
   315  		rtrace.Propagator.Inject(ctx, propagation.HeaderCarrier(r.Header))
   316  		h.ServeHTTP(w, r.WithContext(ctx))
   317  	})
   318  }
   319  
   320  func addProviderAuthMiddleware(conf *config, unprotected []string) (global.Middleware, error) {
   321  	_, ocmdRegistered := global.Services["ocmd"]
   322  	_, ocmdEnabled := conf.Services["ocmd"]
   323  	ocmdPrefix, _ := conf.Services["ocmd"]["prefix"].(string)
   324  	if ocmdRegistered && ocmdEnabled {
   325  		return providerauthorizer.New(conf.Middlewares["providerauthorizer"], unprotected, ocmdPrefix)
   326  	}
   327  	return nil, nil
   328  }