github.com/shogo82148/goa-v1@v1.6.2/service.go (about)

     1  package goa
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"log"
     8  	"net"
     9  	"net/http"
    10  	"net/url"
    11  	"os"
    12  	"path/filepath"
    13  	"sort"
    14  	"strings"
    15  	"sync"
    16  
    17  	"github.com/dimfeld/httptreemux"
    18  )
    19  
    20  type (
    21  	// Service is the data structure supporting goa services.
    22  	// It provides methods for configuring a service and running it.
    23  	// At the basic level a service consists of a set of controllers, each implementing a given
    24  	// resource actions. goagen generates global functions - one per resource - that make it
    25  	// possible to mount the corresponding controller onto a service. A service contains the
    26  	// middleware, not found handler, encoders and muxes shared by all its controllers.
    27  	Service struct {
    28  		// Name of service used for logging, tracing etc.
    29  		Name string
    30  		// Mux is the service request mux
    31  		Mux ServeMux
    32  		// Server is the service HTTP server.
    33  		Server *http.Server
    34  		// Context is the root context from which all request contexts are derived.
    35  		// Set values in the root context prior to starting the server to make these values
    36  		// available to all request handlers.
    37  		Context context.Context
    38  		// Request body decoder
    39  		Decoder *HTTPDecoder
    40  		// Response body encoder
    41  		Encoder *HTTPEncoder
    42  
    43  		middleware []Middleware       // Middleware chain
    44  		cancel     context.CancelFunc // Service context cancel signal trigger
    45  	}
    46  
    47  	// Controller defines the common fields and behavior of generated controllers.
    48  	Controller struct {
    49  		// Controller resource name
    50  		Name string
    51  		// Service that exposes the controller
    52  		Service *Service
    53  		// Controller root context
    54  		Context context.Context
    55  		// MaxRequestBodyLength is the maximum length read from request bodies.
    56  		// Set to 0 to remove the limit altogether. Defaults to 1GB.
    57  		MaxRequestBodyLength int64
    58  		// FileSystem is used in FileHandler to open files. By default it returns
    59  		// http.Dir but you can override it with another one that implements http.FileSystem.
    60  		// For example using github.com/elazarl/go-bindata-assetfs is like below.
    61  		//
    62  		//	ctrl.FileSystem = func(dir string) http.FileSystem {
    63  		//		return &assetfs.AssetFS{
    64  		//			Asset: Asset,
    65  		//			AssetDir: AssetDir,
    66  		//			AssetInfo: AssetInfo,
    67  		//			Prefix: dir,
    68  		//		}
    69  		//	}
    70  		FileSystem func(string) http.FileSystem
    71  
    72  		middleware []Middleware // Controller specific middleware if any
    73  	}
    74  
    75  	// FileServer is the interface implemented by controllers that can serve static files.
    76  	FileServer interface {
    77  		// FileHandler returns a handler that serves files under the given request path.
    78  		FileHandler(path, filename string) Handler
    79  	}
    80  
    81  	// Handler defines the request handler signatures.
    82  	Handler func(context.Context, http.ResponseWriter, *http.Request) error
    83  
    84  	// Unmarshaler defines the request payload unmarshaler signatures.
    85  	Unmarshaler func(context.Context, *Service, *http.Request) error
    86  
    87  	// DecodeFunc is the function that initialize the unmarshaled payload from the request body.
    88  	DecodeFunc func(context.Context, io.ReadCloser, interface{}) error
    89  )
    90  
    91  // New instantiates a service with the given name.
    92  func New(name string) *Service {
    93  	var (
    94  		stdlog       = log.New(os.Stderr, "", log.LstdFlags)
    95  		ctx          = WithLogger(context.Background(), NewLogger(stdlog))
    96  		cctx, cancel = context.WithCancel(ctx)
    97  		mux          = NewMux()
    98  		service      = &Service{
    99  			Name:    name,
   100  			Context: cctx,
   101  			Mux:     mux,
   102  			Server: &http.Server{
   103  				Handler: mux,
   104  			},
   105  			Decoder: NewHTTPDecoder(),
   106  			Encoder: NewHTTPEncoder(),
   107  
   108  			cancel: cancel,
   109  		}
   110  		notFoundHandler         Handler
   111  		methodNotAllowedHandler Handler
   112  	)
   113  
   114  	// Setup default NotFound handler
   115  	mux.HandleNotFound(func(rw http.ResponseWriter, req *http.Request, params url.Values) {
   116  		if resp := ContextResponse(ctx); resp != nil && resp.Written() {
   117  			return
   118  		}
   119  		// Use closure to do lazy computation of middleware chain so all middlewares are
   120  		// registered.
   121  		if notFoundHandler == nil {
   122  			notFoundHandler = func(_ context.Context, _ http.ResponseWriter, req *http.Request) error {
   123  				return ErrNotFound(req.URL.Path)
   124  			}
   125  			chain := service.middleware
   126  			ml := len(chain)
   127  			for i := range chain {
   128  				notFoundHandler = chain[ml-i-1](notFoundHandler)
   129  			}
   130  		}
   131  		ctx := NewContext(service.Context, rw, req, params)
   132  		err := notFoundHandler(ctx, ContextResponse(ctx), req)
   133  		if !ContextResponse(ctx).Written() {
   134  			service.Send(ctx, 404, err)
   135  		}
   136  	})
   137  
   138  	// Setup default MethodNotAllowed handler
   139  	mux.HandleMethodNotAllowed(func(rw http.ResponseWriter, req *http.Request, params url.Values, methods map[string]httptreemux.HandlerFunc) {
   140  		if resp := ContextResponse(ctx); resp != nil && resp.Written() {
   141  			return
   142  		}
   143  		// Use closure to do lazy computation of middleware chain so all middlewares are
   144  		// registered.
   145  		if methodNotAllowedHandler == nil {
   146  			methodNotAllowedHandler = func(_ context.Context, rw http.ResponseWriter, req *http.Request) error {
   147  				allowedMethods := make([]string, len(methods))
   148  				i := 0
   149  				for k := range methods {
   150  					allowedMethods[i] = k
   151  					i++
   152  				}
   153  				rw.Header().Set("Allow", strings.Join(allowedMethods, ", "))
   154  				return MethodNotAllowedError(req.Method, allowedMethods)
   155  			}
   156  			chain := service.middleware
   157  			ml := len(chain)
   158  			for i := range chain {
   159  				methodNotAllowedHandler = chain[ml-i-1](methodNotAllowedHandler)
   160  			}
   161  		}
   162  		ctx := NewContext(service.Context, rw, req, params)
   163  		err := methodNotAllowedHandler(ctx, ContextResponse(ctx), req)
   164  		if !ContextResponse(ctx).Written() {
   165  			service.Send(ctx, 405, err)
   166  		}
   167  	})
   168  
   169  	return service
   170  }
   171  
   172  // CancelAll sends a cancel signals to all request handlers via the context.
   173  // See https://golang.org/pkg/context/ for details on how to handle the signal.
   174  func (service *Service) CancelAll() {
   175  	service.cancel()
   176  }
   177  
   178  // Use adds a middleware to the service wide middleware chain.
   179  // goa comes with a set of commonly used middleware, see the middleware package.
   180  // Controller specific middleware should be mounted using the Controller struct Use method instead.
   181  func (service *Service) Use(m Middleware) {
   182  	service.middleware = append(service.middleware, m)
   183  }
   184  
   185  // WithLogger sets the logger used internally by the service and by Log.
   186  func (service *Service) WithLogger(logger LogAdapter) {
   187  	service.Context = WithLogger(service.Context, logger)
   188  }
   189  
   190  // LogInfo logs the message and values at odd indexes using the keys at even indexes of the keyvals slice.
   191  func (service *Service) LogInfo(msg string, keyvals ...interface{}) {
   192  	ctx := service.Context
   193  
   194  	// this block should be synced with LogInfo.
   195  	// we want not to call LogInfo because it changes the call stack
   196  	// and makes the log adapter more complex to implement.
   197  	if l := ctx.Value(logKey); l != nil {
   198  		switch logger := l.(type) {
   199  		case ContextLogAdapter:
   200  			logger.InfoContext(ctx, msg, keyvals...)
   201  		case LogAdapter:
   202  			logger.Info(msg, keyvals...)
   203  		}
   204  	}
   205  }
   206  
   207  // LogError logs the error and values at odd indexes using the keys at even indexes of the keyvals slice.
   208  func (service *Service) LogError(msg string, keyvals ...interface{}) {
   209  	ctx := service.Context
   210  
   211  	// this block should be synced with LogError.
   212  	// we want not to call LogError because it changes the call stack
   213  	// and makes the log adapter more complex to implement.
   214  	if l := ctx.Value(logKey); l != nil {
   215  		switch logger := l.(type) {
   216  		case ContextLogAdapter:
   217  			logger.ErrorContext(ctx, msg, keyvals...)
   218  		case LogAdapter:
   219  			logger.Error(msg, keyvals...)
   220  		}
   221  	}
   222  }
   223  
   224  // ListenAndServe starts a HTTP server and sets up a listener on the given host/port.
   225  func (service *Service) ListenAndServe(addr string) error {
   226  	service.LogInfo("listen", "transport", "http", "addr", addr)
   227  	service.Server.Addr = addr
   228  	return service.Server.ListenAndServe()
   229  }
   230  
   231  // ListenAndServeTLS starts a HTTPS server and sets up a listener on the given host/port.
   232  func (service *Service) ListenAndServeTLS(addr, certFile, keyFile string) error {
   233  	service.LogInfo("listen", "transport", "https", "addr", addr)
   234  	service.Server.Addr = addr
   235  	return service.Server.ListenAndServeTLS(certFile, keyFile)
   236  }
   237  
   238  // Serve accepts incoming HTTP connections on the listener l, invoking the service mux handler for each.
   239  func (service *Service) Serve(l net.Listener) error {
   240  	return service.Server.Serve(l)
   241  }
   242  
   243  // NewController returns a controller for the given resource. This method is mainly intended for
   244  // use by the generated code. User code shouldn't have to call it directly.
   245  func (service *Service) NewController(name string) *Controller {
   246  	return &Controller{
   247  		Name:                 name,
   248  		Service:              service,
   249  		Context:              context.WithValue(service.Context, ctrlKey, name),
   250  		MaxRequestBodyLength: 1073741824, // 1 GB
   251  		FileSystem: func(dir string) http.FileSystem {
   252  			return http.Dir(dir)
   253  		},
   254  	}
   255  }
   256  
   257  // Send serializes the given body matching the request Accept header against the service
   258  // encoders. It uses the default service encoder if no match is found.
   259  func (service *Service) Send(ctx context.Context, code int, body interface{}) error {
   260  	r := ContextResponse(ctx)
   261  	if r == nil {
   262  		return fmt.Errorf("no response data in context")
   263  	}
   264  	r.WriteHeader(code)
   265  	return service.EncodeResponse(ctx, body)
   266  }
   267  
   268  // ServeFiles create a "FileServer" controller and calls ServerFiles on it.
   269  func (service *Service) ServeFiles(path, filename string) error {
   270  	ctrl := service.NewController("FileServer")
   271  	return ctrl.ServeFiles(path, filename)
   272  }
   273  
   274  // DecodeRequest uses the HTTP decoder to unmarshal the request body into the provided value based
   275  // on the request Content-Type header.
   276  func (service *Service) DecodeRequest(req *http.Request, v interface{}) error {
   277  	body, contentType := req.Body, req.Header.Get("Content-Type")
   278  	defer body.Close()
   279  
   280  	if err := service.Decoder.Decode(v, body, contentType); err != nil {
   281  		return fmt.Errorf("failed to decode request body with content type %#v: %s", contentType, err)
   282  	}
   283  
   284  	return nil
   285  }
   286  
   287  // EncodeResponse uses the HTTP encoder to marshal and write the response body based on the request
   288  // Accept header.
   289  func (service *Service) EncodeResponse(ctx context.Context, v interface{}) error {
   290  	accept := ContextRequest(ctx).Header.Get("Accept")
   291  	return service.Encoder.Encode(v, ContextResponse(ctx), accept)
   292  }
   293  
   294  // ServeFiles replies to the request with the contents of the named file or directory. See
   295  // FileHandler for details.
   296  func (ctrl *Controller) ServeFiles(path, filename string) error {
   297  	if strings.Contains(path, ":") {
   298  		return fmt.Errorf("path may only include wildcards that match the entire end of the URL (e.g. *filepath)")
   299  	}
   300  	LogInfo(ctrl.Context, "mount file", "name", filename, "route", fmt.Sprintf("GET %s", path))
   301  	handler := func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
   302  		if !ContextResponse(ctx).Written() {
   303  			return ctrl.FileHandler(path, filename)(ctx, rw, req)
   304  		}
   305  		return nil
   306  	}
   307  	ctrl.Service.Mux.Handle("GET", path, ctrl.MuxHandler("serve", handler, nil))
   308  	return nil
   309  }
   310  
   311  // Use adds a middleware to the controller.
   312  // Service-wide middleware should be added via the Service Use method instead.
   313  func (ctrl *Controller) Use(m Middleware) {
   314  	ctrl.middleware = append(ctrl.middleware, m)
   315  }
   316  
   317  // MuxHandler wraps a request handler into a MuxHandler. The MuxHandler initializes the request
   318  // context by loading the request state, invokes the handler and in case of error invokes the
   319  // controller (if there is one) or Service error handler.
   320  // This function is intended for the controller generated code. User code should not need to call
   321  // it directly.
   322  func (ctrl *Controller) MuxHandler(name string, hdlr Handler, unm Unmarshaler) MuxHandler {
   323  	// Use closure to enable late computation of handlers to ensure all middleware has been
   324  	// registered.
   325  	var handler Handler
   326  	var initHandler sync.Once
   327  
   328  	return func(rw http.ResponseWriter, req *http.Request, params url.Values) {
   329  		// Build handler middleware chains on first invocation
   330  		initHandler.Do(func() {
   331  			handler = func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
   332  				if !ContextResponse(ctx).Written() {
   333  					return hdlr(ctx, rw, req)
   334  				}
   335  				return nil
   336  			}
   337  			mwLen := len(ctrl.Service.middleware)
   338  			chain := append(ctrl.Service.middleware[:mwLen:mwLen], ctrl.middleware...)
   339  			ml := len(chain)
   340  			for i := range chain {
   341  				handler = chain[ml-i-1](handler)
   342  			}
   343  		})
   344  
   345  		// Build context
   346  		ctx := NewContext(WithAction(ctrl.Context, name), rw, req, params)
   347  
   348  		// Protect against request bodies with unreasonable length
   349  		if ctrl.MaxRequestBodyLength > 0 {
   350  			req.Body = http.MaxBytesReader(rw, req.Body, ctrl.MaxRequestBodyLength)
   351  		}
   352  
   353  		// Load body if any
   354  		if req.ContentLength > 0 && unm != nil {
   355  			if err := unm(ctx, ctrl.Service, req); err != nil {
   356  				if err.Error() == "http: request body too large" {
   357  					msg := fmt.Sprintf("request body length exceeds %d bytes", ctrl.MaxRequestBodyLength)
   358  					err = ErrRequestBodyTooLarge(msg)
   359  				} else {
   360  					err = ErrBadRequest(err)
   361  				}
   362  				ctx = WithError(ctx, err)
   363  			}
   364  		}
   365  
   366  		// Invoke handler
   367  		if err := handler(ctx, ContextResponse(ctx), req); err != nil {
   368  			LogError(ctx, "uncaught error", "err", err)
   369  			respBody := fmt.Sprintf("Internal error: %s", err) // Sprintf catches panics
   370  			ctrl.Service.Send(ctx, 500, respBody)
   371  		}
   372  	}
   373  }
   374  
   375  // FileHandler returns a handler that serves files under the given filename for the given route path.
   376  // The logic for what to do when the filename points to a file vs. a directory is the same as the
   377  // standard http package ServeFile function. The path may end with a wildcard that matches the rest
   378  // of the URL (e.g. *filepath). If it does the matching path is appended to filename to form the
   379  // full file path, so:
   380  //
   381  //	c.FileHandler("/index.html", "/www/data/index.html")
   382  //
   383  // Returns the content of the file "/www/data/index.html" when requests are sent to "/index.html"
   384  // and:
   385  //
   386  //	c.FileHandler("/assets/*filepath", "/www/data/assets")
   387  //
   388  // returns the content of the file "/www/data/assets/x/y/z" when requests are sent to
   389  // "/assets/x/y/z".
   390  func (ctrl *Controller) FileHandler(path, filename string) Handler {
   391  	var wc string
   392  	if idx := strings.LastIndex(path, "/*"); idx > -1 && idx < len(path)-1 {
   393  		wc = path[idx+2:]
   394  		if strings.Contains(wc, "/") {
   395  			wc = ""
   396  		}
   397  	}
   398  	return func(ctx context.Context, rw http.ResponseWriter, req *http.Request) error {
   399  		// prevent path traversal
   400  		if attemptsPathTraversal(req.URL.Path, path) {
   401  			return ErrNotFound(req.URL.Path)
   402  		}
   403  		fname := filename
   404  		if len(wc) > 0 {
   405  			if m, ok := ContextRequest(ctx).Params[wc]; ok {
   406  				fname = filepath.Join(filename, m[0])
   407  			}
   408  		}
   409  		LogInfo(ctx, "serve file", "name", fname, "route", req.URL.Path)
   410  		dir, name := filepath.Split(fname)
   411  		fs := ctrl.FileSystem(dir)
   412  		f, err := fs.Open(name)
   413  		if err != nil {
   414  			return ErrInvalidFile(err)
   415  		}
   416  		defer f.Close()
   417  		d, err := f.Stat()
   418  		if err != nil {
   419  			return ErrInvalidFile(err)
   420  		}
   421  		// use contents of index.html for directory, if present
   422  		if d.IsDir() {
   423  			index := strings.TrimSuffix(name, "/") + "/index.html"
   424  			ff, err := fs.Open(index)
   425  			if err == nil {
   426  				defer ff.Close()
   427  				dd, err := ff.Stat()
   428  				if err == nil {
   429  					name = index
   430  					d = dd
   431  					f = ff
   432  				}
   433  			}
   434  		}
   435  
   436  		// serveContent will check modification time
   437  		// Still a directory? (we didn't find an index.html file)
   438  		if d.IsDir() {
   439  			return dirList(rw, f)
   440  		}
   441  		http.ServeContent(rw, req, d.Name(), d.ModTime(), f)
   442  		return nil
   443  	}
   444  }
   445  
   446  func attemptsPathTraversal(req string, path string) bool {
   447  	if !strings.Contains(req, "..") {
   448  		return false
   449  	}
   450  
   451  	currentPathIdx := 0
   452  	if idx := strings.LastIndex(path, "/*"); idx > -1 && idx < len(path)-1 {
   453  		req = req[idx+1:]
   454  	}
   455  	for _, runeValue := range strings.FieldsFunc(req, isSlashRune) {
   456  		if runeValue == ".." {
   457  			currentPathIdx--
   458  			if currentPathIdx < 0 {
   459  				return true
   460  			}
   461  		} else {
   462  			currentPathIdx++
   463  		}
   464  	}
   465  	return false
   466  }
   467  
   468  func isSlashRune(r rune) bool {
   469  	return os.IsPathSeparator(uint8(r))
   470  }
   471  
   472  var replacer = strings.NewReplacer(
   473  	"&", "&amp;",
   474  	"<", "&lt;",
   475  	">", "&gt;",
   476  	// "&#34;" is shorter than "&quot;".
   477  	`"`, "&#34;",
   478  	// "&#39;" is shorter than "&apos;" and apos was not in HTML until HTML5.
   479  	"'", "&#39;",
   480  )
   481  
   482  func dirList(w http.ResponseWriter, f http.File) error {
   483  	dirs, err := f.Readdir(-1)
   484  	if err != nil {
   485  		return err
   486  	}
   487  	sort.Sort(byName(dirs))
   488  
   489  	w.Header().Set("Content-Type", "text/html; charset=utf-8")
   490  	fmt.Fprintf(w, "<pre>\n")
   491  	for _, d := range dirs {
   492  		name := d.Name()
   493  		if d.IsDir() {
   494  			name += "/"
   495  		}
   496  		// name may contain '?' or '#', which must be escaped to remain
   497  		// part of the URL path, and not indicate the start of a query
   498  		// string or fragment.
   499  		url := url.URL{Path: name}
   500  		fmt.Fprintf(w, "<a href=\"%s\">%s</a>\n", url.String(), replacer.Replace(name))
   501  	}
   502  	fmt.Fprintf(w, "</pre>\n")
   503  	return nil
   504  }
   505  
   506  type byName []os.FileInfo
   507  
   508  func (s byName) Len() int           { return len(s) }
   509  func (s byName) Less(i, j int) bool { return s[i].Name() < s[j].Name() }
   510  func (s byName) Swap(i, j int)      { s[i], s[j] = s[j], s[i] }