github.com/sirkon/goproxy@v1.4.8/middleware.go (about)

     1  package goproxy
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"io"
     7  	"net/http"
     8  	"strings"
     9  	"time"
    10  	"unicode/utf8"
    11  
    12  	"github.com/rs/zerolog"
    13  	"github.com/spaolacci/murmur3"
    14  
    15  	"github.com/sirkon/goproxy/internal/errors"
    16  	"github.com/sirkon/goproxy/semver"
    17  )
    18  
    19  // Middleware acts as go proxy with given router.
    20  //   transportPrefix is a head part of URL path which refers to address of go proxy before the module info. For example,
    21  // if we serving go proxy at https://0.0.0.0:8081/goproxy/..., transportPrefix will be "/goproxy"
    22  func Middleware(r *Router, transportPrefix string, logger *zerolog.Logger) http.Handler {
    23  	return &middleware{
    24  		prefix: transportPrefix,
    25  		router: r,
    26  		logger: logger,
    27  	}
    28  }
    29  
    30  // Middleware
    31  type middleware struct {
    32  	prefix string
    33  	router *Router
    34  	logger *zerolog.Logger
    35  }
    36  
    37  const latestSuffix = "/@latest"
    38  
    39  func errResp(w http.ResponseWriter, logger zerolog.Logger, code int, err error, msg string) {
    40  	w.WriteHeader(code)
    41  	var errMsg string
    42  	if err != nil {
    43  		logger.Error().Err(err).Msg(msg)
    44  		errMsg = errors.Wrap(err, msg).Error()
    45  	} else {
    46  		logger.Error().Msg(msg)
    47  		errMsg = msg
    48  	}
    49  
    50  	if _, wErr := io.WriteString(w, errMsg); wErr != nil {
    51  		logger.Error().Err(wErr).Msg("failed to respond")
    52  	}
    53  }
    54  
    55  func errRespf(w http.ResponseWriter, logger zerolog.Logger, code int, err error, format string, a ...interface{}) {
    56  	errResp(w, logger, code, err, fmt.Sprintf(format, a...))
    57  }
    58  
    59  func (m *middleware) ServeHTTP(w http.ResponseWriter, req *http.Request) {
    60  	hasher := murmur3.New64()
    61  	_, _ = io.WriteString(hasher, req.URL.String())
    62  	_, _ = io.WriteString(hasher, time.Now().Format(time.RFC3339Nano))
    63  	logger := m.logger.With().Hex("request-id", hasher.Sum(nil)).Str("request", req.URL.String()).Logger()
    64  
    65  	path, suffix, err := GetModInfo(req, m.prefix)
    66  	if err != nil {
    67  		errResp(w, logger, http.StatusBadRequest, err, "getting mod info")
    68  		return
    69  	}
    70  
    71  	logger = logger.With().Str("module", path).Logger()
    72  
    73  	factory := m.router.Factory(path)
    74  	if factory == nil {
    75  		errRespf(w, logger, http.StatusBadRequest, nil, "no plugin registered for %s", path)
    76  		return
    77  	}
    78  
    79  	logger = logger.With().Str("plugin", factory.String()).Logger()
    80  
    81  	src, err := factory.Module(req, m.prefix)
    82  	if err != nil {
    83  		errResp(w, logger, http.StatusBadRequest, err, "failed to get a source from plugin")
    84  		return
    85  	}
    86  
    87  	switch {
    88  	case suffix == "list":
    89  		ctx := logger.WithContext(req.Context())
    90  		logger.Debug().Msg("version list requested")
    91  		version, err := src.Versions(ctx, "")
    92  		if err != nil {
    93  			errResp(w, logger, http.StatusBadRequest, err, "getting version list")
    94  			return
    95  		}
    96  		w.WriteHeader(http.StatusOK)
    97  		if _, err := io.WriteString(w, strings.Join(version, "\n")); err != nil {
    98  			logger.Error().Err(err).Msg("writing version list response")
    99  		} else {
   100  			logger.Debug().Msg("version list done")
   101  		}
   102  
   103  	case strings.HasSuffix(suffix, ".info"):
   104  		version := getVersion(suffix)
   105  		tmpLogger := logger.With().Str("version", version).Logger()
   106  		ctx := tmpLogger.WithContext(req.Context())
   107  		tmpLogger.Debug().Msg("version info requested")
   108  		info, err := src.Stat(ctx, version)
   109  		if err != nil {
   110  			errResp(w, tmpLogger, http.StatusBadRequest, err, "getting revision info from source beneath")
   111  			w.WriteHeader(http.StatusBadRequest)
   112  			return
   113  		}
   114  		je := json.NewEncoder(w)
   115  		if err := je.Encode(info); err != nil {
   116  			tmpLogger.Error().Err(err).Msg("writing version info response")
   117  		} else {
   118  			tmpLogger.Debug().Msg("version info done")
   119  		}
   120  
   121  	case strings.HasSuffix(suffix, ".mod"):
   122  		version := getVersion(suffix)
   123  		tmpLogger := logger.With().Str("version", version).Logger()
   124  		ctx := tmpLogger.WithContext(req.Context())
   125  		tmpLogger.Debug().Msg("go.mod requested")
   126  		gomod, err := src.GoMod(ctx, version)
   127  		if err != nil {
   128  			errResp(w, tmpLogger, http.StatusBadRequest, err, "getting go.mod from a source beneath")
   129  			return
   130  		}
   131  		if _, err := w.Write(gomod); err != nil {
   132  			tmpLogger.Error().Err(err).Msg("writing go.mod response")
   133  			return
   134  		} else {
   135  			tmpLogger.Debug().Msg("go.mod done")
   136  		}
   137  
   138  	case strings.HasSuffix(suffix, ".zip"):
   139  		version := getVersion(suffix)
   140  		tmpLogger := logger.With().Str("version", version).Logger()
   141  		ctx := tmpLogger.WithContext(req.Context())
   142  		tmpLogger.Debug().Msg("zip archive requested")
   143  		archiveReader, err := src.Zip(ctx, version)
   144  		if err != nil {
   145  			errResp(w, tmpLogger, http.StatusBadRequest, err, "getting zip archive")
   146  			return
   147  		}
   148  		defer func() {
   149  			if err := archiveReader.Close(); err != nil {
   150  				tmpLogger.Error().Err(err).Msgf("closing zip reachive reader")
   151  			}
   152  		}()
   153  		if _, err := io.Copy(w, archiveReader); err != nil {
   154  			tmpLogger.Error().Err(err).Msg("writing zip archive response")
   155  		} else {
   156  			tmpLogger.Debug().Msg("zip done")
   157  		}
   158  
   159  	case suffix == "latest":
   160  		ctx := logger.WithContext(req.Context())
   161  		logger.Debug().Msg("latest")
   162  		version, err := src.Versions(ctx, "")
   163  		var revision string
   164  		if err != nil {
   165  			logger.Error().Err(err).Msg("getting version list for @latest")
   166  			revision = "master"
   167  		} else {
   168  			for _, v := range version {
   169  				if semver.IsValid(v) && (len(revision) == 0 || semver.Compare(v, revision) > 0) {
   170  					revision = v
   171  				}
   172  			}
   173  			if len(revision) == 0 {
   174  				revision = "master"
   175  			}
   176  		}
   177  		tmpLogger := logger.With().Str("version", revision).Logger()
   178  		tmpLogger.Debug().Msg("version info requested")
   179  		info, err := src.Stat(ctx, revision)
   180  		if err != nil {
   181  			errResp(w, tmpLogger, http.StatusBadRequest, err, "getting revision info from source beneath for @latest")
   182  			return
   183  		}
   184  		je := json.NewEncoder(w)
   185  		if err := je.Encode(info); err != nil {
   186  			tmpLogger.Error().Err(err).Msg("writing version info response for @latest")
   187  		} else {
   188  			tmpLogger.Debug().Msgf("latest done")
   189  		}
   190  	default:
   191  		logger.Error().Msgf("unsupported suffix %s", suffix)
   192  		w.WriteHeader(http.StatusBadRequest)
   193  		return
   194  	}
   195  }
   196  
   197  // getVersion we have something like v0.1.2.zip or v0.1.2.info or v0.1.2.zip in the suffix and need to cut the
   198  func getVersion(suffix string) string {
   199  	off := strings.LastIndex(suffix, ".")
   200  	encoding := suffix[:off]
   201  
   202  	var buf []byte
   203  	bang := false
   204  	for _, r := range encoding {
   205  		if r >= utf8.RuneSelf {
   206  			return encoding
   207  		}
   208  		if bang {
   209  			bang = false
   210  			if r < 'a' || 'z' < r {
   211  				return encoding
   212  			}
   213  			buf = append(buf, byte(r+'A'-'a'))
   214  			continue
   215  		}
   216  		if r == '!' {
   217  			bang = true
   218  			continue
   219  		}
   220  		if 'A' <= r && r <= 'Z' {
   221  			return encoding
   222  		}
   223  		buf = append(buf, byte(r))
   224  	}
   225  	if bang {
   226  		return encoding
   227  	}
   228  	return string(buf)
   229  }