github.com/Axway/agent-sdk@v1.1.101/pkg/apic/specoas3processor.go (about)

     1  package apic
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"net/url"
     7  	"sort"
     8  	"strconv"
     9  	"strings"
    10  
    11  	"github.com/Axway/agent-sdk/pkg/util"
    12  	coreerrors "github.com/Axway/agent-sdk/pkg/util/errors"
    13  	"github.com/Axway/agent-sdk/pkg/util/log"
    14  	"github.com/getkin/kin-openapi/openapi3"
    15  )
    16  
    17  // oas3SpecProcessor parses and validates an OAS3 spec, and exposes methods to modify the content of the spec.
    18  type oas3SpecProcessor struct {
    19  	spec         *openapi3.T
    20  	scopes       map[string]string
    21  	authPolicies []string
    22  	apiKeyInfo   []APIKeyInfo
    23  }
    24  
    25  func newOas3Processor(oas3Obj *openapi3.T) *oas3SpecProcessor {
    26  	return &oas3SpecProcessor{spec: oas3Obj}
    27  }
    28  
    29  func (p *oas3SpecProcessor) GetResourceType() string {
    30  	return Oas3
    31  }
    32  
    33  // GetVersion -
    34  func (p *oas3SpecProcessor) GetVersion() string {
    35  	return p.spec.Info.Version
    36  }
    37  
    38  // GetEndpoints -
    39  func (p *oas3SpecProcessor) GetEndpoints() ([]EndpointDefinition, error) {
    40  	endPoints := []EndpointDefinition{}
    41  	if len(p.spec.Servers) > 0 {
    42  		var err error
    43  		endPoints, err = p.parseEndpoints(p.spec.Servers)
    44  		if err != nil {
    45  			return nil, coreerrors.Wrap(ErrSetSpecEndPoints, err.Error())
    46  		}
    47  		return endPoints, nil
    48  	}
    49  	if len(endPoints) == 0 {
    50  		return nil, coreerrors.Wrap(ErrSetSpecEndPoints, "no server endpoints defined")
    51  	}
    52  	return endPoints, nil
    53  }
    54  
    55  func (p *oas3SpecProcessor) parseEndpoints(servers []*openapi3.Server) ([]EndpointDefinition, error) {
    56  	endPoints := []EndpointDefinition{}
    57  	for _, server := range servers {
    58  		// Add the URL string to the array
    59  		allURLs := []string{
    60  			server.URL,
    61  		}
    62  
    63  		defaultURL := ""
    64  		var err error
    65  		if server.Variables != nil {
    66  			defaultURL, allURLs, err = p.handleURLSubstitutions(server, allURLs)
    67  			if err != nil {
    68  				return nil, err
    69  			}
    70  		}
    71  
    72  		parsedEndPoints, err := p.parseURLsIntoEndpoints(defaultURL, allURLs)
    73  		if err != nil {
    74  			return nil, err
    75  		}
    76  		endPoints = append(endPoints, parsedEndPoints...)
    77  	}
    78  	return endPoints, nil
    79  }
    80  
    81  func (p *oas3SpecProcessor) handleURLSubstitutions(server *openapi3.Server, allURLs []string) (string, []string, error) {
    82  	defaultURL := server.URL
    83  	// Handle substitutions
    84  	for serverKey, serverVar := range server.Variables {
    85  		newURLs := []string{}
    86  		if serverVar.Default == "" {
    87  			err := fmt.Errorf("server variable in OAS3 %s does not have a default value, spec not valid", serverKey)
    88  			log.Errorf(err.Error())
    89  			return "", nil, err
    90  		}
    91  		defaultURL = strings.ReplaceAll(defaultURL, fmt.Sprintf("{%s}", serverKey), serverVar.Default)
    92  		if len(serverVar.Enum) == 0 {
    93  			newURLs = p.processURLSubstitutions(allURLs, newURLs, serverKey, serverVar.Default)
    94  		} else {
    95  			for _, enumVal := range serverVar.Enum {
    96  				newURLs = p.processURLSubstitutions(allURLs, newURLs, serverKey, enumVal)
    97  			}
    98  		}
    99  		allURLs = newURLs
   100  	}
   101  
   102  	return defaultURL, allURLs, nil
   103  }
   104  
   105  func (p *oas3SpecProcessor) processURLSubstitutions(allURLs, newURLs []string, varName, varValue string) []string {
   106  	for _, template := range allURLs {
   107  		newURLs = append(newURLs, strings.ReplaceAll(template, fmt.Sprintf("{%s}", varName), varValue))
   108  	}
   109  	return newURLs
   110  }
   111  
   112  func (p *oas3SpecProcessor) parseURL(urlStr string) (*url.URL, error) {
   113  	urlObj, err := url.Parse(urlStr)
   114  	if err != nil {
   115  		return nil, err
   116  	}
   117  	if urlObj.Scheme == "" {
   118  		urlObj, err = p.parseURL("https://" + urlStr)
   119  	}
   120  	return urlObj, err
   121  }
   122  
   123  func (p *oas3SpecProcessor) parseURLsIntoEndpoints(defaultURL string, allURLs []string) ([]EndpointDefinition, error) {
   124  	endPoints := []EndpointDefinition{}
   125  	for _, urlStr := range allURLs {
   126  		if urlStr == "" {
   127  			return nil, fmt.Errorf("server definition cannot have empty url")
   128  		}
   129  		urlObj, err := p.parseURL(urlStr)
   130  		if err != nil {
   131  			return nil, err
   132  		}
   133  		if urlObj.Hostname() == "" {
   134  			err = fmt.Errorf("could not parse url: %s", urlStr)
   135  			return nil, err
   136  		}
   137  		port := 0
   138  		if urlObj.Port() != "" {
   139  			port, _ = strconv.Atoi(urlObj.Port())
   140  		}
   141  		endPoint := createEndpointDefinition(urlObj.Scheme, urlObj.Hostname(), port, urlObj.Path)
   142  		// If the URL is the default URL put it at the front of the array
   143  		if urlStr == defaultURL {
   144  			newEndPoints := []EndpointDefinition{endPoint}
   145  			newEndPoints = append(newEndPoints, endPoints...)
   146  			endPoints = newEndPoints
   147  		} else {
   148  			endPoints = append(endPoints, endPoint)
   149  		}
   150  	}
   151  
   152  	return endPoints, nil
   153  }
   154  
   155  func (p *oas3SpecProcessor) ParseAuthInfo() {
   156  	p.authPolicies = []string{}
   157  	p.apiKeyInfo = []APIKeyInfo{}
   158  	p.scopes = make(map[string]string)
   159  
   160  	if p.spec.Components == nil {
   161  		return
   162  	}
   163  	for _, scheme := range p.spec.Components.SecuritySchemes {
   164  		switch scheme.Value.Type {
   165  		case oasSecurityHttp:
   166  			if scheme.Value.Scheme == oasSecurityBasic {
   167  				p.authPolicies = append(p.authPolicies, Basic)
   168  			}
   169  		case oasSecurityAPIKey:
   170  			p.authPolicies = append(p.authPolicies, Apikey)
   171  			p.apiKeyInfo = append(p.apiKeyInfo, APIKeyInfo{
   172  				Location: scheme.Value.In,
   173  				Name:     scheme.Value.Name,
   174  			})
   175  		case oasSecurityOauth:
   176  			p.authPolicies = append(p.authPolicies, Oauth)
   177  			if scheme.Value.Flows != nil {
   178  				if scheme.Value.Flows.ClientCredentials != nil {
   179  					p.scopes = util.MergeMapStringString(p.scopes, scheme.Value.Flows.ClientCredentials.Scopes)
   180  				}
   181  				if scheme.Value.Flows.Implicit != nil {
   182  					p.scopes = util.MergeMapStringString(p.scopes, scheme.Value.Flows.Implicit.Scopes)
   183  				}
   184  				if scheme.Value.Flows.AuthorizationCode != nil {
   185  					p.scopes = util.MergeMapStringString(p.scopes, scheme.Value.Flows.AuthorizationCode.Scopes)
   186  				}
   187  			}
   188  		}
   189  	}
   190  	p.authPolicies = util.RemoveDuplicateValuesFromStringSlice(p.authPolicies)
   191  	sort.Strings(p.authPolicies)
   192  }
   193  
   194  func (p *oas3SpecProcessor) GetAuthPolicies() []string {
   195  	return p.authPolicies
   196  }
   197  
   198  func (p *oas3SpecProcessor) GetOAuthScopes() map[string]string {
   199  	return p.scopes
   200  }
   201  
   202  func (p *oas3SpecProcessor) GetAPIKeyInfo() []APIKeyInfo {
   203  	return p.apiKeyInfo
   204  }
   205  
   206  func (p *oas3SpecProcessor) GetTitle() string {
   207  	return p.spec.Info.Title
   208  }
   209  
   210  func (p *oas3SpecProcessor) GetDescription() string {
   211  	return p.spec.Info.Description
   212  }
   213  
   214  func (p *oas3SpecProcessor) StripSpecAuth() {
   215  	p.spec.Components.SecuritySchemes = openapi3.SecuritySchemes{}
   216  }
   217  
   218  func (p *oas3SpecProcessor) GetSpecBytes() []byte {
   219  	s, _ := json.Marshal(p.spec)
   220  	return s
   221  }
   222  
   223  func (p *oas3SpecProcessor) GetSecurityBuilder() SecurityBuilder {
   224  	return newSpecSecurityBuilder(oas3)
   225  }
   226  
   227  func (p *oas3SpecProcessor) AddSecuritySchemes(authSchemes map[string]interface{}) {
   228  	for name, scheme := range util.OrderStringsInMap(authSchemes) {
   229  		s, ok := scheme.(*openapi3.SecurityScheme)
   230  		if !ok {
   231  			continue
   232  		}
   233  		p.spec.Components.SecuritySchemes[name] = &openapi3.SecuritySchemeRef{
   234  			Value: s,
   235  		}
   236  	}
   237  }