github.com/vseinstrumentiru/lego@v1.0.2/pkg/lehttp/middleware/trace.go (about)

     1  package middleware
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/json"
     6  	"github.com/gorilla/mux"
     7  	"go.opencensus.io/trace"
     8  	"io/ioutil"
     9  	"net/http"
    10  )
    11  
    12  type traceWriter struct {
    13  	http.ResponseWriter
    14  	buf []byte
    15  }
    16  
    17  func (w *traceWriter) Write(b []byte) (int, error) {
    18  	i, err := w.ResponseWriter.Write(b)
    19  
    20  	if err == nil {
    21  		w.buf = []byte(string(w.buf) + string(b))
    22  	}
    23  
    24  	return i, err
    25  }
    26  
    27  type TraceRequestOptions struct {
    28  	LogRequest        bool
    29  	RequestBodyLimit  int
    30  	LogResponse       bool
    31  	ResponseBodyLimit int
    32  	HeadersToTags     map[string]string
    33  }
    34  
    35  func TraceRequestResponse(opt TraceRequestOptions) func(http.Handler) http.Handler {
    36  	return func(next http.Handler) http.Handler {
    37  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    38  			span := trace.FromContext(r.Context())
    39  
    40  			if span == nil {
    41  				next.ServeHTTP(w, r)
    42  				return
    43  			}
    44  
    45  			if opt.HeadersToTags != nil {
    46  				for header, tag := range opt.HeadersToTags {
    47  					span.AddAttributes(trace.StringAttribute(tag, r.Header.Get(header)))
    48  				}
    49  			}
    50  
    51  			if opt.LogRequest {
    52  				var bodyBytes []byte
    53  				if r.Method == http.MethodGet {
    54  					req := struct {
    55  						Query string            `json:"query"`
    56  						Vars  map[string]string `json:"vars"`
    57  					}{
    58  						Query: r.URL.RawQuery,
    59  						Vars:  mux.Vars(r),
    60  					}
    61  					bodyBytes, _ = json.Marshal(req)
    62  
    63  				} else {
    64  					bodyBytes, _ = ioutil.ReadAll(r.Body)
    65  					_ = r.Body.Close() //  must close
    66  					r.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
    67  				}
    68  
    69  				if opt.RequestBodyLimit > 0 && len(bodyBytes) > opt.RequestBodyLimit {
    70  					bodyBytes = bodyBytes[:opt.RequestBodyLimit-1]
    71  				}
    72  
    73  				reqLog := string(bodyBytes)
    74  
    75  				span.Annotate(
    76  					[]trace.Attribute{
    77  						trace.StringAttribute("request", reqLog),
    78  					},
    79  					"request log",
    80  				)
    81  			}
    82  
    83  			if !opt.LogResponse {
    84  				next.ServeHTTP(w, r)
    85  				return
    86  			}
    87  
    88  			tw := &traceWriter{
    89  				ResponseWriter: w,
    90  			}
    91  			next.ServeHTTP(tw, r)
    92  
    93  			var bodyBytes []byte
    94  			bodyBytes = tw.buf
    95  			if opt.ResponseBodyLimit > 0 && len(bodyBytes) > opt.ResponseBodyLimit {
    96  				bodyBytes = bodyBytes[:opt.ResponseBodyLimit-1]
    97  			}
    98  
    99  			span.Annotate(
   100  				[]trace.Attribute{
   101  					trace.StringAttribute("response", string(bodyBytes)),
   102  				},
   103  				"response log",
   104  			)
   105  		})
   106  	}
   107  }
   108  
   109  func TraceTagFromHeaders(headersTagNames map[string]string) func(http.Handler) http.Handler {
   110  	return func(next http.Handler) http.Handler {
   111  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   112  			span := trace.FromContext(r.Context())
   113  
   114  			if span == nil {
   115  				next.ServeHTTP(w, r)
   116  				return
   117  			}
   118  		})
   119  	}
   120  }