github.com/hellofresh/janus@v0.0.0-20230925145208-ce8de8183c67/pkg/middleware/host_matcher.go (about)

     1  package middleware
     2  
     3  import (
     4  	"fmt"
     5  	"net/http"
     6  	"regexp"
     7  	"strings"
     8  
     9  	"github.com/hellofresh/janus/pkg/errors"
    10  	log "github.com/sirupsen/logrus"
    11  )
    12  
    13  // HostMatcher is a middleware that matches any host with the given list of hosts.
    14  // It also supports regex host like *.example.com
    15  type HostMatcher struct {
    16  	plainHosts    map[string]bool
    17  	wildcardHosts []*regexp.Regexp
    18  }
    19  
    20  // NewHostMatcher creates a new instance of HostMatcher
    21  func NewHostMatcher(hosts []string) *HostMatcher {
    22  	matcher := &HostMatcher{plainHosts: make(map[string]bool)}
    23  	matcher.prepareIndexes(hosts)
    24  	return matcher
    25  }
    26  
    27  // Handler is the middleware function
    28  func (h *HostMatcher) Handler(handler http.Handler) http.Handler {
    29  	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    30  		log.WithField("path", r.URL.Path).Debug("Starting host matcher middleware")
    31  		host := r.Host
    32  
    33  		if _, ok := h.plainHosts[host]; ok {
    34  			log.WithField("host", host).Debug("Plain host matched")
    35  			handler.ServeHTTP(w, r)
    36  			return
    37  		}
    38  
    39  		for _, hostRegex := range h.wildcardHosts {
    40  			if hostRegex.MatchString(host) {
    41  				log.WithField("host", host).Debug("Wildcard host matched")
    42  				handler.ServeHTTP(w, r)
    43  				return
    44  			}
    45  		}
    46  
    47  		err := errors.ErrRouteNotFound
    48  		log.WithError(err).Error("The host didn't match any of the provided hosts")
    49  		errors.Handler(w, r, err)
    50  	})
    51  }
    52  
    53  func (h *HostMatcher) prepareIndexes(hosts []string) {
    54  	if len(hosts) > 0 {
    55  		for _, host := range hosts {
    56  			if strings.Contains(host, "*") {
    57  				regexStr := strings.Replace(host, ".", "\\.", -1)
    58  				regexStr = strings.Replace(regexStr, "*", ".+", -1)
    59  				h.wildcardHosts = append(h.wildcardHosts, regexp.MustCompile(fmt.Sprintf("^%s$", regexStr)))
    60  			} else {
    61  				h.plainHosts[host] = true
    62  			}
    63  		}
    64  	}
    65  }