k8s.io/kube-openapi@v0.0.0-20240228011516-70dd3763d340/pkg/generators/union.go (about)

     1  /*
     2  Copyright 2016 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package generators
    18  
    19  import (
    20  	"fmt"
    21  	"sort"
    22  
    23  	"k8s.io/gengo/v2"
    24  	"k8s.io/gengo/v2/types"
    25  )
    26  
    27  const tagUnionMember = "union"
    28  const tagUnionDeprecated = "unionDeprecated"
    29  const tagUnionDiscriminator = "unionDiscriminator"
    30  
    31  type union struct {
    32  	discriminator         string
    33  	fieldsToDiscriminated map[string]string
    34  }
    35  
    36  // emit prints the union, can be called on a nil union (emits nothing)
    37  func (u *union) emit(g openAPITypeWriter) {
    38  	if u == nil {
    39  		return
    40  	}
    41  	g.Do("map[string]interface{}{\n", nil)
    42  	if u.discriminator != "" {
    43  		g.Do("\"discriminator\": \"$.$\",\n", u.discriminator)
    44  	}
    45  	g.Do("\"fields-to-discriminateBy\": map[string]interface{}{\n", nil)
    46  	keys := []string{}
    47  	for field := range u.fieldsToDiscriminated {
    48  		keys = append(keys, field)
    49  	}
    50  	sort.Strings(keys)
    51  	for _, field := range keys {
    52  		g.Do("\"$.$\": ", field)
    53  		g.Do("\"$.$\",\n", u.fieldsToDiscriminated[field])
    54  	}
    55  	g.Do("},\n", nil)
    56  	g.Do("},\n", nil)
    57  }
    58  
    59  // Sets the discriminator if it's not set yet, otherwise return an error
    60  func (u *union) setDiscriminator(value string) []error {
    61  	errors := []error{}
    62  	if u.discriminator != "" {
    63  		errors = append(errors, fmt.Errorf("at least two discriminators found: %v and %v", value, u.discriminator))
    64  	}
    65  	u.discriminator = value
    66  	return errors
    67  }
    68  
    69  // Add a new member to the union
    70  func (u *union) addMember(jsonName, variableName string) {
    71  	if _, ok := u.fieldsToDiscriminated[jsonName]; ok {
    72  		panic(fmt.Errorf("same field (%v) found multiple times", jsonName))
    73  	}
    74  	u.fieldsToDiscriminated[jsonName] = variableName
    75  }
    76  
    77  // Makes sure that the union is valid, specifically looking for re-used discriminated
    78  func (u *union) isValid() []error {
    79  	errors := []error{}
    80  	// Case 1: discriminator but no fields
    81  	if u.discriminator != "" && len(u.fieldsToDiscriminated) == 0 {
    82  		errors = append(errors, fmt.Errorf("discriminator set with no fields in union"))
    83  	}
    84  	// Case 2: two fields have the same discriminated value
    85  	discriminated := map[string]struct{}{}
    86  	for _, d := range u.fieldsToDiscriminated {
    87  		if _, ok := discriminated[d]; ok {
    88  			errors = append(errors, fmt.Errorf("discriminated value is used twice: %v", d))
    89  		}
    90  		discriminated[d] = struct{}{}
    91  	}
    92  	// Case 3: a field is both discriminator AND part of the union
    93  	if u.discriminator != "" {
    94  		if _, ok := u.fieldsToDiscriminated[u.discriminator]; ok {
    95  			errors = append(errors, fmt.Errorf("%v can't be both discriminator and part of the union", u.discriminator))
    96  		}
    97  	}
    98  	return errors
    99  }
   100  
   101  // Find unions either directly on the members (or inlined members, not
   102  // going across types) or on the type itself, or on embedded types.
   103  func parseUnions(t *types.Type) ([]union, []error) {
   104  	errors := []error{}
   105  	unions := []union{}
   106  	su, err := parseUnionStruct(t)
   107  	if su != nil {
   108  		unions = append(unions, *su)
   109  	}
   110  	errors = append(errors, err...)
   111  	eu, err := parseEmbeddedUnion(t)
   112  	unions = append(unions, eu...)
   113  	errors = append(errors, err...)
   114  	mu, err := parseUnionMembers(t)
   115  	if mu != nil {
   116  		unions = append(unions, *mu)
   117  	}
   118  	errors = append(errors, err...)
   119  	return unions, errors
   120  }
   121  
   122  // Find unions in embedded types, unions shouldn't go across types.
   123  func parseEmbeddedUnion(t *types.Type) ([]union, []error) {
   124  	errors := []error{}
   125  	unions := []union{}
   126  	for _, m := range t.Members {
   127  		if hasOpenAPITagValue(m.CommentLines, tagValueFalse) {
   128  			continue
   129  		}
   130  		if !shouldInlineMembers(&m) {
   131  			continue
   132  		}
   133  		u, err := parseUnions(m.Type)
   134  		unions = append(unions, u...)
   135  		errors = append(errors, err...)
   136  	}
   137  	return unions, errors
   138  }
   139  
   140  // Look for union tag on a struct, and then include all the fields
   141  // (except the discriminator if there is one). The struct shouldn't have
   142  // embedded types.
   143  func parseUnionStruct(t *types.Type) (*union, []error) {
   144  	errors := []error{}
   145  	if gengo.ExtractCommentTags("+", t.CommentLines)[tagUnionMember] == nil {
   146  		return nil, nil
   147  	}
   148  
   149  	u := &union{fieldsToDiscriminated: map[string]string{}}
   150  
   151  	for _, m := range t.Members {
   152  		jsonName := getReferableName(&m)
   153  		if jsonName == "" {
   154  			continue
   155  		}
   156  		if shouldInlineMembers(&m) {
   157  			errors = append(errors, fmt.Errorf("union structures can't have embedded fields: %v.%v", t.Name, m.Name))
   158  			continue
   159  		}
   160  		if gengo.ExtractCommentTags("+", m.CommentLines)[tagUnionDeprecated] != nil {
   161  			errors = append(errors, fmt.Errorf("union struct can't have unionDeprecated members: %v.%v", t.Name, m.Name))
   162  			continue
   163  		}
   164  		if gengo.ExtractCommentTags("+", m.CommentLines)[tagUnionDiscriminator] != nil {
   165  			errors = append(errors, u.setDiscriminator(jsonName)...)
   166  		} else {
   167  			if optional, err := isOptional(&m); !optional || err != nil {
   168  				errors = append(errors, fmt.Errorf("union members must be optional: %v.%v", t.Name, m.Name))
   169  			}
   170  			u.addMember(jsonName, m.Name)
   171  		}
   172  	}
   173  
   174  	return u, errors
   175  }
   176  
   177  // Find unions specifically on members.
   178  func parseUnionMembers(t *types.Type) (*union, []error) {
   179  	errors := []error{}
   180  	u := &union{fieldsToDiscriminated: map[string]string{}}
   181  
   182  	for _, m := range t.Members {
   183  		jsonName := getReferableName(&m)
   184  		if jsonName == "" {
   185  			continue
   186  		}
   187  		if shouldInlineMembers(&m) {
   188  			continue
   189  		}
   190  		if gengo.ExtractCommentTags("+", m.CommentLines)[tagUnionDiscriminator] != nil {
   191  			errors = append(errors, u.setDiscriminator(jsonName)...)
   192  		}
   193  		if gengo.ExtractCommentTags("+", m.CommentLines)[tagUnionMember] != nil {
   194  			errors = append(errors, fmt.Errorf("union tag is not accepted on struct members: %v.%v", t.Name, m.Name))
   195  			continue
   196  		}
   197  		if gengo.ExtractCommentTags("+", m.CommentLines)[tagUnionDeprecated] != nil {
   198  			if optional, err := isOptional(&m); !optional || err != nil {
   199  				errors = append(errors, fmt.Errorf("union members must be optional: %v.%v", t.Name, m.Name))
   200  			}
   201  			u.addMember(jsonName, m.Name)
   202  		}
   203  	}
   204  	if len(u.fieldsToDiscriminated) == 0 {
   205  		return nil, nil
   206  	}
   207  	return u, append(errors, u.isValid()...)
   208  }