github.com/mendersoftware/go-lib-micro@v0.0.0-20240304135804-e8e39c59b148/accesslog/middleware.go (about)

     1  // Copyright 2023 Northern.tech AS
     2  //
     3  //	Licensed under the Apache License, Version 2.0 (the "License");
     4  //	you may not use this file except in compliance with the License.
     5  //	You may obtain a copy of the License at
     6  //
     7  //	    http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  //	Unless required by applicable law or agreed to in writing, software
    10  //	distributed under the License is distributed on an "AS IS" BASIS,
    11  //	WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  //	See the License for the specific language governing permissions and
    13  //	limitations under the License.
    14  package accesslog
    15  
    16  import (
    17  	"bytes"
    18  	"context"
    19  	"errors"
    20  	"fmt"
    21  	"net"
    22  	"net/http"
    23  	"os"
    24  	"path"
    25  	"runtime"
    26  	"strings"
    27  	"text/template"
    28  	"time"
    29  
    30  	"github.com/ant0ine/go-json-rest/rest"
    31  	"github.com/sirupsen/logrus"
    32  
    33  	"github.com/mendersoftware/go-lib-micro/requestlog"
    34  )
    35  
    36  const (
    37  	StatusClientClosedConnection = 499
    38  
    39  	DefaultLogFormat = "%t %S\033[0m \033[36;1m%Dμs\033[0m \"%r\" \033[1;30m%u \"%{User-Agent}i\"\033[0m"
    40  	SimpleLogFormat  = "%s %Dμs %r %u %{User-Agent}i"
    41  
    42  	TypeHTTP = "http"
    43  )
    44  
    45  // AccesLogMiddleware is a customized version of the AccessLogApacheMiddleware.
    46  // It uses the request-specific custom logger (created by requestlog), which appends the Mender-specific request context.
    47  type AccessLogMiddleware struct {
    48  	Format       AccessLogFormat
    49  	textTemplate *template.Template
    50  
    51  	DisableLog func(statusCode int, r *rest.Request) bool
    52  
    53  	recorder *rest.RecorderMiddleware
    54  }
    55  
    56  const MaxTraceback = 32
    57  
    58  func collectTrace() string {
    59  	var (
    60  		trace     [MaxTraceback]uintptr
    61  		traceback strings.Builder
    62  	)
    63  	// Skip 4
    64  	// = accesslog.LogFunc
    65  	// + accesslog.collectTrace
    66  	// + runtime.Callers
    67  	// + runtime.gopanic
    68  	n := runtime.Callers(4, trace[:])
    69  	frames := runtime.CallersFrames(trace[:n])
    70  	for frame, more := frames.Next(); frame.PC != 0 &&
    71  		n >= 0; frame, more = frames.Next() {
    72  		funcName := frame.Function
    73  		if funcName == "" {
    74  			fmt.Fprint(&traceback, "???\n")
    75  		} else {
    76  			fmt.Fprintf(&traceback, "%s@%s:%d",
    77  				frame.Function,
    78  				path.Base(frame.File),
    79  				frame.Line,
    80  			)
    81  		}
    82  		if more {
    83  			fmt.Fprintln(&traceback)
    84  		}
    85  		n--
    86  	}
    87  	return traceback.String()
    88  }
    89  
    90  func (mw *AccessLogMiddleware) LogFunc(
    91  	ctx context.Context, startTime time.Time,
    92  	w rest.ResponseWriter, r *rest.Request) {
    93  	util := &accessLogUtil{w, r}
    94  	fields := logrus.Fields{
    95  		"type": r.Proto,
    96  		"ts": startTime.
    97  			Truncate(time.Millisecond).
    98  			Format(time.RFC3339Nano),
    99  		"method": r.Method,
   100  		"path":   r.URL.Path,
   101  		"qs":     r.URL.RawQuery,
   102  	}
   103  	lc := fromContext(ctx)
   104  	if lc != nil {
   105  		lc.addFields(fields)
   106  	}
   107  	statusCode := util.StatusCode()
   108  	select {
   109  	case <-ctx.Done():
   110  		if errors.Is(ctx.Err(), context.Canceled) {
   111  			statusCode = StatusClientClosedConnection
   112  		}
   113  	default:
   114  	}
   115  
   116  	if panic := recover(); panic != nil {
   117  		trace := collectTrace()
   118  		fields["panic"] = panic
   119  		fields["trace"] = trace
   120  		// Wrap in recorder middleware to make sure the response is recorded
   121  		mw.recorder.MiddlewareFunc(func(w rest.ResponseWriter, r *rest.Request) {
   122  			rest.Error(w, "Internal Server Error", http.StatusInternalServerError)
   123  		})(w, r)
   124  		statusCode = http.StatusInternalServerError
   125  	} else if mw.DisableLog != nil && mw.DisableLog(statusCode, r) {
   126  		return
   127  	}
   128  	rspTime := time.Since(startTime)
   129  	r.Env["ELAPSED_TIME"] = &rspTime
   130  	// We do not need more than 3 digit fraction
   131  	if rspTime > time.Second {
   132  		rspTime = rspTime.Round(time.Millisecond)
   133  	} else if rspTime > time.Millisecond {
   134  		rspTime = rspTime.Round(time.Microsecond)
   135  	}
   136  	fields["responsetime"] = rspTime.String()
   137  	fields["byteswritten"] = util.BytesWritten()
   138  	fields["status"] = statusCode
   139  
   140  	logger := requestlog.GetRequestLogger(r)
   141  	var level logrus.Level = logrus.InfoLevel
   142  	if statusCode >= 500 {
   143  		level = logrus.ErrorLevel
   144  	} else if statusCode >= 300 {
   145  		level = logrus.WarnLevel
   146  	}
   147  	logger.WithFields(fields).
   148  		Log(level, mw.executeTextTemplate(util))
   149  }
   150  
   151  // MiddlewareFunc makes AccessLogMiddleware implement the Middleware interface.
   152  func (mw *AccessLogMiddleware) MiddlewareFunc(h rest.HandlerFunc) rest.HandlerFunc {
   153  	if mw.Format == "" {
   154  		mw.Format = DefaultLogFormat
   155  	}
   156  
   157  	mw.convertFormat()
   158  
   159  	// This middleware depends on RecorderMiddleware to work
   160  	mw.recorder = new(rest.RecorderMiddleware)
   161  	return func(w rest.ResponseWriter, r *rest.Request) {
   162  		ctx := r.Request.Context()
   163  		startTime := time.Now()
   164  		r.Env["START_TIME"] = &startTime
   165  		ctx = withContext(ctx, &logContext{maxErrors: DefaultMaxErrors})
   166  		r.Request = r.Request.WithContext(ctx)
   167  		defer mw.LogFunc(ctx, startTime, w, r)
   168  		// call the handler inside recorder context
   169  		mw.recorder.MiddlewareFunc(h)(w, r)
   170  	}
   171  }
   172  
   173  var apacheAdapter = strings.NewReplacer(
   174  	"%b", "{{.BytesWritten | dashIf0}}",
   175  	"%B", "{{.BytesWritten}}",
   176  	"%D", "{{.ResponseTime | microseconds}}",
   177  	"%h", "{{.ApacheRemoteAddr}}",
   178  	"%H", "{{.R.Proto}}",
   179  	"%l", "-",
   180  	"%m", "{{.R.Method}}",
   181  	"%P", "{{.Pid}}",
   182  	"%q", "{{.ApacheQueryString}}",
   183  	"%r", "{{.R.Method}} {{.R.URL.RequestURI}} {{.R.Proto}}",
   184  	"%s", "{{.StatusCode}}",
   185  	"%S", "\033[{{.StatusCode | statusCodeColor}}m{{.StatusCode}}",
   186  	"%t", "{{if .StartTime}}{{.StartTime.Format \"02/Jan/2006:15:04:05 -0700\"}}{{end}}",
   187  	"%T", "{{if .ResponseTime}}{{.ResponseTime.Seconds | printf \"%.3f\"}}{{end}}",
   188  	"%u", "{{.RemoteUser | dashIfEmptyStr}}",
   189  	"%{User-Agent}i", "{{.R.UserAgent | dashIfEmptyStr}}",
   190  	"%{Referer}i", "{{.R.Referer | dashIfEmptyStr}}",
   191  )
   192  
   193  // Execute the text template with the data derived from the request, and return a string.
   194  func (mw *AccessLogMiddleware) executeTextTemplate(util *accessLogUtil) string {
   195  	buf := bytes.NewBufferString("")
   196  	err := mw.textTemplate.Execute(buf, util)
   197  	if err != nil {
   198  		panic(err)
   199  	}
   200  	return buf.String()
   201  }
   202  
   203  func (mw *AccessLogMiddleware) convertFormat() {
   204  
   205  	tmplText := apacheAdapter.Replace(string(mw.Format))
   206  
   207  	funcMap := template.FuncMap{
   208  		"dashIfEmptyStr": func(value string) string {
   209  			if value == "" {
   210  				return "-"
   211  			}
   212  			return value
   213  		},
   214  		"dashIf0": func(value int64) string {
   215  			if value == 0 {
   216  				return "-"
   217  			}
   218  			return fmt.Sprintf("%d", value)
   219  		},
   220  		"microseconds": func(dur *time.Duration) string {
   221  			if dur != nil {
   222  				return fmt.Sprintf("%d", dur.Nanoseconds()/1000)
   223  			}
   224  			return ""
   225  		},
   226  		"statusCodeColor": func(statusCode int) string {
   227  			if statusCode >= 400 && statusCode < 500 {
   228  				return "1;33"
   229  			} else if statusCode >= 500 {
   230  				return "0;31"
   231  			}
   232  			return "0;32"
   233  		},
   234  	}
   235  
   236  	var err error
   237  	mw.textTemplate, err = template.New("accessLog").Funcs(funcMap).Parse(tmplText)
   238  	if err != nil {
   239  		panic(err)
   240  	}
   241  }
   242  
   243  // accessLogUtil provides a collection of utility functions that devrive data from the Request object.
   244  // This object is used to provide data to the Apache Style template and the the JSON log record.
   245  type accessLogUtil struct {
   246  	W rest.ResponseWriter
   247  	R *rest.Request
   248  }
   249  
   250  // As stored by the auth middlewares.
   251  func (u *accessLogUtil) RemoteUser() string {
   252  	if u.R.Env["REMOTE_USER"] != nil {
   253  		return u.R.Env["REMOTE_USER"].(string)
   254  	}
   255  	return ""
   256  }
   257  
   258  // If qs exists then return it with a leadin "?", apache log style.
   259  func (u *accessLogUtil) ApacheQueryString() string {
   260  	if u.R.URL.RawQuery != "" {
   261  		return "?" + u.R.URL.RawQuery
   262  	}
   263  	return ""
   264  }
   265  
   266  // When the request entered the timer middleware.
   267  func (u *accessLogUtil) StartTime() *time.Time {
   268  	if u.R.Env["START_TIME"] != nil {
   269  		return u.R.Env["START_TIME"].(*time.Time)
   270  	}
   271  	return nil
   272  }
   273  
   274  // If remoteAddr is set then return is without the port number, apache log style.
   275  func (u *accessLogUtil) ApacheRemoteAddr() string {
   276  	remoteAddr := u.R.RemoteAddr
   277  	if remoteAddr != "" {
   278  		if ip, _, err := net.SplitHostPort(remoteAddr); err == nil {
   279  			return ip
   280  		}
   281  	}
   282  	return ""
   283  }
   284  
   285  // As recorded by the recorder middleware.
   286  func (u *accessLogUtil) StatusCode() int {
   287  	if u.R.Env["STATUS_CODE"] != nil {
   288  		return u.R.Env["STATUS_CODE"].(int)
   289  	}
   290  	return 0
   291  }
   292  
   293  // As mesured by the timer middleware.
   294  func (u *accessLogUtil) ResponseTime() *time.Duration {
   295  	if u.R.Env["ELAPSED_TIME"] != nil {
   296  		return u.R.Env["ELAPSED_TIME"].(*time.Duration)
   297  	}
   298  	return nil
   299  }
   300  
   301  // Process id.
   302  func (u *accessLogUtil) Pid() int {
   303  	return os.Getpid()
   304  }
   305  
   306  // As recorded by the recorder middleware.
   307  func (u *accessLogUtil) BytesWritten() int64 {
   308  	if u.R.Env["BYTES_WRITTEN"] != nil {
   309  		return u.R.Env["BYTES_WRITTEN"].(int64)
   310  	}
   311  	return 0
   312  }