github.com/nginxinc/kubernetes-ingress@v1.12.5/pkg/apis/configuration/validation/transportserver.go (about)

     1  package validation
     2  
     3  import (
     4  	"encoding/hex"
     5  	"fmt"
     6  	"regexp"
     7  	"strings"
     8  
     9  	"github.com/nginxinc/kubernetes-ingress/pkg/apis/configuration/v1alpha1"
    10  	"k8s.io/apimachinery/pkg/util/sets"
    11  	"k8s.io/apimachinery/pkg/util/validation"
    12  	"k8s.io/apimachinery/pkg/util/validation/field"
    13  )
    14  
    15  // TransportServerValidator validates a TransportServer resource.
    16  type TransportServerValidator struct {
    17  	tlsPassthrough  bool
    18  	snippetsEnabled bool
    19  	isPlus          bool
    20  }
    21  
    22  // NewTransportServerValidator creates a new TransportServerValidator.
    23  func NewTransportServerValidator(tlsPassthrough bool, snippetsEnabled bool, isPlus bool) *TransportServerValidator {
    24  	return &TransportServerValidator{
    25  		tlsPassthrough:  tlsPassthrough,
    26  		snippetsEnabled: snippetsEnabled,
    27  		isPlus:          isPlus,
    28  	}
    29  }
    30  
    31  // ValidateTransportServer validates a TransportServer.
    32  func (tsv *TransportServerValidator) ValidateTransportServer(transportServer *v1alpha1.TransportServer) error {
    33  	allErrs := tsv.validateTransportServerSpec(&transportServer.Spec, field.NewPath("spec"))
    34  	return allErrs.ToAggregate()
    35  }
    36  
    37  func (tsv *TransportServerValidator) validateTransportServerSpec(spec *v1alpha1.TransportServerSpec, fieldPath *field.Path) field.ErrorList {
    38  	allErrs := field.ErrorList{}
    39  
    40  	allErrs = append(allErrs, tsv.validateTransportListener(&spec.Listener, fieldPath.Child("listener"))...)
    41  
    42  	isTLSPassthroughListener := isPotentialTLSPassthroughListener(&spec.Listener)
    43  	allErrs = append(allErrs, validateTransportServerHost(spec.Host, fieldPath.Child("host"), isTLSPassthroughListener)...)
    44  
    45  	upstreamErrs, upstreamNames := validateTransportServerUpstreams(spec.Upstreams, fieldPath.Child("upstreams"), tsv.isPlus)
    46  	allErrs = append(allErrs, upstreamErrs...)
    47  
    48  	allErrs = append(allErrs, validateTransportServerUpstreamParameters(spec.UpstreamParameters, fieldPath.Child("upstreamParameters"), spec.Listener.Protocol)...)
    49  
    50  	allErrs = append(allErrs, validateSessionParameters(spec.SessionParameters, fieldPath.Child("sessionParameters"))...)
    51  
    52  	if spec.Action == nil {
    53  		allErrs = append(allErrs, field.Required(fieldPath.Child("action"), "must specify action"))
    54  	} else {
    55  		allErrs = append(allErrs, validateTransportServerAction(spec.Action, fieldPath.Child("action"), upstreamNames)...)
    56  	}
    57  
    58  	allErrs = append(allErrs, validateSnippets(spec.ServerSnippets, fieldPath.Child("serverSnippets"), tsv.snippetsEnabled)...)
    59  
    60  	allErrs = append(allErrs, validateSnippets(spec.StreamSnippets, fieldPath.Child("streamSnippets"), tsv.snippetsEnabled)...)
    61  
    62  	return allErrs
    63  }
    64  
    65  func validateSnippets(serverSnippet string, fieldPath *field.Path, snippetsEnabled bool) field.ErrorList {
    66  	allErrs := field.ErrorList{}
    67  	if !snippetsEnabled && serverSnippet != "" {
    68  		return append(allErrs, field.Forbidden(fieldPath, "snippet specified but snippets feature is not enabled"))
    69  	}
    70  
    71  	return allErrs
    72  }
    73  
    74  func validateTransportServerHost(host string, fieldPath *field.Path, isTLSPassthroughListener bool) field.ErrorList {
    75  	allErrs := field.ErrorList{}
    76  
    77  	if !isTLSPassthroughListener {
    78  		if host != "" {
    79  			return append(allErrs, field.Forbidden(fieldPath, "host field is allowed only for TLS Passthrough TransportServers"))
    80  		}
    81  		return allErrs
    82  	}
    83  
    84  	return validateHost(host, fieldPath)
    85  }
    86  
    87  func (tsv *TransportServerValidator) validateTransportListener(listener *v1alpha1.TransportServerListener, fieldPath *field.Path) field.ErrorList {
    88  	if isPotentialTLSPassthroughListener(listener) {
    89  		return tsv.validateTLSPassthroughListener(listener, fieldPath)
    90  	}
    91  
    92  	return validateRegularListener(listener, fieldPath)
    93  }
    94  
    95  func validateRegularListener(listener *v1alpha1.TransportServerListener, fieldPath *field.Path) field.ErrorList {
    96  	allErrs := field.ErrorList{}
    97  
    98  	allErrs = append(allErrs, validateListenerName(listener.Name, fieldPath.Child("name"))...)
    99  	allErrs = append(allErrs, validateListenerProtocol(listener.Protocol, fieldPath.Child("protocol"))...)
   100  
   101  	return allErrs
   102  }
   103  
   104  func isPotentialTLSPassthroughListener(listener *v1alpha1.TransportServerListener) bool {
   105  	return listener.Name == v1alpha1.TLSPassthroughListenerName || listener.Protocol == v1alpha1.TLSPassthroughListenerProtocol
   106  }
   107  
   108  func (tsv *TransportServerValidator) validateTLSPassthroughListener(listener *v1alpha1.TransportServerListener, fieldPath *field.Path) field.ErrorList {
   109  	allErrs := field.ErrorList{}
   110  
   111  	if !tsv.tlsPassthrough {
   112  		return append(allErrs, field.Forbidden(fieldPath, "TLS Passthrough is not enabled"))
   113  	}
   114  
   115  	if listener.Name == v1alpha1.TLSPassthroughListenerName && listener.Protocol != v1alpha1.TLSPassthroughListenerProtocol {
   116  		msg := fmt.Sprintf("must be '%s' for the built-in %s listener", v1alpha1.TLSPassthroughListenerProtocol, v1alpha1.TLSPassthroughListenerName)
   117  		return append(allErrs, field.Invalid(fieldPath.Child("protocol"), listener.Protocol, msg))
   118  	}
   119  
   120  	if listener.Protocol == v1alpha1.TLSPassthroughListenerProtocol && listener.Name != v1alpha1.TLSPassthroughListenerName {
   121  		msg := fmt.Sprintf("must be '%s' for a listener with the protocol %s", v1alpha1.TLSPassthroughListenerName, v1alpha1.TLSPassthroughListenerProtocol)
   122  		return append(allErrs, field.Invalid(fieldPath.Child("name"), listener.Name, msg))
   123  	}
   124  
   125  	return allErrs
   126  }
   127  
   128  func validateListenerName(name string, fieldPath *field.Path) field.ErrorList {
   129  	return validateDNS1035Label(name, fieldPath)
   130  }
   131  
   132  // listenerProtocols defines the protocols supported by a listener.
   133  var listenerProtocols = map[string]bool{
   134  	"TCP": true,
   135  	"UDP": true,
   136  }
   137  
   138  func validateListenerProtocol(protocol string, fieldPath *field.Path) field.ErrorList {
   139  	allErrs := field.ErrorList{}
   140  
   141  	if protocol == "" {
   142  		msg := fmt.Sprintf("must specify protocol. Accepted values: %s", mapToPrettyString(listenerProtocols))
   143  		return append(allErrs, field.Required(fieldPath, msg))
   144  	}
   145  
   146  	if !listenerProtocols[protocol] {
   147  		msg := fmt.Sprintf("invalid protocol. Accepted values: %s", mapToPrettyString(listenerProtocols))
   148  		allErrs = append(allErrs, field.Invalid(fieldPath, protocol, msg))
   149  	}
   150  
   151  	return allErrs
   152  }
   153  
   154  func validateTransportServerUpstreams(upstreams []v1alpha1.Upstream, fieldPath *field.Path, isPlus bool) (allErrs field.ErrorList, upstreamNames sets.String) {
   155  	allErrs = field.ErrorList{}
   156  	upstreamNames = sets.String{}
   157  
   158  	for i, u := range upstreams {
   159  		idxPath := fieldPath.Index(i)
   160  
   161  		upstreamErrors := validateUpstreamName(u.Name, idxPath.Child("name"))
   162  		if len(upstreamErrors) > 0 {
   163  			allErrs = append(allErrs, upstreamErrors...)
   164  		} else if upstreamNames.Has(u.Name) {
   165  			allErrs = append(allErrs, field.Duplicate(idxPath.Child("name"), u.Name))
   166  		} else {
   167  			upstreamNames.Insert(u.Name)
   168  		}
   169  
   170  		allErrs = append(allErrs, validateServiceName(u.Service, idxPath.Child("service"))...)
   171  		allErrs = append(allErrs, validatePositiveIntOrZeroFromPointer(u.MaxFails, idxPath.Child("maxFails"))...)
   172  		allErrs = append(allErrs, validatePositiveIntOrZeroFromPointer(u.MaxFails, idxPath.Child("maxConns"))...)
   173  		allErrs = append(allErrs, validateTime(u.FailTimeout, idxPath.Child("failTimeout"))...)
   174  
   175  		for _, msg := range validation.IsValidPortNum(u.Port) {
   176  			allErrs = append(allErrs, field.Invalid(idxPath.Child("port"), u.Port, msg))
   177  		}
   178  
   179  		allErrs = append(allErrs, validateTSUpstreamHealthChecks(u.HealthCheck, idxPath.Child("healthChecks"))...)
   180  
   181  		allErrs = append(allErrs, validateLoadBalancingMethod(u.LoadBalancingMethod, idxPath.Child("loadBalancingMethod"), isPlus)...)
   182  	}
   183  
   184  	return allErrs, upstreamNames
   185  }
   186  
   187  func validateLoadBalancingMethod(method string, fieldPath *field.Path, isPlus bool) field.ErrorList {
   188  	allErrs := field.ErrorList{}
   189  	if method == "" {
   190  		return allErrs
   191  	}
   192  
   193  	method = strings.TrimSpace(method)
   194  
   195  	if strings.HasPrefix(method, "hash") {
   196  		return validateHashLoadBalancingMethod(method, fieldPath, isPlus)
   197  	}
   198  
   199  	validMethodValues := nginxStreamLoadBalanceValidInput
   200  	if isPlus {
   201  		validMethodValues = nginxPlusStreamLoadBalanceValidInput
   202  	}
   203  
   204  	if _, exists := validMethodValues[method]; !exists {
   205  		return append(allErrs, field.Invalid(fieldPath, method, fmt.Sprintf("load balancing method is not valid: %v", method)))
   206  	}
   207  
   208  	return allErrs
   209  }
   210  
   211  var nginxStreamLoadBalanceValidInput = map[string]bool{
   212  	"round_robin":           true,
   213  	"least_conn":            true,
   214  	"random":                true,
   215  	"random two":            true,
   216  	"random two least_conn": true,
   217  }
   218  
   219  var nginxPlusStreamLoadBalanceValidInput = map[string]bool{
   220  	"round_robin":                   true,
   221  	"least_conn":                    true,
   222  	"random":                        true,
   223  	"random two":                    true,
   224  	"random two least_conn":         true,
   225  	"random least_conn":             true,
   226  	"least_time connect":            true,
   227  	"least_time first_byte":         true,
   228  	"least_time last_byte":          true,
   229  	"least_time last_byte inflight": true,
   230  }
   231  
   232  var loadBalancingVariables = map[string]bool{
   233  	"remote_addr": true,
   234  }
   235  
   236  var hashMethodRegexp = regexp.MustCompile(`^hash (\S+)(?: consistent)?$`)
   237  
   238  func validateHashLoadBalancingMethod(method string, fieldPath *field.Path, isPlus bool) field.ErrorList {
   239  	allErrs := field.ErrorList{}
   240  	matches := hashMethodRegexp.FindStringSubmatch(method)
   241  	if len(matches) != 2 {
   242  		msg := fmt.Sprintf("invalid value for load balancing method: %v", method)
   243  		return append(allErrs, field.Invalid(fieldPath, method, msg))
   244  	}
   245  
   246  	hashKey := matches[1]
   247  	if strings.Contains(hashKey, "$") {
   248  		varErrs := validateStringWithVariables(hashKey, fieldPath, []string{}, loadBalancingVariables, isPlus)
   249  		allErrs = append(allErrs, varErrs...)
   250  	}
   251  
   252  	if !escapedStringsFmtRegexp.MatchString(method) {
   253  		msg := fmt.Sprintf("invalid value for hash: %v", validation.RegexError(escapedStringsErrMsg, escapedStringsFmt))
   254  		return append(allErrs, field.Invalid(fieldPath, method, msg))
   255  	}
   256  
   257  	return allErrs
   258  }
   259  
   260  func validateTSUpstreamHealthChecks(hc *v1alpha1.HealthCheck, fieldPath *field.Path) field.ErrorList {
   261  	allErrs := field.ErrorList{}
   262  
   263  	if hc == nil {
   264  		return allErrs
   265  	}
   266  
   267  	allErrs = append(allErrs, validateTime(hc.Timeout, fieldPath.Child("timeout"))...)
   268  	allErrs = append(allErrs, validateTime(hc.Interval, fieldPath.Child("interval"))...)
   269  	allErrs = append(allErrs, validateTime(hc.Jitter, fieldPath.Child("jitter"))...)
   270  	allErrs = append(allErrs, validatePositiveIntOrZero(hc.Fails, fieldPath.Child("fails"))...)
   271  	allErrs = append(allErrs, validatePositiveIntOrZero(hc.Passes, fieldPath.Child("passes"))...)
   272  
   273  	if hc.Port > 0 {
   274  		for _, msg := range validation.IsValidPortNum(hc.Port) {
   275  			allErrs = append(allErrs, field.Invalid(fieldPath.Child("port"), hc.Port, msg))
   276  		}
   277  	}
   278  
   279  	allErrs = append(allErrs, validateHealthCheckMatch(hc.Match, fieldPath.Child("match"))...)
   280  
   281  	return allErrs
   282  }
   283  
   284  func validateHealthCheckMatch(match *v1alpha1.Match, fieldPath *field.Path) field.ErrorList {
   285  	allErrs := field.ErrorList{}
   286  	if match == nil {
   287  		return allErrs
   288  	}
   289  	allErrs = append(allErrs, validateMatchExpect(match.Expect, fieldPath.Child("expect"))...)
   290  	allErrs = append(allErrs, validateMatchSend(match.Expect, fieldPath.Child("send"))...)
   291  	return allErrs
   292  }
   293  
   294  func validateMatchExpect(expect string, fieldPath *field.Path) field.ErrorList {
   295  	allErrs := field.ErrorList{}
   296  	if expect == "" {
   297  		return allErrs
   298  	}
   299  
   300  	if !escapedStringsFmtRegexp.MatchString(expect) {
   301  		msg := validation.RegexError(escapedStringsErrMsg, escapedStringsFmt)
   302  		return append(allErrs, field.Invalid(fieldPath, expect, msg))
   303  	}
   304  
   305  	if strings.HasPrefix(expect, "~") {
   306  		var expr string
   307  		if strings.HasPrefix(expect, "~*") {
   308  			expr = strings.TrimPrefix(expect, "~*")
   309  		} else {
   310  			expr = strings.TrimPrefix(expect, "~")
   311  		}
   312  
   313  		// compile also validates hex literals
   314  		if _, err := regexp.Compile(expr); err != nil {
   315  			return append(allErrs, field.Invalid(fieldPath, expr, fmt.Sprintf("must be a valid regular expression: %v", err)))
   316  		}
   317  	} else {
   318  		if err := validateHexString(expect); err != nil {
   319  			return append(allErrs, field.Invalid(fieldPath, expect, err.Error()))
   320  		}
   321  	}
   322  
   323  	return allErrs
   324  }
   325  
   326  func validateMatchSend(send string, fieldPath *field.Path) field.ErrorList {
   327  	allErrs := field.ErrorList{}
   328  	if send == "" {
   329  		return allErrs
   330  	}
   331  	if !escapedStringsFmtRegexp.MatchString(send) {
   332  		msg := validation.RegexError(escapedStringsErrMsg, escapedStringsFmt)
   333  		return append(allErrs, field.Invalid(fieldPath, send, msg))
   334  	}
   335  
   336  	err := validateHexString(send)
   337  	if err != nil {
   338  		return append(allErrs, field.Invalid(fieldPath, send, err.Error()))
   339  	}
   340  
   341  	return allErrs
   342  }
   343  
   344  var hexLiteralRegexp = regexp.MustCompile(`\\x(.{0,2})`)
   345  
   346  func validateHexString(s string) error {
   347  	literals := hexLiteralRegexp.FindAllStringSubmatch(s, -1)
   348  	for _, match := range literals {
   349  		lit := match[0]
   350  		digits := match[1]
   351  
   352  		if len(digits) != 2 {
   353  			return fmt.Errorf("hex literal '%s' must contain two hex digits", lit)
   354  		}
   355  
   356  		_, err := hex.DecodeString(digits)
   357  		if err != nil {
   358  			return fmt.Errorf("hex literal '%s' must contain two hex digits: %w", lit, err)
   359  		}
   360  	}
   361  
   362  	return nil
   363  }
   364  
   365  func validateTransportServerUpstreamParameters(upstreamParameters *v1alpha1.UpstreamParameters, fieldPath *field.Path, protocol string) field.ErrorList {
   366  	allErrs := field.ErrorList{}
   367  
   368  	if upstreamParameters == nil {
   369  		return allErrs
   370  	}
   371  
   372  	allErrs = append(allErrs, validateUDPUpstreamParameter(upstreamParameters.UDPRequests, fieldPath.Child("udpRequests"), protocol)...)
   373  	allErrs = append(allErrs, validateUDPUpstreamParameter(upstreamParameters.UDPResponses, fieldPath.Child("udpResponses"), protocol)...)
   374  	allErrs = append(allErrs, validateTime(upstreamParameters.ConnectTimeout, fieldPath.Child("connectTimeout"))...)
   375  	allErrs = append(allErrs, validateTime(upstreamParameters.NextUpstreamTimeout, fieldPath.Child("nextUpstreamTimeout"))...)
   376  	allErrs = append(allErrs, validatePositiveIntOrZero(upstreamParameters.NextUpstreamTries, fieldPath.Child("nextUpstreamTries"))...)
   377  
   378  	return allErrs
   379  }
   380  
   381  func validateSessionParameters(sessionParameters *v1alpha1.SessionParameters, fieldPath *field.Path) field.ErrorList {
   382  	allErrs := field.ErrorList{}
   383  
   384  	if sessionParameters == nil {
   385  		return allErrs
   386  	}
   387  
   388  	allErrs = append(allErrs, validateTime(sessionParameters.Timeout, fieldPath.Child("timeout"))...)
   389  
   390  	return allErrs
   391  }
   392  
   393  func validateUDPUpstreamParameter(parameter *int, fieldPath *field.Path, protocol string) field.ErrorList {
   394  	allErrs := field.ErrorList{}
   395  
   396  	if parameter != nil && protocol != "UDP" {
   397  		return append(allErrs, field.Forbidden(fieldPath, "is not allowed for non-UDP TransportServers"))
   398  	}
   399  
   400  	return validatePositiveIntOrZeroFromPointer(parameter, fieldPath)
   401  }
   402  
   403  func validateTransportServerAction(action *v1alpha1.Action, fieldPath *field.Path, upstreamNames sets.String) field.ErrorList {
   404  	allErrs := field.ErrorList{}
   405  
   406  	if action.Pass == "" {
   407  		return append(allErrs, field.Required(fieldPath, "must specify pass"))
   408  	}
   409  
   410  	return validateReferencedUpstream(action.Pass, fieldPath.Child("pass"), upstreamNames)
   411  }