github.com/snyk/vervet/v5@v5.11.1-0.20240202085829-ad4dd7fb6101/remove_elements.go (about)

     1  package vervet
     2  
     3  import (
     4  	"reflect"
     5  	"regexp"
     6  
     7  	"github.com/getkin/kin-openapi/openapi3"
     8  	"github.com/mitchellh/reflectwalk"
     9  )
    10  
    11  // ExcludePatterns defines patterns matching elements to be removed from an
    12  // OpenAPI document.
    13  type ExcludePatterns struct {
    14  	ExtensionPatterns []string
    15  	HeaderPatterns    []string
    16  	Paths             []string
    17  }
    18  
    19  type excluder struct {
    20  	doc *openapi3.T
    21  
    22  	extensionPatterns []*regexp.Regexp
    23  	headerPatterns    []*regexp.Regexp
    24  	paths             []string
    25  }
    26  
    27  // RemoveElements removes those elements from an OpenAPI document matching the
    28  // given exclude patterns.
    29  func RemoveElements(doc *openapi3.T, excludes ExcludePatterns) error {
    30  	ex := &excluder{
    31  		doc:               doc,
    32  		extensionPatterns: make([]*regexp.Regexp, len(excludes.ExtensionPatterns)),
    33  		headerPatterns:    make([]*regexp.Regexp, len(excludes.HeaderPatterns)),
    34  		paths:             excludes.Paths,
    35  	}
    36  	for i, pat := range excludes.ExtensionPatterns {
    37  		re, err := regexp.Compile(pat)
    38  		if err != nil {
    39  			return err
    40  		}
    41  		ex.extensionPatterns[i] = re
    42  	}
    43  	for i, pat := range excludes.HeaderPatterns {
    44  		re, err := regexp.Compile(pat)
    45  		if err != nil {
    46  			return err
    47  		}
    48  		ex.headerPatterns[i] = re
    49  	}
    50  	// Remove excluded paths
    51  	excludedPaths := map[string]struct{}{}
    52  	for path := range doc.Paths {
    53  		if ex.isExcludedPath(path) {
    54  			excludedPaths[path] = struct{}{}
    55  		}
    56  	}
    57  	for path := range excludedPaths {
    58  		delete(doc.Paths, path)
    59  	}
    60  	// Remove excluded elements
    61  	if err := ex.apply(); err != nil {
    62  		return err
    63  	}
    64  	return nil
    65  }
    66  
    67  func (ex *excluder) apply() error {
    68  	return reflectwalk.Walk(ex.doc, ex)
    69  }
    70  
    71  // Struct implements reflectwalk.StructWalker.
    72  func (ex *excluder) Struct(v reflect.Value) error {
    73  	if !v.CanInterface() {
    74  		return nil
    75  	}
    76  
    77  	switch v.Interface().(type) {
    78  	case openapi3.Operation:
    79  		ex.applyOperation(v.Addr().Interface().(*openapi3.Operation))
    80  	}
    81  
    82  	return nil
    83  }
    84  
    85  // StructField implements reflectwalk.StructWalker.
    86  func (ex *excluder) StructField(field reflect.StructField, v reflect.Value) error {
    87  	if field.Name != "Extensions" || !v.CanInterface() {
    88  		return nil
    89  	}
    90  
    91  	switch v.Interface().(type) {
    92  	case map[string]interface{}:
    93  		ex.applyExtensions(v.Addr().Interface().(*map[string]interface{}))
    94  	}
    95  
    96  	return nil
    97  }
    98  
    99  func (ex *excluder) applyExtensions(extensions *map[string]interface{}) {
   100  	exts := make(map[string]interface{}, len(*extensions))
   101  	for k, v := range *extensions {
   102  		if !ex.isExcludedExtension(k) {
   103  			exts[k] = v
   104  		}
   105  	}
   106  	*extensions = exts
   107  }
   108  
   109  func (ex *excluder) applyOperation(op *openapi3.Operation) {
   110  	var params []*openapi3.ParameterRef
   111  	for _, p := range op.Parameters {
   112  		if !ex.isExcludedHeaderParam(p) {
   113  			params = append(params, p)
   114  		}
   115  	}
   116  	op.Parameters = params
   117  
   118  	for _, resp := range op.Responses {
   119  		if resp.Value == nil {
   120  			continue
   121  		}
   122  		headers := openapi3.Headers{}
   123  		for headerName, header := range resp.Value.Headers {
   124  			var matched bool
   125  			for _, re := range ex.headerPatterns {
   126  				if re.MatchString(headerName) {
   127  					matched = true
   128  					break
   129  				}
   130  			}
   131  			if !matched {
   132  				headers[headerName] = header
   133  			}
   134  		}
   135  		resp.Value.Headers = headers
   136  	}
   137  }
   138  
   139  func (ex *excluder) isExcludedExtension(name string) bool {
   140  	for _, re := range ex.extensionPatterns {
   141  		if re.MatchString(name) {
   142  			return true
   143  		}
   144  	}
   145  	return false
   146  }
   147  
   148  func (ex *excluder) isExcludedPath(path string) bool {
   149  	for _, matchPath := range ex.paths {
   150  		if matchPath == path {
   151  			return true
   152  		}
   153  	}
   154  	return false
   155  }
   156  
   157  func (ex *excluder) isExcludedHeaderParam(p *openapi3.ParameterRef) bool {
   158  	if p.Value == nil {
   159  		return false
   160  	}
   161  	if p.Value.In != "header" {
   162  		return false
   163  	}
   164  	for _, re := range ex.headerPatterns {
   165  		if re.MatchString(p.Value.Name) {
   166  			return true
   167  		}
   168  	}
   169  	return false
   170  }