github.com/crewjam/saml@v0.4.14/samlsp/request_tracker_cookie.go (about) 1 package samlsp 2 3 import ( 4 "encoding/base64" 5 "fmt" 6 "net/http" 7 "strings" 8 "time" 9 10 "github.com/crewjam/saml" 11 ) 12 13 var _ RequestTracker = CookieRequestTracker{} 14 15 // CookieRequestTracker tracks requests by setting a uniquely named 16 // cookie for each request. 17 type CookieRequestTracker struct { 18 ServiceProvider *saml.ServiceProvider 19 NamePrefix string 20 Codec TrackedRequestCodec 21 MaxAge time.Duration 22 RelayStateFunc func(w http.ResponseWriter, r *http.Request) string 23 SameSite http.SameSite 24 } 25 26 // TrackRequest starts tracking the SAML request with the given ID. It returns an 27 // `index` that should be used as the RelayState in the SAMl request flow. 28 func (t CookieRequestTracker) TrackRequest(w http.ResponseWriter, r *http.Request, samlRequestID string) (string, error) { 29 trackedRequest := TrackedRequest{ 30 Index: base64.RawURLEncoding.EncodeToString(randomBytes(42)), 31 SAMLRequestID: samlRequestID, 32 URI: r.URL.String(), 33 } 34 35 if t.RelayStateFunc != nil { 36 relayState := t.RelayStateFunc(w, r) 37 if relayState != "" { 38 trackedRequest.Index = relayState 39 } 40 } 41 42 signedTrackedRequest, err := t.Codec.Encode(trackedRequest) 43 if err != nil { 44 return "", err 45 } 46 47 http.SetCookie(w, &http.Cookie{ 48 Name: t.NamePrefix + trackedRequest.Index, 49 Value: signedTrackedRequest, 50 MaxAge: int(t.MaxAge.Seconds()), 51 HttpOnly: true, 52 SameSite: t.SameSite, 53 Secure: t.ServiceProvider.AcsURL.Scheme == "https", 54 Path: t.ServiceProvider.AcsURL.Path, 55 }) 56 57 return trackedRequest.Index, nil 58 } 59 60 // StopTrackingRequest stops tracking the SAML request given by index, which is a string 61 // previously returned from TrackRequest 62 func (t CookieRequestTracker) StopTrackingRequest(w http.ResponseWriter, r *http.Request, index string) error { 63 cookie, err := r.Cookie(t.NamePrefix + index) 64 if err != nil { 65 return err 66 } 67 cookie.Value = "" 68 cookie.Domain = t.ServiceProvider.AcsURL.Hostname() 69 cookie.Expires = time.Unix(1, 0) // past time as close to epoch as possible, but not zero time.Time{} 70 http.SetCookie(w, cookie) 71 return nil 72 } 73 74 // GetTrackedRequests returns all the pending tracked requests 75 func (t CookieRequestTracker) GetTrackedRequests(r *http.Request) []TrackedRequest { 76 rv := []TrackedRequest{} 77 for _, cookie := range r.Cookies() { 78 if !strings.HasPrefix(cookie.Name, t.NamePrefix) { 79 continue 80 } 81 82 trackedRequest, err := t.Codec.Decode(cookie.Value) 83 if err != nil { 84 continue 85 } 86 index := strings.TrimPrefix(cookie.Name, t.NamePrefix) 87 if index != trackedRequest.Index { 88 continue 89 } 90 91 rv = append(rv, *trackedRequest) 92 } 93 return rv 94 } 95 96 // GetTrackedRequest returns a pending tracked request. 97 func (t CookieRequestTracker) GetTrackedRequest(r *http.Request, index string) (*TrackedRequest, error) { 98 cookie, err := r.Cookie(t.NamePrefix + index) 99 if err != nil { 100 return nil, err 101 } 102 103 trackedRequest, err := t.Codec.Decode(cookie.Value) 104 if err != nil { 105 return nil, err 106 } 107 if trackedRequest.Index != index { 108 return nil, fmt.Errorf("expected index %q, got %q", index, trackedRequest.Index) 109 } 110 return trackedRequest, nil 111 }