github.com/crewjam/saml@v0.4.14/samlsp/fetch_metadata.go (about) 1 package samlsp 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/xml" 7 "errors" 8 "io" 9 "net/http" 10 "net/url" 11 12 "github.com/crewjam/httperr" 13 xrv "github.com/mattermost/xml-roundtrip-validator" 14 15 "github.com/crewjam/saml/logger" 16 17 "github.com/crewjam/saml" 18 ) 19 20 // ParseMetadata parses arbitrary SAML IDP metadata. 21 // 22 // Note: this is needed because IDP metadata is sometimes wrapped in 23 // an <EntitiesDescriptor>, and sometimes the top level element is an 24 // <EntityDescriptor>. 25 func ParseMetadata(data []byte) (*saml.EntityDescriptor, error) { 26 entity := &saml.EntityDescriptor{} 27 28 if err := xrv.Validate(bytes.NewBuffer(data)); err != nil { 29 return nil, err 30 } 31 32 err := xml.Unmarshal(data, entity) 33 34 // this comparison is ugly, but it is how the error is generated in encoding/xml 35 if err != nil && err.Error() == "expected element type <EntityDescriptor> but have <EntitiesDescriptor>" { 36 entities := &saml.EntitiesDescriptor{} 37 if err := xml.Unmarshal(data, entities); err != nil { 38 return nil, err 39 } 40 41 for i, e := range entities.EntityDescriptors { 42 if len(e.IDPSSODescriptors) > 0 { 43 return &entities.EntityDescriptors[i], nil 44 } 45 } 46 return nil, errors.New("no entity found with IDPSSODescriptor") 47 } 48 if err != nil { 49 return nil, err 50 } 51 return entity, nil 52 } 53 54 // FetchMetadata returns metadata from an IDP metadata URL. 55 func FetchMetadata(ctx context.Context, httpClient *http.Client, metadataURL url.URL) (*saml.EntityDescriptor, error) { 56 req, err := http.NewRequest("GET", metadataURL.String(), nil) 57 if err != nil { 58 return nil, err 59 } 60 req = req.WithContext(ctx) 61 62 resp, err := httpClient.Do(req) 63 if err != nil { 64 return nil, err 65 } 66 defer func() { 67 if err := resp.Body.Close(); err != nil { 68 logger.DefaultLogger.Printf("Error while closing response body during fetch metadata: %v", err) 69 } 70 }() 71 if resp.StatusCode >= 400 { 72 return nil, httperr.Response(*resp) 73 } 74 75 data, err := io.ReadAll(resp.Body) 76 if err != nil { 77 return nil, err 78 } 79 80 return ParseMetadata(data) 81 }