github.com/gofiber/fiber/v2@v2.47.0/middleware/idempotency/idempotency.go (about)

     1  package idempotency
     2  
     3  import (
     4  	"fmt"
     5  	"log"
     6  	"strings"
     7  
     8  	"github.com/gofiber/fiber/v2"
     9  	"github.com/gofiber/fiber/v2/utils"
    10  )
    11  
    12  // Inspired by https://datatracker.ietf.org/doc/html/draft-ietf-httpapi-idempotency-key-header-02
    13  // and https://github.com/penguin-statistics/backend-next/blob/f2f7d5ba54fc8a58f168d153baa17b2ad4a14e45/internal/pkg/middlewares/idempotency.go
    14  
    15  const (
    16  	localsKeyIsFromCache   = "idempotency_isfromcache"
    17  	localsKeyWasPutToCache = "idempotency_wasputtocache"
    18  )
    19  
    20  func IsFromCache(c *fiber.Ctx) bool {
    21  	return c.Locals(localsKeyIsFromCache) != nil
    22  }
    23  
    24  func WasPutToCache(c *fiber.Ctx) bool {
    25  	return c.Locals(localsKeyWasPutToCache) != nil
    26  }
    27  
    28  func New(config ...Config) fiber.Handler {
    29  	// Set default config
    30  	cfg := configDefault(config...)
    31  
    32  	keepResponseHeadersMap := make(map[string]struct{}, len(cfg.KeepResponseHeaders))
    33  	for _, h := range cfg.KeepResponseHeaders {
    34  		keepResponseHeadersMap[strings.ToLower(h)] = struct{}{}
    35  	}
    36  
    37  	maybeWriteCachedResponse := func(c *fiber.Ctx, key string) (bool, error) {
    38  		if val, err := cfg.Storage.Get(key); err != nil {
    39  			return false, fmt.Errorf("failed to read response: %w", err)
    40  		} else if val != nil {
    41  			var res response
    42  			if _, err := res.UnmarshalMsg(val); err != nil {
    43  				return false, fmt.Errorf("failed to unmarshal response: %w", err)
    44  			}
    45  
    46  			_ = c.Status(res.StatusCode)
    47  
    48  			for header, val := range res.Headers {
    49  				c.Set(header, val)
    50  			}
    51  
    52  			if len(res.Body) != 0 {
    53  				if err := c.Send(res.Body); err != nil {
    54  					return true, err
    55  				}
    56  			}
    57  
    58  			_ = c.Locals(localsKeyIsFromCache, true)
    59  
    60  			return true, nil
    61  		}
    62  
    63  		return false, nil
    64  	}
    65  
    66  	return func(c *fiber.Ctx) error {
    67  		// Don't execute middleware if Next returns true
    68  		if cfg.Next != nil && cfg.Next(c) {
    69  			return c.Next()
    70  		}
    71  
    72  		// Don't execute middleware if the idempotency key is empty
    73  		key := utils.CopyString(c.Get(cfg.KeyHeader))
    74  		if key == "" {
    75  			return c.Next()
    76  		}
    77  
    78  		// Validate key
    79  		if err := cfg.KeyHeaderValidate(key); err != nil {
    80  			return err
    81  		}
    82  
    83  		// First-pass: if the idempotency key is in the storage, get and return the response
    84  		if ok, err := maybeWriteCachedResponse(c, key); err != nil {
    85  			return fmt.Errorf("failed to write cached response at fastpath: %w", err)
    86  		} else if ok {
    87  			return nil
    88  		}
    89  
    90  		if err := cfg.Lock.Lock(key); err != nil {
    91  			return fmt.Errorf("failed to lock: %w", err)
    92  		}
    93  		defer func() {
    94  			if err := cfg.Lock.Unlock(key); err != nil {
    95  				log.Printf("[Error] - [IDEMPOTENCY] failed to unlock key %q: %v", key, err)
    96  			}
    97  		}()
    98  
    99  		// Lock acquired. If the idempotency key now is in the storage, get and return the response
   100  		if ok, err := maybeWriteCachedResponse(c, key); err != nil {
   101  			return fmt.Errorf("failed to write cached response while locked: %w", err)
   102  		} else if ok {
   103  			return nil
   104  		}
   105  
   106  		// Execute the request handler
   107  		if err := c.Next(); err != nil {
   108  			// If the request handler returned an error, return it and skip idempotency
   109  			return err
   110  		}
   111  
   112  		// Construct response
   113  		res := &response{
   114  			StatusCode: c.Response().StatusCode(),
   115  
   116  			Body: utils.CopyBytes(c.Response().Body()),
   117  		}
   118  		{
   119  			headers := c.GetRespHeaders()
   120  			if cfg.KeepResponseHeaders == nil {
   121  				// Keep all
   122  				res.Headers = headers
   123  			} else {
   124  				// Filter
   125  				res.Headers = make(map[string]string)
   126  				for h := range headers {
   127  					if _, ok := keepResponseHeadersMap[utils.ToLower(h)]; ok {
   128  						res.Headers[h] = headers[h]
   129  					}
   130  				}
   131  			}
   132  		}
   133  
   134  		// Marshal response
   135  		bs, err := res.MarshalMsg(nil)
   136  		if err != nil {
   137  			return fmt.Errorf("failed to marshal response: %w", err)
   138  		}
   139  
   140  		// Store response
   141  		if err := cfg.Storage.Set(key, bs, cfg.Lifetime); err != nil {
   142  			return fmt.Errorf("failed to save response: %w", err)
   143  		}
   144  
   145  		_ = c.Locals(localsKeyWasPutToCache, true)
   146  
   147  		return nil
   148  	}
   149  }