github.com/nginxinc/kubernetes-ingress@v1.12.5/pkg/apis/configuration/validation/transportserver.go (about) 1 package validation 2 3 import ( 4 "encoding/hex" 5 "fmt" 6 "regexp" 7 "strings" 8 9 "github.com/nginxinc/kubernetes-ingress/pkg/apis/configuration/v1alpha1" 10 "k8s.io/apimachinery/pkg/util/sets" 11 "k8s.io/apimachinery/pkg/util/validation" 12 "k8s.io/apimachinery/pkg/util/validation/field" 13 ) 14 15 // TransportServerValidator validates a TransportServer resource. 16 type TransportServerValidator struct { 17 tlsPassthrough bool 18 snippetsEnabled bool 19 isPlus bool 20 } 21 22 // NewTransportServerValidator creates a new TransportServerValidator. 23 func NewTransportServerValidator(tlsPassthrough bool, snippetsEnabled bool, isPlus bool) *TransportServerValidator { 24 return &TransportServerValidator{ 25 tlsPassthrough: tlsPassthrough, 26 snippetsEnabled: snippetsEnabled, 27 isPlus: isPlus, 28 } 29 } 30 31 // ValidateTransportServer validates a TransportServer. 32 func (tsv *TransportServerValidator) ValidateTransportServer(transportServer *v1alpha1.TransportServer) error { 33 allErrs := tsv.validateTransportServerSpec(&transportServer.Spec, field.NewPath("spec")) 34 return allErrs.ToAggregate() 35 } 36 37 func (tsv *TransportServerValidator) validateTransportServerSpec(spec *v1alpha1.TransportServerSpec, fieldPath *field.Path) field.ErrorList { 38 allErrs := field.ErrorList{} 39 40 allErrs = append(allErrs, tsv.validateTransportListener(&spec.Listener, fieldPath.Child("listener"))...) 41 42 isTLSPassthroughListener := isPotentialTLSPassthroughListener(&spec.Listener) 43 allErrs = append(allErrs, validateTransportServerHost(spec.Host, fieldPath.Child("host"), isTLSPassthroughListener)...) 44 45 upstreamErrs, upstreamNames := validateTransportServerUpstreams(spec.Upstreams, fieldPath.Child("upstreams"), tsv.isPlus) 46 allErrs = append(allErrs, upstreamErrs...) 47 48 allErrs = append(allErrs, validateTransportServerUpstreamParameters(spec.UpstreamParameters, fieldPath.Child("upstreamParameters"), spec.Listener.Protocol)...) 49 50 allErrs = append(allErrs, validateSessionParameters(spec.SessionParameters, fieldPath.Child("sessionParameters"))...) 51 52 if spec.Action == nil { 53 allErrs = append(allErrs, field.Required(fieldPath.Child("action"), "must specify action")) 54 } else { 55 allErrs = append(allErrs, validateTransportServerAction(spec.Action, fieldPath.Child("action"), upstreamNames)...) 56 } 57 58 allErrs = append(allErrs, validateSnippets(spec.ServerSnippets, fieldPath.Child("serverSnippets"), tsv.snippetsEnabled)...) 59 60 allErrs = append(allErrs, validateSnippets(spec.StreamSnippets, fieldPath.Child("streamSnippets"), tsv.snippetsEnabled)...) 61 62 return allErrs 63 } 64 65 func validateSnippets(serverSnippet string, fieldPath *field.Path, snippetsEnabled bool) field.ErrorList { 66 allErrs := field.ErrorList{} 67 if !snippetsEnabled && serverSnippet != "" { 68 return append(allErrs, field.Forbidden(fieldPath, "snippet specified but snippets feature is not enabled")) 69 } 70 71 return allErrs 72 } 73 74 func validateTransportServerHost(host string, fieldPath *field.Path, isTLSPassthroughListener bool) field.ErrorList { 75 allErrs := field.ErrorList{} 76 77 if !isTLSPassthroughListener { 78 if host != "" { 79 return append(allErrs, field.Forbidden(fieldPath, "host field is allowed only for TLS Passthrough TransportServers")) 80 } 81 return allErrs 82 } 83 84 return validateHost(host, fieldPath) 85 } 86 87 func (tsv *TransportServerValidator) validateTransportListener(listener *v1alpha1.TransportServerListener, fieldPath *field.Path) field.ErrorList { 88 if isPotentialTLSPassthroughListener(listener) { 89 return tsv.validateTLSPassthroughListener(listener, fieldPath) 90 } 91 92 return validateRegularListener(listener, fieldPath) 93 } 94 95 func validateRegularListener(listener *v1alpha1.TransportServerListener, fieldPath *field.Path) field.ErrorList { 96 allErrs := field.ErrorList{} 97 98 allErrs = append(allErrs, validateListenerName(listener.Name, fieldPath.Child("name"))...) 99 allErrs = append(allErrs, validateListenerProtocol(listener.Protocol, fieldPath.Child("protocol"))...) 100 101 return allErrs 102 } 103 104 func isPotentialTLSPassthroughListener(listener *v1alpha1.TransportServerListener) bool { 105 return listener.Name == v1alpha1.TLSPassthroughListenerName || listener.Protocol == v1alpha1.TLSPassthroughListenerProtocol 106 } 107 108 func (tsv *TransportServerValidator) validateTLSPassthroughListener(listener *v1alpha1.TransportServerListener, fieldPath *field.Path) field.ErrorList { 109 allErrs := field.ErrorList{} 110 111 if !tsv.tlsPassthrough { 112 return append(allErrs, field.Forbidden(fieldPath, "TLS Passthrough is not enabled")) 113 } 114 115 if listener.Name == v1alpha1.TLSPassthroughListenerName && listener.Protocol != v1alpha1.TLSPassthroughListenerProtocol { 116 msg := fmt.Sprintf("must be '%s' for the built-in %s listener", v1alpha1.TLSPassthroughListenerProtocol, v1alpha1.TLSPassthroughListenerName) 117 return append(allErrs, field.Invalid(fieldPath.Child("protocol"), listener.Protocol, msg)) 118 } 119 120 if listener.Protocol == v1alpha1.TLSPassthroughListenerProtocol && listener.Name != v1alpha1.TLSPassthroughListenerName { 121 msg := fmt.Sprintf("must be '%s' for a listener with the protocol %s", v1alpha1.TLSPassthroughListenerName, v1alpha1.TLSPassthroughListenerProtocol) 122 return append(allErrs, field.Invalid(fieldPath.Child("name"), listener.Name, msg)) 123 } 124 125 return allErrs 126 } 127 128 func validateListenerName(name string, fieldPath *field.Path) field.ErrorList { 129 return validateDNS1035Label(name, fieldPath) 130 } 131 132 // listenerProtocols defines the protocols supported by a listener. 133 var listenerProtocols = map[string]bool{ 134 "TCP": true, 135 "UDP": true, 136 } 137 138 func validateListenerProtocol(protocol string, fieldPath *field.Path) field.ErrorList { 139 allErrs := field.ErrorList{} 140 141 if protocol == "" { 142 msg := fmt.Sprintf("must specify protocol. Accepted values: %s", mapToPrettyString(listenerProtocols)) 143 return append(allErrs, field.Required(fieldPath, msg)) 144 } 145 146 if !listenerProtocols[protocol] { 147 msg := fmt.Sprintf("invalid protocol. Accepted values: %s", mapToPrettyString(listenerProtocols)) 148 allErrs = append(allErrs, field.Invalid(fieldPath, protocol, msg)) 149 } 150 151 return allErrs 152 } 153 154 func validateTransportServerUpstreams(upstreams []v1alpha1.Upstream, fieldPath *field.Path, isPlus bool) (allErrs field.ErrorList, upstreamNames sets.String) { 155 allErrs = field.ErrorList{} 156 upstreamNames = sets.String{} 157 158 for i, u := range upstreams { 159 idxPath := fieldPath.Index(i) 160 161 upstreamErrors := validateUpstreamName(u.Name, idxPath.Child("name")) 162 if len(upstreamErrors) > 0 { 163 allErrs = append(allErrs, upstreamErrors...) 164 } else if upstreamNames.Has(u.Name) { 165 allErrs = append(allErrs, field.Duplicate(idxPath.Child("name"), u.Name)) 166 } else { 167 upstreamNames.Insert(u.Name) 168 } 169 170 allErrs = append(allErrs, validateServiceName(u.Service, idxPath.Child("service"))...) 171 allErrs = append(allErrs, validatePositiveIntOrZeroFromPointer(u.MaxFails, idxPath.Child("maxFails"))...) 172 allErrs = append(allErrs, validatePositiveIntOrZeroFromPointer(u.MaxFails, idxPath.Child("maxConns"))...) 173 allErrs = append(allErrs, validateTime(u.FailTimeout, idxPath.Child("failTimeout"))...) 174 175 for _, msg := range validation.IsValidPortNum(u.Port) { 176 allErrs = append(allErrs, field.Invalid(idxPath.Child("port"), u.Port, msg)) 177 } 178 179 allErrs = append(allErrs, validateTSUpstreamHealthChecks(u.HealthCheck, idxPath.Child("healthChecks"))...) 180 181 allErrs = append(allErrs, validateLoadBalancingMethod(u.LoadBalancingMethod, idxPath.Child("loadBalancingMethod"), isPlus)...) 182 } 183 184 return allErrs, upstreamNames 185 } 186 187 func validateLoadBalancingMethod(method string, fieldPath *field.Path, isPlus bool) field.ErrorList { 188 allErrs := field.ErrorList{} 189 if method == "" { 190 return allErrs 191 } 192 193 method = strings.TrimSpace(method) 194 195 if strings.HasPrefix(method, "hash") { 196 return validateHashLoadBalancingMethod(method, fieldPath, isPlus) 197 } 198 199 validMethodValues := nginxStreamLoadBalanceValidInput 200 if isPlus { 201 validMethodValues = nginxPlusStreamLoadBalanceValidInput 202 } 203 204 if _, exists := validMethodValues[method]; !exists { 205 return append(allErrs, field.Invalid(fieldPath, method, fmt.Sprintf("load balancing method is not valid: %v", method))) 206 } 207 208 return allErrs 209 } 210 211 var nginxStreamLoadBalanceValidInput = map[string]bool{ 212 "round_robin": true, 213 "least_conn": true, 214 "random": true, 215 "random two": true, 216 "random two least_conn": true, 217 } 218 219 var nginxPlusStreamLoadBalanceValidInput = map[string]bool{ 220 "round_robin": true, 221 "least_conn": true, 222 "random": true, 223 "random two": true, 224 "random two least_conn": true, 225 "random least_conn": true, 226 "least_time connect": true, 227 "least_time first_byte": true, 228 "least_time last_byte": true, 229 "least_time last_byte inflight": true, 230 } 231 232 var loadBalancingVariables = map[string]bool{ 233 "remote_addr": true, 234 } 235 236 var hashMethodRegexp = regexp.MustCompile(`^hash (\S+)(?: consistent)?$`) 237 238 func validateHashLoadBalancingMethod(method string, fieldPath *field.Path, isPlus bool) field.ErrorList { 239 allErrs := field.ErrorList{} 240 matches := hashMethodRegexp.FindStringSubmatch(method) 241 if len(matches) != 2 { 242 msg := fmt.Sprintf("invalid value for load balancing method: %v", method) 243 return append(allErrs, field.Invalid(fieldPath, method, msg)) 244 } 245 246 hashKey := matches[1] 247 if strings.Contains(hashKey, "$") { 248 varErrs := validateStringWithVariables(hashKey, fieldPath, []string{}, loadBalancingVariables, isPlus) 249 allErrs = append(allErrs, varErrs...) 250 } 251 252 if !escapedStringsFmtRegexp.MatchString(method) { 253 msg := fmt.Sprintf("invalid value for hash: %v", validation.RegexError(escapedStringsErrMsg, escapedStringsFmt)) 254 return append(allErrs, field.Invalid(fieldPath, method, msg)) 255 } 256 257 return allErrs 258 } 259 260 func validateTSUpstreamHealthChecks(hc *v1alpha1.HealthCheck, fieldPath *field.Path) field.ErrorList { 261 allErrs := field.ErrorList{} 262 263 if hc == nil { 264 return allErrs 265 } 266 267 allErrs = append(allErrs, validateTime(hc.Timeout, fieldPath.Child("timeout"))...) 268 allErrs = append(allErrs, validateTime(hc.Interval, fieldPath.Child("interval"))...) 269 allErrs = append(allErrs, validateTime(hc.Jitter, fieldPath.Child("jitter"))...) 270 allErrs = append(allErrs, validatePositiveIntOrZero(hc.Fails, fieldPath.Child("fails"))...) 271 allErrs = append(allErrs, validatePositiveIntOrZero(hc.Passes, fieldPath.Child("passes"))...) 272 273 if hc.Port > 0 { 274 for _, msg := range validation.IsValidPortNum(hc.Port) { 275 allErrs = append(allErrs, field.Invalid(fieldPath.Child("port"), hc.Port, msg)) 276 } 277 } 278 279 allErrs = append(allErrs, validateHealthCheckMatch(hc.Match, fieldPath.Child("match"))...) 280 281 return allErrs 282 } 283 284 func validateHealthCheckMatch(match *v1alpha1.Match, fieldPath *field.Path) field.ErrorList { 285 allErrs := field.ErrorList{} 286 if match == nil { 287 return allErrs 288 } 289 allErrs = append(allErrs, validateMatchExpect(match.Expect, fieldPath.Child("expect"))...) 290 allErrs = append(allErrs, validateMatchSend(match.Expect, fieldPath.Child("send"))...) 291 return allErrs 292 } 293 294 func validateMatchExpect(expect string, fieldPath *field.Path) field.ErrorList { 295 allErrs := field.ErrorList{} 296 if expect == "" { 297 return allErrs 298 } 299 300 if !escapedStringsFmtRegexp.MatchString(expect) { 301 msg := validation.RegexError(escapedStringsErrMsg, escapedStringsFmt) 302 return append(allErrs, field.Invalid(fieldPath, expect, msg)) 303 } 304 305 if strings.HasPrefix(expect, "~") { 306 var expr string 307 if strings.HasPrefix(expect, "~*") { 308 expr = strings.TrimPrefix(expect, "~*") 309 } else { 310 expr = strings.TrimPrefix(expect, "~") 311 } 312 313 // compile also validates hex literals 314 if _, err := regexp.Compile(expr); err != nil { 315 return append(allErrs, field.Invalid(fieldPath, expr, fmt.Sprintf("must be a valid regular expression: %v", err))) 316 } 317 } else { 318 if err := validateHexString(expect); err != nil { 319 return append(allErrs, field.Invalid(fieldPath, expect, err.Error())) 320 } 321 } 322 323 return allErrs 324 } 325 326 func validateMatchSend(send string, fieldPath *field.Path) field.ErrorList { 327 allErrs := field.ErrorList{} 328 if send == "" { 329 return allErrs 330 } 331 if !escapedStringsFmtRegexp.MatchString(send) { 332 msg := validation.RegexError(escapedStringsErrMsg, escapedStringsFmt) 333 return append(allErrs, field.Invalid(fieldPath, send, msg)) 334 } 335 336 err := validateHexString(send) 337 if err != nil { 338 return append(allErrs, field.Invalid(fieldPath, send, err.Error())) 339 } 340 341 return allErrs 342 } 343 344 var hexLiteralRegexp = regexp.MustCompile(`\\x(.{0,2})`) 345 346 func validateHexString(s string) error { 347 literals := hexLiteralRegexp.FindAllStringSubmatch(s, -1) 348 for _, match := range literals { 349 lit := match[0] 350 digits := match[1] 351 352 if len(digits) != 2 { 353 return fmt.Errorf("hex literal '%s' must contain two hex digits", lit) 354 } 355 356 _, err := hex.DecodeString(digits) 357 if err != nil { 358 return fmt.Errorf("hex literal '%s' must contain two hex digits: %w", lit, err) 359 } 360 } 361 362 return nil 363 } 364 365 func validateTransportServerUpstreamParameters(upstreamParameters *v1alpha1.UpstreamParameters, fieldPath *field.Path, protocol string) field.ErrorList { 366 allErrs := field.ErrorList{} 367 368 if upstreamParameters == nil { 369 return allErrs 370 } 371 372 allErrs = append(allErrs, validateUDPUpstreamParameter(upstreamParameters.UDPRequests, fieldPath.Child("udpRequests"), protocol)...) 373 allErrs = append(allErrs, validateUDPUpstreamParameter(upstreamParameters.UDPResponses, fieldPath.Child("udpResponses"), protocol)...) 374 allErrs = append(allErrs, validateTime(upstreamParameters.ConnectTimeout, fieldPath.Child("connectTimeout"))...) 375 allErrs = append(allErrs, validateTime(upstreamParameters.NextUpstreamTimeout, fieldPath.Child("nextUpstreamTimeout"))...) 376 allErrs = append(allErrs, validatePositiveIntOrZero(upstreamParameters.NextUpstreamTries, fieldPath.Child("nextUpstreamTries"))...) 377 378 return allErrs 379 } 380 381 func validateSessionParameters(sessionParameters *v1alpha1.SessionParameters, fieldPath *field.Path) field.ErrorList { 382 allErrs := field.ErrorList{} 383 384 if sessionParameters == nil { 385 return allErrs 386 } 387 388 allErrs = append(allErrs, validateTime(sessionParameters.Timeout, fieldPath.Child("timeout"))...) 389 390 return allErrs 391 } 392 393 func validateUDPUpstreamParameter(parameter *int, fieldPath *field.Path, protocol string) field.ErrorList { 394 allErrs := field.ErrorList{} 395 396 if parameter != nil && protocol != "UDP" { 397 return append(allErrs, field.Forbidden(fieldPath, "is not allowed for non-UDP TransportServers")) 398 } 399 400 return validatePositiveIntOrZeroFromPointer(parameter, fieldPath) 401 } 402 403 func validateTransportServerAction(action *v1alpha1.Action, fieldPath *field.Path, upstreamNames sets.String) field.ErrorList { 404 allErrs := field.ErrorList{} 405 406 if action.Pass == "" { 407 return append(allErrs, field.Required(fieldPath, "must specify pass")) 408 } 409 410 return validateReferencedUpstream(action.Pass, fieldPath.Child("pass"), upstreamNames) 411 }