github.com/Diggs/controller-tools@v0.4.2/pkg/crd/flatten.go (about)

     1  /*
     2  Copyright 2019 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package crd
    18  
    19  import (
    20  	"fmt"
    21  	"reflect"
    22  	"sort"
    23  	"strings"
    24  	"sync"
    25  
    26  	apiext "k8s.io/apiextensions-apiserver/pkg/apis/apiextensions/v1"
    27  
    28  	"github.com/Diggs/controller-tools/pkg/loader"
    29  )
    30  
    31  // ErrorRecorder knows how to record errors.  It wraps the part of
    32  // pkg/loader.Package that we need to record errors in places were it might not
    33  // make sense to have a loader.Package
    34  type ErrorRecorder interface {
    35  	// AddError records that the given error occurred.
    36  	// See the documentation on loader.Package.AddError for more information.
    37  	AddError(error)
    38  }
    39  
    40  // isOrNil checks if val is nil if val is of a nillable type, otherwise,
    41  // it compares val to valInt (which should probably be the zero value).
    42  func isOrNil(val reflect.Value, valInt interface{}, zeroInt interface{}) bool {
    43  	switch valKind := val.Kind(); valKind {
    44  	case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
    45  		return val.IsNil()
    46  	default:
    47  		return valInt == zeroInt
    48  	}
    49  }
    50  
    51  // flattenAllOfInto copies properties from src to dst, then copies the properties
    52  // of each item in src's allOf to dst's properties as well.
    53  func flattenAllOfInto(dst *apiext.JSONSchemaProps, src apiext.JSONSchemaProps, errRec ErrorRecorder) {
    54  	if len(src.AllOf) > 0 {
    55  		for _, embedded := range src.AllOf {
    56  			flattenAllOfInto(dst, embedded, errRec)
    57  		}
    58  	}
    59  
    60  	dstVal := reflect.Indirect(reflect.ValueOf(dst))
    61  	srcVal := reflect.ValueOf(src)
    62  	typ := dstVal.Type()
    63  
    64  	srcRemainder := apiext.JSONSchemaProps{}
    65  	srcRemVal := reflect.Indirect(reflect.ValueOf(&srcRemainder))
    66  	dstRemainder := apiext.JSONSchemaProps{}
    67  	dstRemVal := reflect.Indirect(reflect.ValueOf(&dstRemainder))
    68  	hoisted := false
    69  
    70  	for i := 0; i < srcVal.NumField(); i++ {
    71  		fieldName := typ.Field(i).Name
    72  		switch fieldName {
    73  		case "AllOf":
    74  			// don't merge because we deal with it above
    75  			continue
    76  		case "Title", "Description", "Example", "ExternalDocs":
    77  			// don't merge because we pre-merge to properly preserve field docs
    78  			continue
    79  		}
    80  		srcField := srcVal.Field(i)
    81  		fldTyp := srcField.Type()
    82  		zeroVal := reflect.Zero(fldTyp)
    83  		zeroInt := zeroVal.Interface()
    84  		srcInt := srcField.Interface()
    85  
    86  		if isOrNil(srcField, srcInt, zeroInt) {
    87  			// nothing to copy from src, continue
    88  			continue
    89  		}
    90  
    91  		dstField := dstVal.Field(i)
    92  		dstInt := dstField.Interface()
    93  		if isOrNil(dstField, dstInt, zeroInt) {
    94  			// dst is empty, continue
    95  			dstField.Set(srcField)
    96  			continue
    97  		}
    98  
    99  		if fldTyp.Comparable() && srcInt == dstInt {
   100  			// same value, continue
   101  			continue
   102  		}
   103  
   104  		// resolve conflict
   105  		switch fieldName {
   106  		case "Properties":
   107  			// merge if possible, use all of otherwise
   108  			srcMap := srcInt.(map[string]apiext.JSONSchemaProps)
   109  			dstMap := dstInt.(map[string]apiext.JSONSchemaProps)
   110  
   111  			for k, v := range srcMap {
   112  				dstProp, exists := dstMap[k]
   113  				if !exists {
   114  					dstMap[k] = v
   115  					continue
   116  				}
   117  				flattenAllOfInto(&dstProp, v, errRec)
   118  				dstMap[k] = dstProp
   119  			}
   120  		case "Required":
   121  			// merge
   122  			dstField.Set(reflect.AppendSlice(dstField, srcField))
   123  		case "Type":
   124  			if srcInt != dstInt {
   125  				// TODO(directxman12): figure out how to attach this back to a useful point in the Go source or in the schema
   126  				errRec.AddError(fmt.Errorf("conflicting types in allOf branches in schema: %s vs %s", dstInt, srcInt))
   127  			}
   128  			// keep the destination value, for now
   129  		// TODO(directxman12): Default -- use field?
   130  		// TODO(directxman12):
   131  		// - Dependencies: if field x is present, then either schema validates or all props are present
   132  		// - AdditionalItems: like AdditionalProperties
   133  		// - Definitions: common named validation sets that can be references (merge, bail if duplicate)
   134  		case "AdditionalProperties":
   135  			// as of the time of writing, `allows: false` is not allowed, so we don't have to handle it
   136  			srcProps := srcInt.(*apiext.JSONSchemaPropsOrBool)
   137  			if srcProps.Schema == nil {
   138  				// nothing to merge
   139  				continue
   140  			}
   141  			dstProps := dstInt.(*apiext.JSONSchemaPropsOrBool)
   142  			if dstProps.Schema == nil {
   143  				dstProps.Schema = &apiext.JSONSchemaProps{}
   144  			}
   145  			flattenAllOfInto(dstProps.Schema, *srcProps.Schema, errRec)
   146  		// NB(directxman12): no need to explicitly handle nullable -- false is considered to be the zero value
   147  		// TODO(directxman12): src isn't necessarily the field value -- it's just the most recent allOf entry
   148  		default:
   149  			// hoist into allOf...
   150  			hoisted = true
   151  
   152  			srcRemVal.Field(i).Set(srcField)
   153  			dstRemVal.Field(i).Set(dstField)
   154  			// ...and clear the original
   155  			dstField.Set(zeroVal)
   156  		}
   157  	}
   158  
   159  	if hoisted {
   160  		dst.AllOf = append(dst.AllOf, dstRemainder, srcRemainder)
   161  	}
   162  
   163  	// dedup required
   164  	if len(dst.Required) > 0 {
   165  		reqUniq := make(map[string]struct{})
   166  		for _, req := range dst.Required {
   167  			reqUniq[req] = struct{}{}
   168  		}
   169  		dst.Required = make([]string, 0, len(reqUniq))
   170  		for req := range reqUniq {
   171  			dst.Required = append(dst.Required, req)
   172  		}
   173  		// be deterministic
   174  		sort.Strings(dst.Required)
   175  	}
   176  }
   177  
   178  // allOfVisitor recursively visits allOf fields in the schema,
   179  // merging nested allOf properties into the root schema.
   180  type allOfVisitor struct {
   181  	// errRec is used to record errors while flattening (like two conflicting
   182  	// field values used in an allOf)
   183  	errRec ErrorRecorder
   184  }
   185  
   186  func (v *allOfVisitor) Visit(schema *apiext.JSONSchemaProps) SchemaVisitor {
   187  	if schema == nil {
   188  		return v
   189  	}
   190  
   191  	// clear this now so that we can safely preserve edits made my flattenAllOfInto
   192  	origAllOf := schema.AllOf
   193  	schema.AllOf = nil
   194  
   195  	for _, embedded := range origAllOf {
   196  		flattenAllOfInto(schema, embedded, v.errRec)
   197  	}
   198  	return v
   199  }
   200  
   201  // NB(directxman12): FlattenEmbedded is separate from Flattener because
   202  // some tooling wants to flatten out embedded fields, but only actually
   203  // flatten a few specific types first.
   204  
   205  // FlattenEmbedded flattens embedded fields (represented via AllOf) which have
   206  // already had their references resolved into simple properties in the containing
   207  // schema.
   208  func FlattenEmbedded(schema *apiext.JSONSchemaProps, errRec ErrorRecorder) *apiext.JSONSchemaProps {
   209  	outSchema := schema.DeepCopy()
   210  	EditSchema(outSchema, &allOfVisitor{errRec: errRec})
   211  	return outSchema
   212  }
   213  
   214  // Flattener knows how to take a root type, and flatten all references in it
   215  // into a single, flat type.  Flattened types are cached, so it's relatively
   216  // cheap to make repeated calls with the same type.
   217  type Flattener struct {
   218  	// Parser is used to lookup package and type details, and parse in new packages.
   219  	Parser *Parser
   220  
   221  	LookupReference func(ref string, contextPkg *loader.Package) (TypeIdent, error)
   222  
   223  	// flattenedTypes hold the flattened version of each seen type for later reuse.
   224  	flattenedTypes map[TypeIdent]apiext.JSONSchemaProps
   225  	initOnce       sync.Once
   226  }
   227  
   228  func (f *Flattener) init() {
   229  	f.initOnce.Do(func() {
   230  		f.flattenedTypes = make(map[TypeIdent]apiext.JSONSchemaProps)
   231  		if f.LookupReference == nil {
   232  			f.LookupReference = identFromRef
   233  		}
   234  	})
   235  }
   236  
   237  // cacheType saves the flattened version of the given type for later reuse
   238  func (f *Flattener) cacheType(typ TypeIdent, schema apiext.JSONSchemaProps) {
   239  	f.init()
   240  	f.flattenedTypes[typ] = schema
   241  }
   242  
   243  // loadUnflattenedSchema fetches a fresh, unflattened schema from the parser.
   244  func (f *Flattener) loadUnflattenedSchema(typ TypeIdent) (*apiext.JSONSchemaProps, error) {
   245  	f.Parser.NeedSchemaFor(typ)
   246  
   247  	baseSchema, found := f.Parser.Schemata[typ]
   248  	if !found {
   249  		return nil, fmt.Errorf("unable to locate schema for type %s", typ)
   250  	}
   251  	return &baseSchema, nil
   252  }
   253  
   254  // FlattenType flattens the given pre-loaded type, removing any references from it.
   255  // It deep-copies the schema first, so it won't affect the parser's version of the schema.
   256  func (f *Flattener) FlattenType(typ TypeIdent) *apiext.JSONSchemaProps {
   257  	f.init()
   258  	if cachedSchema, isCached := f.flattenedTypes[typ]; isCached {
   259  		return &cachedSchema
   260  	}
   261  	baseSchema, err := f.loadUnflattenedSchema(typ)
   262  	if err != nil {
   263  		typ.Package.AddError(err)
   264  		return nil
   265  	}
   266  	resSchema := f.FlattenSchema(*baseSchema, typ.Package)
   267  	f.cacheType(typ, *resSchema)
   268  	return resSchema
   269  }
   270  
   271  // FlattenSchema flattens the given schema, removing any references.
   272  // It deep-copies the schema first, so the input schema won't be affected.
   273  func (f *Flattener) FlattenSchema(baseSchema apiext.JSONSchemaProps, currentPackage *loader.Package) *apiext.JSONSchemaProps {
   274  	resSchema := baseSchema.DeepCopy()
   275  	EditSchema(resSchema, &flattenVisitor{
   276  		Flattener:      f,
   277  		currentPackage: currentPackage,
   278  	})
   279  
   280  	return resSchema
   281  }
   282  
   283  // RefParts splits a reference produced by the schema generator into its component
   284  // type name and package name (if it's a cross-package reference).  Note that
   285  // referenced packages *must* be looked up relative to the current package.
   286  func RefParts(ref string) (typ string, pkgName string, err error) {
   287  	if !strings.HasPrefix(ref, defPrefix) {
   288  		return "", "", fmt.Errorf("non-standard reference link %q", ref)
   289  	}
   290  	ref = ref[len(defPrefix):]
   291  	// decode the json pointer encodings
   292  	ref = strings.Replace(ref, "~1", "/", -1)
   293  	ref = strings.Replace(ref, "~0", "~", -1)
   294  	nameParts := strings.SplitN(ref, "~", 2)
   295  
   296  	if len(nameParts) == 1 {
   297  		// local reference
   298  		return nameParts[0], "", nil
   299  	}
   300  	// cross-package reference
   301  	return nameParts[1], nameParts[0], nil
   302  }
   303  
   304  // identFromRef converts the given schema ref from the given package back
   305  // into the TypeIdent that it represents.
   306  func identFromRef(ref string, contextPkg *loader.Package) (TypeIdent, error) {
   307  	typ, pkgName, err := RefParts(ref)
   308  	if err != nil {
   309  		return TypeIdent{}, err
   310  	}
   311  
   312  	if pkgName == "" {
   313  		// a local reference
   314  		return TypeIdent{
   315  			Name:    typ,
   316  			Package: contextPkg,
   317  		}, nil
   318  	}
   319  
   320  	// an external reference
   321  	return TypeIdent{
   322  		Name:    typ,
   323  		Package: contextPkg.Imports()[pkgName],
   324  	}, nil
   325  }
   326  
   327  // preserveFields copies documentation fields from src into dst, preserving
   328  // field-level documentation when flattening, and preserving field-level validation
   329  // as allOf entries.
   330  func preserveFields(dst *apiext.JSONSchemaProps, src apiext.JSONSchemaProps) {
   331  	srcDesc := src.Description
   332  	srcTitle := src.Title
   333  	srcExDoc := src.ExternalDocs
   334  	srcEx := src.Example
   335  
   336  	src.Description, src.Title, src.ExternalDocs, src.Example = "", "", nil, nil
   337  
   338  	src.Ref = nil
   339  	*dst = apiext.JSONSchemaProps{
   340  		AllOf: []apiext.JSONSchemaProps{*dst, src},
   341  
   342  		// keep these, in case the source field doesn't specify anything useful
   343  		Description:  dst.Description,
   344  		Title:        dst.Title,
   345  		ExternalDocs: dst.ExternalDocs,
   346  		Example:      dst.Example,
   347  	}
   348  
   349  	if srcDesc != "" {
   350  		dst.Description = srcDesc
   351  	}
   352  	if srcTitle != "" {
   353  		dst.Title = srcTitle
   354  	}
   355  	if srcExDoc != nil {
   356  		dst.ExternalDocs = srcExDoc
   357  	}
   358  	if srcEx != nil {
   359  		dst.Example = srcEx
   360  	}
   361  }
   362  
   363  // flattenVisitor visits each node in the schema, recursively flattening references.
   364  type flattenVisitor struct {
   365  	*Flattener
   366  
   367  	currentPackage *loader.Package
   368  	currentType    *TypeIdent
   369  	currentSchema  *apiext.JSONSchemaProps
   370  	originalField  apiext.JSONSchemaProps
   371  }
   372  
   373  func (f *flattenVisitor) Visit(baseSchema *apiext.JSONSchemaProps) SchemaVisitor {
   374  	if baseSchema == nil {
   375  		// end-of-node marker, cache the results
   376  		if f.currentType != nil {
   377  			f.cacheType(*f.currentType, *f.currentSchema)
   378  			// preserve field information *after* caching so that we don't
   379  			// accidentally cache field-level information onto the schema for
   380  			// the type in general.
   381  			preserveFields(f.currentSchema, f.originalField)
   382  		}
   383  		return f
   384  	}
   385  
   386  	// if we get a type that's just a ref, resolve it
   387  	if baseSchema.Ref != nil && len(*baseSchema.Ref) > 0 {
   388  		// resolve this ref
   389  		refIdent, err := f.LookupReference(*baseSchema.Ref, f.currentPackage)
   390  		if err != nil {
   391  			f.currentPackage.AddError(err)
   392  			return nil
   393  		}
   394  
   395  		// load and potentially flatten the schema
   396  
   397  		// check the cache first...
   398  		if refSchemaCached, isCached := f.flattenedTypes[refIdent]; isCached {
   399  			// shallow copy is fine, it's just to avoid overwriting the doc fields
   400  			preserveFields(&refSchemaCached, *baseSchema)
   401  			*baseSchema = refSchemaCached
   402  			return nil // don't recurse, we're done
   403  		}
   404  
   405  		// ...otherwise, we need to flatten
   406  		refSchema, err := f.loadUnflattenedSchema(refIdent)
   407  		if err != nil {
   408  			f.currentPackage.AddError(err)
   409  			return nil
   410  		}
   411  		refSchema = refSchema.DeepCopy()
   412  
   413  		// keep field around to preserve field-level validation, docs, etc
   414  		origField := *baseSchema
   415  		*baseSchema = *refSchema
   416  
   417  		// avoid loops (which shouldn't exist, but just in case)
   418  		// by marking a nil cached pointer before we start recursing
   419  		f.cacheType(refIdent, apiext.JSONSchemaProps{})
   420  
   421  		return &flattenVisitor{
   422  			Flattener: f.Flattener,
   423  
   424  			currentPackage: refIdent.Package,
   425  			currentType:    &refIdent,
   426  			currentSchema:  baseSchema,
   427  			originalField:  origField,
   428  		}
   429  	}
   430  
   431  	// otherwise, continue recursing...
   432  	if f.currentType != nil {
   433  		// ...but don't accidentally end this node early (for caching purposes)
   434  		return &flattenVisitor{
   435  			Flattener:      f.Flattener,
   436  			currentPackage: f.currentPackage,
   437  		}
   438  	}
   439  
   440  	return f
   441  }