github.com/crewjam/saml@v0.4.14/samlsp/request_tracker_cookie.go (about)

     1  package samlsp
     2  
     3  import (
     4  	"encoding/base64"
     5  	"fmt"
     6  	"net/http"
     7  	"strings"
     8  	"time"
     9  
    10  	"github.com/crewjam/saml"
    11  )
    12  
    13  var _ RequestTracker = CookieRequestTracker{}
    14  
    15  // CookieRequestTracker tracks requests by setting a uniquely named
    16  // cookie for each request.
    17  type CookieRequestTracker struct {
    18  	ServiceProvider *saml.ServiceProvider
    19  	NamePrefix      string
    20  	Codec           TrackedRequestCodec
    21  	MaxAge          time.Duration
    22  	RelayStateFunc  func(w http.ResponseWriter, r *http.Request) string
    23  	SameSite        http.SameSite
    24  }
    25  
    26  // TrackRequest starts tracking the SAML request with the given ID. It returns an
    27  // `index` that should be used as the RelayState in the SAMl request flow.
    28  func (t CookieRequestTracker) TrackRequest(w http.ResponseWriter, r *http.Request, samlRequestID string) (string, error) {
    29  	trackedRequest := TrackedRequest{
    30  		Index:         base64.RawURLEncoding.EncodeToString(randomBytes(42)),
    31  		SAMLRequestID: samlRequestID,
    32  		URI:           r.URL.String(),
    33  	}
    34  
    35  	if t.RelayStateFunc != nil {
    36  		relayState := t.RelayStateFunc(w, r)
    37  		if relayState != "" {
    38  			trackedRequest.Index = relayState
    39  		}
    40  	}
    41  
    42  	signedTrackedRequest, err := t.Codec.Encode(trackedRequest)
    43  	if err != nil {
    44  		return "", err
    45  	}
    46  
    47  	http.SetCookie(w, &http.Cookie{
    48  		Name:     t.NamePrefix + trackedRequest.Index,
    49  		Value:    signedTrackedRequest,
    50  		MaxAge:   int(t.MaxAge.Seconds()),
    51  		HttpOnly: true,
    52  		SameSite: t.SameSite,
    53  		Secure:   t.ServiceProvider.AcsURL.Scheme == "https",
    54  		Path:     t.ServiceProvider.AcsURL.Path,
    55  	})
    56  
    57  	return trackedRequest.Index, nil
    58  }
    59  
    60  // StopTrackingRequest stops tracking the SAML request given by index, which is a string
    61  // previously returned from TrackRequest
    62  func (t CookieRequestTracker) StopTrackingRequest(w http.ResponseWriter, r *http.Request, index string) error {
    63  	cookie, err := r.Cookie(t.NamePrefix + index)
    64  	if err != nil {
    65  		return err
    66  	}
    67  	cookie.Value = ""
    68  	cookie.Domain = t.ServiceProvider.AcsURL.Hostname()
    69  	cookie.Expires = time.Unix(1, 0) // past time as close to epoch as possible, but not zero time.Time{}
    70  	http.SetCookie(w, cookie)
    71  	return nil
    72  }
    73  
    74  // GetTrackedRequests returns all the pending tracked requests
    75  func (t CookieRequestTracker) GetTrackedRequests(r *http.Request) []TrackedRequest {
    76  	rv := []TrackedRequest{}
    77  	for _, cookie := range r.Cookies() {
    78  		if !strings.HasPrefix(cookie.Name, t.NamePrefix) {
    79  			continue
    80  		}
    81  
    82  		trackedRequest, err := t.Codec.Decode(cookie.Value)
    83  		if err != nil {
    84  			continue
    85  		}
    86  		index := strings.TrimPrefix(cookie.Name, t.NamePrefix)
    87  		if index != trackedRequest.Index {
    88  			continue
    89  		}
    90  
    91  		rv = append(rv, *trackedRequest)
    92  	}
    93  	return rv
    94  }
    95  
    96  // GetTrackedRequest returns a pending tracked request.
    97  func (t CookieRequestTracker) GetTrackedRequest(r *http.Request, index string) (*TrackedRequest, error) {
    98  	cookie, err := r.Cookie(t.NamePrefix + index)
    99  	if err != nil {
   100  		return nil, err
   101  	}
   102  
   103  	trackedRequest, err := t.Codec.Decode(cookie.Value)
   104  	if err != nil {
   105  		return nil, err
   106  	}
   107  	if trackedRequest.Index != index {
   108  		return nil, fmt.Errorf("expected index %q, got %q", index, trackedRequest.Index)
   109  	}
   110  	return trackedRequest, nil
   111  }