github.com/lastbackend/toolkit@v0.0.0-20241020043710-cafa37b95aad/pkg/server/http/middleware.go (about)

     1  /*
     2  Copyright [2014] - [2023] The Last.Backend authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package http
    18  
    19  import (
    20  	"fmt"
    21  	"net/http"
    22  	"regexp"
    23  	"sort"
    24  
    25  	"github.com/lastbackend/toolkit/pkg/runtime/logger"
    26  	"github.com/lastbackend/toolkit/pkg/server"
    27  )
    28  
    29  const MiddlewareNotFoundError string = "Can not apply middleware router: %s Can not find global server middleware: %s. To " +
    30  	"register middleware, please add Server().HTTP().SetMiddleware(\"%s\", http.Handler) to runtime."
    31  
    32  type Middlewares struct {
    33  	log          logger.Logger
    34  	global       []server.KindMiddleware
    35  	constructors []interface{}
    36  	items        map[server.KindMiddleware]server.HttpServerMiddleware
    37  }
    38  
    39  func (m *Middlewares) SetGlobal(middlewares ...server.KindMiddleware) {
    40  	for _, item := range middlewares {
    41  		if item != "" {
    42  			m.global = append([]server.KindMiddleware{item}, m.global...)
    43  		}
    44  	}
    45  }
    46  
    47  func (m *Middlewares) AddConstructor(h interface{}) {
    48  	m.constructors = append(m.constructors, h)
    49  }
    50  
    51  func (m *Middlewares) Add(h server.HttpServerMiddleware) {
    52  	m.items[h.Kind()] = h
    53  }
    54  
    55  func (m *Middlewares) apply(handler server.HTTPServerHandler) (http.HandlerFunc, error) {
    56  
    57  	h := handler.Handler
    58  
    59  	var (
    60  		exclude = make([]*regexp.Regexp, 0)
    61  		mws     = make([]server.HttpServerMiddleware, 0)
    62  	)
    63  
    64  	for _, opt := range handler.Options {
    65  		if opt.Kind() != optionKindMiddleware {
    66  			continue
    67  		}
    68  
    69  		o, ok := opt.(*optionMiddleware)
    70  
    71  		if !ok {
    72  			continue
    73  		}
    74  
    75  		for _, g := range m.global {
    76  			if g == o.middleware {
    77  				continue
    78  			}
    79  		}
    80  
    81  		middleware, ok := m.items[o.middleware]
    82  		if !ok {
    83  			m.log.Errorf(MiddlewareNotFoundError, handler.Path, o.middleware, o.middleware)
    84  			return h, fmt.Errorf("can not find global server middleware: %s", o.middleware)
    85  		}
    86  
    87  		mws = append(mws, middleware)
    88  	}
    89  
    90  	for _, opt := range handler.Options {
    91  		if opt.Kind() != optionKindExcludeGlobalMiddleware {
    92  			continue
    93  		}
    94  
    95  		o, ok := opt.(*optionExcludeGlobalMiddleware)
    96  
    97  		if !ok {
    98  			continue
    99  		}
   100  
   101  		exclude = append(exclude, regexp.MustCompile(``+o.regexp))
   102  	}
   103  
   104  	for _, g := range m.global {
   105  		var skip bool
   106  
   107  		switch g {
   108  		case corsMiddlewareKind:
   109  		default:
   110  			for _, re := range exclude {
   111  
   112  				if re.MatchString(string(g)) {
   113  					skip = true
   114  					break
   115  				}
   116  
   117  			}
   118  		}
   119  
   120  		if skip {
   121  			continue
   122  		}
   123  
   124  		middleware, ok := m.items[g]
   125  		if !ok {
   126  			m.log.Errorf(MiddlewareNotFoundError, handler.Path, g, g)
   127  			return h, fmt.Errorf("can not find global server middleware: %s", g)
   128  		}
   129  
   130  		mws = append(mws, middleware)
   131  	}
   132  
   133  	sort.Slice(mws, func(i, j int) bool {
   134  		return mws[i].Order() < mws[j].Order()
   135  	})
   136  
   137  	for _, mw := range mws {
   138  		m.log.V(5).Infof("apply middleware %s to %s", mw.Kind(), handler.Path)
   139  		h = mw.Apply(h)
   140  	}
   141  
   142  	return h, nil
   143  }
   144  
   145  func newMiddlewares(log logger.Logger) *Middlewares {
   146  	middlewares := Middlewares{
   147  		log:          log,
   148  		global:       make([]server.KindMiddleware, 0),
   149  		constructors: make([]interface{}, 0),
   150  		items:        make(map[server.KindMiddleware]server.HttpServerMiddleware),
   151  	}
   152  
   153  	return &middlewares
   154  }