github.com/cs3org/reva/v2@v2.27.7/pkg/rhttp/rhttp.go (about) 1 // Copyright 2018-2021 CERN 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 // 15 // In applying this license, CERN does not waive the privileges and immunities 16 // granted to it by virtue of its status as an Intergovernmental Organization 17 // or submit itself to any jurisdiction. 18 19 package rhttp 20 21 import ( 22 "context" 23 "fmt" 24 "net" 25 "net/http" 26 "path" 27 "sort" 28 "time" 29 30 "github.com/cs3org/reva/v2/internal/http/interceptors/appctx" 31 "github.com/cs3org/reva/v2/internal/http/interceptors/auth" 32 "github.com/cs3org/reva/v2/internal/http/interceptors/log" 33 "github.com/cs3org/reva/v2/internal/http/interceptors/providerauthorizer" 34 "github.com/cs3org/reva/v2/pkg/rhttp/global" 35 "github.com/cs3org/reva/v2/pkg/rhttp/router" 36 rtrace "github.com/cs3org/reva/v2/pkg/trace" 37 "github.com/mitchellh/mapstructure" 38 "github.com/pkg/errors" 39 "github.com/rs/zerolog" 40 "go.opentelemetry.io/otel/propagation" 41 "go.opentelemetry.io/otel/trace" 42 ) 43 44 // name is the Tracer name used to identify this instrumentation library. 45 const tracerName = "rhttp" 46 47 // New returns a new server 48 func New(m interface{}, l zerolog.Logger, tp trace.TracerProvider) (*Server, error) { 49 conf := &config{} 50 if err := mapstructure.Decode(m, conf); err != nil { 51 return nil, err 52 } 53 54 conf.init() 55 56 httpServer := &http.Server{} 57 s := &Server{ 58 httpServer: httpServer, 59 conf: conf, 60 svcs: map[string]global.Service{}, 61 unprotected: []string{}, 62 handlers: map[string]http.Handler{}, 63 log: l, 64 tracerProvider: tp, 65 } 66 return s, nil 67 } 68 69 // Server contains the server info. 70 type Server struct { 71 httpServer *http.Server 72 conf *config 73 listener net.Listener 74 svcs map[string]global.Service // map key is svc Prefix 75 unprotected []string 76 handlers map[string]http.Handler 77 middlewares []*middlewareTriple 78 log zerolog.Logger 79 tracerProvider trace.TracerProvider 80 } 81 82 type config struct { 83 Network string `mapstructure:"network"` 84 Address string `mapstructure:"address"` 85 Services map[string]map[string]interface{} `mapstructure:"services"` 86 Middlewares map[string]map[string]interface{} `mapstructure:"middlewares"` 87 CertFile string `mapstructure:"certfile"` 88 KeyFile string `mapstructure:"keyfile"` 89 } 90 91 func (c *config) init() { 92 // apply defaults 93 if c.Network == "" { 94 c.Network = "tcp" 95 } 96 97 if c.Address == "" { 98 c.Address = "0.0.0.0:19001" 99 } 100 } 101 102 // Start starts the server 103 func (s *Server) Start(ln net.Listener) error { 104 if err := s.registerServices(); err != nil { 105 return err 106 } 107 108 if err := s.registerMiddlewares(); err != nil { 109 return err 110 } 111 112 handler, err := s.getHandler() 113 if err != nil { 114 return errors.Wrap(err, "rhttp: error creating http handler") 115 } 116 117 s.httpServer.Handler = handler 118 s.listener = ln 119 120 if (s.conf.CertFile != "") && (s.conf.KeyFile != "") { 121 s.log.Info().Msgf("https server listening at https://%s '%s' '%s'", s.conf.Address, s.conf.CertFile, s.conf.KeyFile) 122 err = s.httpServer.ServeTLS(s.listener, s.conf.CertFile, s.conf.KeyFile) 123 } else { 124 s.log.Info().Msgf("http server listening at http://%s '%s' '%s'", s.conf.Address, s.conf.CertFile, s.conf.KeyFile) 125 err = s.httpServer.Serve(s.listener) 126 } 127 if err == nil || err == http.ErrServerClosed { 128 return nil 129 } 130 return err 131 } 132 133 // Stop stops the server. 134 func (s *Server) Stop() error { 135 s.closeServices() 136 // TODO(labkode): set ctx deadline to zero 137 ctx, cancel := context.WithTimeout(context.Background(), time.Second) 138 defer cancel() 139 return s.httpServer.Shutdown(ctx) 140 } 141 142 // TODO(labkode): we can't stop the server shutdown because a service cannot be shutdown. 143 // What do we do in case a service cannot be properly closed? Now we just log the error. 144 // TODO(labkode): the close should be given a deadline using context.Context. 145 func (s *Server) closeServices() { 146 for _, svc := range s.svcs { 147 if err := svc.Close(); err != nil { 148 s.log.Error().Err(err).Msgf("error closing service %q", svc.Prefix()) 149 } else { 150 s.log.Info().Msgf("service %q correctly closed", svc.Prefix()) 151 } 152 } 153 } 154 155 // Network return the network type. 156 func (s *Server) Network() string { 157 return s.conf.Network 158 } 159 160 // Address returns the network address. 161 func (s *Server) Address() string { 162 return s.conf.Address 163 } 164 165 // GracefulStop gracefully stops the server. 166 func (s *Server) GracefulStop() error { 167 s.closeServices() 168 return s.httpServer.Shutdown(context.Background()) 169 } 170 171 // middlewareTriple represents a middleware with the 172 // priority to be chained. 173 type middlewareTriple struct { 174 Name string 175 Priority int 176 Middleware global.Middleware 177 } 178 179 func (s *Server) registerMiddlewares() error { 180 middlewares := []*middlewareTriple{} 181 for name, newFunc := range global.NewMiddlewares { 182 if s.isMiddlewareEnabled(name) { 183 m, prio, err := newFunc(s.conf.Middlewares[name]) 184 if err != nil { 185 err = errors.Wrapf(err, "error creating new middleware: %s,", name) 186 return err 187 } 188 middlewares = append(middlewares, &middlewareTriple{ 189 Name: name, 190 Priority: prio, 191 Middleware: m, 192 }) 193 s.log.Info().Msgf("http middleware enabled: %s", name) 194 } 195 } 196 s.middlewares = middlewares 197 return nil 198 } 199 200 func (s *Server) isMiddlewareEnabled(name string) bool { 201 _, ok := s.conf.Middlewares[name] 202 return ok 203 } 204 205 func (s *Server) registerServices() error { 206 for svcName := range s.conf.Services { 207 if s.isServiceEnabled(svcName) { 208 newFunc := global.Services[svcName] 209 svc, err := newFunc(s.conf.Services[svcName], &s.log) 210 if err != nil { 211 err = errors.Wrapf(err, "http service %s could not be started,", svcName) 212 return err 213 } 214 215 // instrument services with opencensus tracing. 216 h := traceHandler(svcName, svc.Handler(), s.tracerProvider) 217 s.handlers[svc.Prefix()] = h 218 s.svcs[svc.Prefix()] = svc 219 s.unprotected = append(s.unprotected, getUnprotected(svc.Prefix(), svc.Unprotected())...) 220 s.log.Info().Msgf("http service enabled: %s@/%s", svcName, svc.Prefix()) 221 } else { 222 message := fmt.Sprintf("http service %s does not exist", svcName) 223 return errors.New(message) 224 } 225 } 226 return nil 227 } 228 229 func (s *Server) isServiceEnabled(svcName string) bool { 230 _, ok := global.Services[svcName] 231 return ok 232 } 233 234 // TODO(labkode): if the http server is exposed under a basename we need to prepend 235 // to prefix. 236 func getUnprotected(prefix string, unprotected []string) []string { 237 for i := range unprotected { 238 unprotected[i] = path.Join("/", prefix, unprotected[i]) 239 } 240 return unprotected 241 } 242 243 func (s *Server) getHandler() (http.Handler, error) { 244 h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 245 head, tail := router.ShiftPath(r.URL.Path) 246 if h, ok := s.handlers[head]; ok { 247 r.URL.Path = tail 248 s.log.Debug().Msgf("http routing: head=%s tail=%s svc=%s", head, r.URL.Path, head) 249 h.ServeHTTP(w, r) 250 return 251 } 252 253 // when a service is exposed at the root. 254 if h, ok := s.handlers[""]; ok { 255 r.URL.Path = "/" + head + tail 256 s.log.Debug().Msgf("http routing: head= tail=%s svc=root", r.URL.Path) 257 h.ServeHTTP(w, r) 258 return 259 } 260 261 s.log.Debug().Msgf("http routing: head=%s tail=%s svc=not-found", head, tail) 262 w.WriteHeader(http.StatusNotFound) 263 }) 264 265 // sort middlewares by priority. 266 sort.SliceStable(s.middlewares, func(i, j int) bool { 267 return s.middlewares[i].Priority > s.middlewares[j].Priority 268 }) 269 270 handler := http.Handler(h) 271 272 for _, triple := range s.middlewares { 273 s.log.Info().Msgf("chaining http middleware %s with priority %d", triple.Name, triple.Priority) 274 handler = triple.Middleware(traceHandler(triple.Name, handler, s.tracerProvider)) 275 } 276 277 for _, v := range s.unprotected { 278 s.log.Info().Msgf("unprotected URL: %s", v) 279 } 280 authMiddle, err := auth.New(s.conf.Middlewares["auth"], s.unprotected, s.tracerProvider) 281 if err != nil { 282 return nil, errors.Wrap(err, "rhttp: error creating auth middleware") 283 } 284 285 // add always the logctx middleware as most priority, this middleware is internal 286 // and cannot be configured from the configuration. 287 coreMiddlewares := []*middlewareTriple{} 288 289 providerAuthMiddle, err := addProviderAuthMiddleware(s.conf, s.unprotected) 290 if err != nil { 291 return nil, errors.Wrap(err, "rhttp: error creating providerauthorizer middleware") 292 } 293 if providerAuthMiddle != nil { 294 coreMiddlewares = append(coreMiddlewares, &middlewareTriple{Middleware: providerAuthMiddle, Name: "providerauthorizer"}) 295 } 296 297 coreMiddlewares = append(coreMiddlewares, &middlewareTriple{Middleware: authMiddle, Name: "auth"}) 298 coreMiddlewares = append(coreMiddlewares, &middlewareTriple{Middleware: log.New(), Name: "log"}) 299 coreMiddlewares = append(coreMiddlewares, &middlewareTriple{Middleware: appctx.New(s.log, s.tracerProvider), Name: "appctx"}) 300 301 for _, triple := range coreMiddlewares { 302 handler = triple.Middleware(traceHandler(triple.Name, handler, s.tracerProvider)) 303 } 304 305 return handler, nil 306 } 307 308 func traceHandler(name string, h http.Handler, tp trace.TracerProvider) http.Handler { 309 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 310 ctx := rtrace.Propagator.Extract(r.Context(), propagation.HeaderCarrier(r.Header)) 311 t := tp.Tracer(tracerName) 312 ctx, span := t.Start(ctx, name) 313 defer span.End() 314 315 rtrace.Propagator.Inject(ctx, propagation.HeaderCarrier(r.Header)) 316 h.ServeHTTP(w, r.WithContext(ctx)) 317 }) 318 } 319 320 func addProviderAuthMiddleware(conf *config, unprotected []string) (global.Middleware, error) { 321 _, ocmdRegistered := global.Services["ocmd"] 322 _, ocmdEnabled := conf.Services["ocmd"] 323 ocmdPrefix, _ := conf.Services["ocmd"]["prefix"].(string) 324 if ocmdRegistered && ocmdEnabled { 325 return providerauthorizer.New(conf.Middlewares["providerauthorizer"], unprotected, ocmdPrefix) 326 } 327 return nil, nil 328 }