github.com/zak-blake/goa@v1.4.1/middleware/gzip/middleware.go (about)

     1  package gzip
     2  
     3  import (
     4  	"bytes"
     5  	"compress/gzip"
     6  	"context"
     7  	"io/ioutil"
     8  	"net/http"
     9  	"strconv"
    10  	"strings"
    11  	"sync"
    12  
    13  	"github.com/goadesign/goa"
    14  )
    15  
    16  // These compression constants are copied from the compress/gzip package.
    17  const (
    18  	encodingGzip = "gzip"
    19  
    20  	headerAcceptEncoding  = "Accept-Encoding"
    21  	headerContentEncoding = "Content-Encoding"
    22  	headerContentLength   = "Content-Length"
    23  	headerContentType     = "Content-Type"
    24  	headerVary            = "Vary"
    25  	headerRange           = "Range"
    26  	headerAcceptRanges    = "Accept-Ranges"
    27  	headerSecWebSocketKey = "Sec-WebSocket-Key"
    28  )
    29  
    30  // gzipResponseWriter wraps the http.ResponseWriter to provide gzip
    31  // capabilities.
    32  type gzipResponseWriter struct {
    33  	http.ResponseWriter
    34  	gzw            *gzip.Writer
    35  	buf            bytes.Buffer
    36  	pool           *sync.Pool
    37  	statusCode     int
    38  	shouldCompress *bool
    39  	o              options
    40  }
    41  
    42  // Write writes bytes to the gzip.Writer. It will also set the Content-Type
    43  // header using the net/http library content type detection if the Content-Type
    44  // header was not set yet.
    45  func (grw *gzipResponseWriter) Write(b []byte) (int, error) {
    46  	if len(grw.Header().Get(headerContentType)) == 0 {
    47  		grw.Header().Set(headerContentType, http.DetectContentType(b))
    48  	}
    49  
    50  	// If we already decided to gzip, do that.
    51  	if grw.gzw != nil {
    52  		return grw.gzw.Write(b)
    53  	}
    54  
    55  	// If we have already decided not to gzip, do that.
    56  	if grw.shouldCompress != nil && !*grw.shouldCompress {
    57  		return grw.ResponseWriter.Write(b)
    58  	}
    59  
    60  	// Detect types, check status code.
    61  	if grw.shouldCompress == nil {
    62  		s := grw.o.shouldCompress(grw.Header().Get(headerContentType), grw.statusCode)
    63  		grw.shouldCompress = &s
    64  		if !s {
    65  			grw.ResponseWriter.WriteHeader(grw.statusCode)
    66  			return grw.ResponseWriter.Write(b)
    67  		}
    68  	}
    69  
    70  	// Check if length is above minimum,
    71  	// if not save to buffer.
    72  	size := len(b) + grw.buf.Len()
    73  	if size < grw.o.minSize {
    74  		return grw.buf.Write(b)
    75  	}
    76  
    77  	// Reset our gzip writer to use the http.ResponseWriter
    78  	// Retrieve gzip writer from the pool. Reset it to use the ResponseWriter.
    79  	// This allows us to re-use an already allocated buffer rather than
    80  	// allocating a new buffer for every request.
    81  	gz := grw.pool.Get().(*gzip.Writer)
    82  
    83  	// We must write header now
    84  	grw.Header().Set(headerContentEncoding, encodingGzip)
    85  	grw.Header().Set(headerVary, headerAcceptEncoding)
    86  	grw.Header().Del(headerContentLength)
    87  	grw.Header().Del(headerAcceptRanges)
    88  	grw.ResponseWriter.WriteHeader(grw.statusCode)
    89  	gz.Reset(grw.ResponseWriter)
    90  	grw.gzw = gz
    91  
    92  	// Write buffer
    93  	if grw.buf.Len() > 0 {
    94  		_, err := gz.Write(grw.buf.Bytes())
    95  		if err != nil {
    96  			return 0, err
    97  		}
    98  		grw.buf.Reset()
    99  	}
   100  	return gz.Write(b)
   101  }
   102  
   103  func (grw *gzipResponseWriter) WriteHeader(n int) {
   104  	grw.statusCode = n
   105  }
   106  
   107  type (
   108  	// Option allows to override default parameters.
   109  	Option func(*options) error
   110  
   111  	// options contains final options
   112  	options struct {
   113  		ignoreRange  bool
   114  		minSize      int
   115  		contentTypes []string
   116  		statusCodes  map[int]struct{}
   117  	}
   118  )
   119  
   120  // defaultContentTypes is the default list of content types for which
   121  // a Handler considers gzip compression. This list originates from the
   122  // file compression.conf within the Apache configuration found at
   123  // https://html5boilerplate.com/
   124  var defaultContentTypes = []string{
   125  	"application/atom+xml",
   126  	"application/font-sfnt",
   127  	"application/javascript",
   128  	"application/json",
   129  	"application/ld+json",
   130  	"application/manifest+json",
   131  	"application/rdf+xml",
   132  	"application/rss+xml",
   133  	"application/schema+json",
   134  	"application/vnd.", // All custom vendor types
   135  	"application/x-font-ttf",
   136  	"application/x-javascript",
   137  	"application/x-web-app-manifest+json",
   138  	"application/xhtml+xml",
   139  	"application/xml",
   140  	"font/eot",
   141  	"font/opentype",
   142  	"image/bmp",
   143  	"image/svg+xml",
   144  	"image/vnd.microsoft.icon",
   145  	"image/x-icon",
   146  	"text/cache-manifest",
   147  	"text/css",
   148  	"text/html",
   149  	"text/javascript",
   150  	"text/plain",
   151  	"text/vcard",
   152  	"text/vnd.rim.location.xloc",
   153  	"text/vtt",
   154  	"text/x-component",
   155  	"text/x-cross-domain-policy",
   156  	"text/xml",
   157  }
   158  
   159  // defaultStatusCodes are the status codes that will be compressed.
   160  var defaultStatusCodes = []int{
   161  	http.StatusOK,
   162  	http.StatusCreated,
   163  	http.StatusAccepted,
   164  }
   165  
   166  // AddContentTypes allows to specify specific content types to encode.
   167  // Adds to previous content types.
   168  func AddContentTypes(types ...string) Option {
   169  	return func(c *options) error {
   170  		dst := make([]string, len(c.contentTypes)+len(types))
   171  		copy(dst, c.contentTypes)
   172  		copy(dst[len(c.contentTypes):], types)
   173  		c.contentTypes = dst
   174  		return nil
   175  	}
   176  }
   177  
   178  // OnlyContentTypes allows to specify specific content types to encode.
   179  // Overrides previous content types.
   180  // no types = ignore content types (always compress).
   181  func OnlyContentTypes(types ...string) Option {
   182  	return func(c *options) error {
   183  		if len(types) == 0 {
   184  			c.contentTypes = nil
   185  			return nil
   186  		}
   187  		c.contentTypes = types
   188  		return nil
   189  	}
   190  }
   191  
   192  // AddStatusCodes allows to specify specific content types to encode.
   193  // All content types that has the supplied prefixes are compressed.
   194  func AddStatusCodes(codes ...int) Option {
   195  	return func(c *options) error {
   196  		dst := make(map[int]struct{}, len(c.statusCodes)+len(codes))
   197  		for code := range c.statusCodes {
   198  			dst[code] = struct{}{}
   199  		}
   200  		for _, code := range codes {
   201  			c.statusCodes[code] = struct{}{}
   202  		}
   203  		return nil
   204  	}
   205  }
   206  
   207  // OnlyStatusCodes allows to specify specific content types to encode.
   208  // All content types that has the supplied prefixes are compressed.
   209  // No codes = ignore content types (always compress).
   210  func OnlyStatusCodes(codes ...int) Option {
   211  	return func(c *options) error {
   212  		if len(codes) == 0 {
   213  			c.statusCodes = nil
   214  			return nil
   215  		}
   216  		c.statusCodes = make(map[int]struct{}, len(codes))
   217  		for _, code := range codes {
   218  			c.statusCodes[code] = struct{}{}
   219  		}
   220  		return nil
   221  	}
   222  }
   223  
   224  // MinSize will set a minimum size for compression.
   225  func MinSize(n int) Option {
   226  	return func(c *options) error {
   227  		if n <= 0 {
   228  			c.minSize = 0
   229  			return nil
   230  		}
   231  		c.minSize = n
   232  		return nil
   233  	}
   234  }
   235  
   236  // IgnoreRange will set make the compressor ignore Range requests.
   237  // Range requests are incompatible with compressed content,
   238  // so if this is set to true "Range" headers will be ignored.
   239  // If set to false, compression is disabled for all requests with Range header.
   240  func IgnoreRange(b bool) Option {
   241  	return func(c *options) error {
   242  		c.ignoreRange = b
   243  		return nil
   244  	}
   245  }
   246  
   247  // Middleware encodes the response using Gzip encoding and sets all the
   248  // appropriate headers. If the Content-Type is not set, it will be set by
   249  // calling http.DetectContentType on the data being written.
   250  func Middleware(level int, o ...Option) goa.Middleware {
   251  	opts := options{
   252  		ignoreRange:  true,
   253  		minSize:      256,
   254  		contentTypes: defaultContentTypes,
   255  	}
   256  	opts.statusCodes = make(map[int]struct{}, len(defaultStatusCodes))
   257  	for _, v := range defaultStatusCodes {
   258  		opts.statusCodes[v] = struct{}{}
   259  	}
   260  	for _, opt := range o {
   261  		err := opt(&opts)
   262  		if err != nil {
   263  			panic(err)
   264  		}
   265  	}
   266  	gzipPool := sync.Pool{
   267  		New: func() interface{} {
   268  			gz, err := gzip.NewWriterLevel(ioutil.Discard, level)
   269  			if err != nil {
   270  				panic(err)
   271  			}
   272  			return gz
   273  		},
   274  	}
   275  	return func(h goa.Handler) goa.Handler {
   276  		return func(ctx context.Context, rw http.ResponseWriter, req *http.Request) (err error) {
   277  			// Skip compression if the client doesn't accept gzip encoding, is
   278  			// requesting a WebSocket or the data is already compressed.
   279  			if !strings.Contains(req.Header.Get(headerAcceptEncoding), encodingGzip) ||
   280  				len(req.Header.Get(headerSecWebSocketKey)) > 0 ||
   281  				rw.Header().Get(headerContentEncoding) == encodingGzip ||
   282  				(!opts.ignoreRange && req.Header.Get(headerRange) != "") {
   283  				return h(ctx, rw, req)
   284  			}
   285  
   286  			// Set the appropriate gzip headers.
   287  			resp := goa.ContextResponse(ctx)
   288  
   289  			// Get the original http.ResponseWriter
   290  			w := resp.SwitchWriter(nil)
   291  
   292  			// Wrap the original http.ResponseWriter with our gzipResponseWriter
   293  			grw := &gzipResponseWriter{
   294  				ResponseWriter: w,
   295  				pool:           &gzipPool,
   296  				statusCode:     http.StatusOK,
   297  				o:              opts,
   298  			}
   299  
   300  			// Set the new http.ResponseWriter
   301  			resp.SwitchWriter(grw)
   302  
   303  			// We cannot do ranges, if possibly gzipped responses.
   304  			req.Header.Del("Range")
   305  
   306  			// Call the next handler supplying the gzipResponseWriter instead of
   307  			// the original.
   308  			err = h(ctx, rw, req)
   309  			if err != nil {
   310  				return
   311  			}
   312  
   313  			// Check for uncompressed data
   314  			if grw.buf.Len() > 0 {
   315  				w.Header().Set(headerContentLength, strconv.Itoa(grw.buf.Len()))
   316  				w.WriteHeader(grw.statusCode)
   317  				_, err = w.Write(grw.buf.Bytes())
   318  				return
   319  			}
   320  
   321  			// Flush compressor.
   322  			if grw.gzw != nil {
   323  				if err = grw.gzw.Close(); err != nil {
   324  					return
   325  				}
   326  				gzipPool.Put(grw.gzw)
   327  				return
   328  			}
   329  			// No writes, set status code.
   330  			if grw.shouldCompress == nil {
   331  				w.WriteHeader(grw.statusCode)
   332  			}
   333  			return
   334  		}
   335  	}
   336  }
   337  
   338  // returns true if we've been configured to compress the specific content type.
   339  func (o options) shouldCompress(contentType string, statusCode int) bool {
   340  	// If contentTypes is nil we handle all content types.
   341  	if len(o.contentTypes) > 0 {
   342  		ct := strings.ToLower(contentType)
   343  		ct = strings.Split(ct, ";")[0]
   344  		found := false
   345  		for _, v := range o.contentTypes {
   346  			if strings.HasPrefix(ct, v) {
   347  				found = true
   348  				break
   349  			}
   350  		}
   351  		if !found {
   352  			return false
   353  		}
   354  	}
   355  	if len(o.statusCodes) > 0 {
   356  		_, ok := o.statusCodes[statusCode]
   357  		if !ok {
   358  			return false
   359  		}
   360  	}
   361  
   362  	return true
   363  }