github.com/lingyao2333/mo-zero@v1.4.1/rest/handler/sheddinghandler.go (about)

     1  package handler
     2  
     3  import (
     4  	"net/http"
     5  	"sync"
     6  
     7  	"github.com/lingyao2333/mo-zero/core/load"
     8  	"github.com/lingyao2333/mo-zero/core/logx"
     9  	"github.com/lingyao2333/mo-zero/core/stat"
    10  	"github.com/lingyao2333/mo-zero/rest/httpx"
    11  	"github.com/lingyao2333/mo-zero/rest/internal/response"
    12  )
    13  
    14  const serviceType = "api"
    15  
    16  var (
    17  	sheddingStat *load.SheddingStat
    18  	lock         sync.Mutex
    19  )
    20  
    21  // SheddingHandler returns a middleware that does load shedding.
    22  func SheddingHandler(shedder load.Shedder, metrics *stat.Metrics) func(http.Handler) http.Handler {
    23  	if shedder == nil {
    24  		return func(next http.Handler) http.Handler {
    25  			return next
    26  		}
    27  	}
    28  
    29  	ensureSheddingStat()
    30  
    31  	return func(next http.Handler) http.Handler {
    32  		return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    33  			sheddingStat.IncrementTotal()
    34  			promise, err := shedder.Allow()
    35  			if err != nil {
    36  				metrics.AddDrop()
    37  				sheddingStat.IncrementDrop()
    38  				logx.Errorf("[http] dropped, %s - %s - %s",
    39  					r.RequestURI, httpx.GetRemoteAddr(r), r.UserAgent())
    40  				w.WriteHeader(http.StatusServiceUnavailable)
    41  				return
    42  			}
    43  
    44  			cw := &response.WithCodeResponseWriter{Writer: w}
    45  			defer func() {
    46  				if cw.Code == http.StatusServiceUnavailable {
    47  					promise.Fail()
    48  				} else {
    49  					sheddingStat.IncrementPass()
    50  					promise.Pass()
    51  				}
    52  			}()
    53  			next.ServeHTTP(cw, r)
    54  		})
    55  	}
    56  }
    57  
    58  func ensureSheddingStat() {
    59  	lock.Lock()
    60  	if sheddingStat == nil {
    61  		sheddingStat = load.NewSheddingStat(serviceType)
    62  	}
    63  	lock.Unlock()
    64  }