
     1  // Package inject gives the ability to copy data and inject a payload before a
     2  // specified marker. In order to let the user respond to the change in length,
     3  // the API is split into two parts - Sniff checks whether the marker occurs
     4  // within a specified number of initial bytes, and Copy sends the data to the
     5  // destination.
     6  //
     7  // The package tries to avoid double-injecting a payload by checking whether
     8  // the payload occurs within the first Within + len(Payload) bytes.
     9  package inject
    11  import (
    12  	"bytes"
    13  	"fmt"
    14  	"html/template"
    15  	"io"
    16  	"net/http"
    17  	"regexp"
    18  	"strings"
    19  )
    21  // CopyInject copies data, and injects a payload before a specified marker
    22  type CopyInject struct {
    23  	// Number of initial bytes within which to search for marker
    24  	Within int
    25  	// Only inject in responses with this content type
    26  	ContentType string
    27  	// A marker, BEFORE which the payload is inserted
    28  	Marker *regexp.Regexp
    29  	// The payload to be inserted
    30  	Payload []byte
    31  }
    33  type Injector interface {
    34  	Copy(dst io.Writer) (int64, error)
    35  	Extra() int
    36  	Found() bool
    37  }
    39  // realInjector keeps injection state
    40  type realInjector struct {
    41  	// Has the marker been found?
    42  	found       bool
    43  	conf        *CopyInject
    44  	src         io.Reader
    45  	offset      int
    46  	sniffedData []byte
    47  }
    49  type nopInjector struct {
    50  	src io.Reader
    51  }
    53  func (injector *nopInjector) Copy(dst io.Writer) (int64, error) {
    54  	return io.Copy(dst, injector.src)
    55  }
    57  func (injector *nopInjector) Extra() int {
    58  	return 0
    59  }
    61  func (injector *nopInjector) Found() bool {
    62  	return false
    63  }
    65  // Extra reports the number of extra bytes that will be injected
    66  func (injector *realInjector) Extra() int {
    67  	if injector.found {
    68  		return len(injector.conf.Payload)
    69  	}
    70  	return 0
    71  }
    73  func (injector *realInjector) Found() bool {
    74  	return injector.found
    75  }
    77  func min(a int, b int) int {
    78  	if a > b {
    79  		return b
    80  	}
    81  	return a
    82  }
    84  // Sniff reads the first SniffLen bytes of the source, and checks for the
    85  // marker. Returns an Injector instance.
    86  func (ci *CopyInject) Sniff(src io.Reader, contentType string) (Injector, error) {
    87  	if !strings.Contains(contentType, ci.ContentType) {
    88  		return &nopInjector{src: src}, nil
    89  	}
    91  	injector := &realInjector{
    92  		conf: ci,
    93  		src:  src,
    94  	}
    95  	if ci.Within == 0 || ci.Marker == nil {
    96  		return injector, nil
    97  	}
    98  	buf := make([]byte, ci.Within+len(ci.Payload))
    99  	n, err := io.ReadFull(src, buf)
   100  	if err != nil && err != io.ErrUnexpectedEOF && err != io.EOF {
   101  		return nil, fmt.Errorf("inject could not read data to sniff: %s", err)
   102  	}
   103  	injector.sniffedData = buf[:n]
   104  	if bytes.Index(buf, ci.Payload) > -1 {
   105  		return injector, nil
   106  	}
   107  	loc := ci.Marker.FindIndex(injector.sniffedData[:min(n, ci.Within)])
   108  	if loc != nil {
   109  		injector.found = true
   110  		injector.offset = loc[0]
   111  	}
   112  	return injector, nil
   113  }
   115  // ServeTemplate renders and serves a template to an http.ResponseWriter
   116  func (ci *CopyInject) ServeTemplate(statuscode int, w http.ResponseWriter, t *template.Template, data interface{}) error {
   117  	buff := bytes.NewBuffer(make([]byte, 0, 0))
   118  	err := t.Execute(buff, data)
   119  	if err != nil {
   120  		return err
   121  	}
   123  	length := buff.Len()
   124  	inj, err := ci.Sniff(buff, "text/html")
   125  	if err != nil {
   126  		return err
   127  	}
   128  	w.Header().Set(
   129  		"Content-Length", fmt.Sprintf("%d", length+inj.Extra()),
   130  	)
   131  	w.WriteHeader(statuscode)
   132  	_, err = inj.Copy(w)
   133  	if err != nil {
   134  		return err
   135  	}
   136  	return nil
   137  }
   139  // Copy copies the data from src to dst, injecting the Payload if Sniff found
   140  // the marker.
   141  func (injector *realInjector) Copy(dst io.Writer) (int64, error) {
   142  	var preludeLen int64
   143  	if injector.found {
   144  		startn, err := io.Copy(
   145  			dst,
   146  			bytes.NewBuffer(
   147  				injector.sniffedData[:injector.offset],
   148  			),
   149  		)
   150  		if err != nil {
   151  			return startn, err
   152  		}
   153  		payloadn, err := io.Copy(dst, bytes.NewBuffer(injector.conf.Payload))
   154  		if err != nil {
   155  			return startn + payloadn, err
   156  		}
   157  		endn, err := io.Copy(
   158  			dst, bytes.NewBuffer(injector.sniffedData[injector.offset:]),
   159  		)
   160  		if err != nil {
   161  			return startn + payloadn + endn, err
   162  		}
   163  		preludeLen = startn + payloadn + endn
   164  	} else {
   165  		n, err := io.Copy(dst, bytes.NewBuffer(injector.sniffedData))
   166  		if err != nil {
   167  			return n, err
   168  		}
   169  		preludeLen = int64(len(injector.sniffedData))
   170  	}
   171  	n, err := io.Copy(dst, injector.src)
   172  	return n + preludeLen, err
   173  }