istio.io/istio@v0.0.0-20240520182934-d79c90f27776/pkg/test/framework/components/echo/check/checkers.go (about)

     1  //  Copyright Istio Authors
     2  //
     3  //  Licensed under the Apache License, Version 2.0 (the "License");
     4  //  you may not use this file except in compliance with the License.
     5  //  You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  //  Unless required by applicable law or agreed to in writing, software
    10  //  distributed under the License is distributed on an "AS IS" BASIS,
    11  //  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  //  See the License for the specific language governing permissions and
    13  //  limitations under the License.
    14  
    15  package check
    16  
    17  import (
    18  	"errors"
    19  	"fmt"
    20  	"net/http"
    21  	"strconv"
    22  	"strings"
    23  
    24  	"github.com/hashicorp/go-multierror"
    25  	"google.golang.org/grpc/codes"
    26  
    27  	"istio.io/istio/pkg/config/protocol"
    28  	echoClient "istio.io/istio/pkg/test/echo"
    29  	"istio.io/istio/pkg/test/framework"
    30  	"istio.io/istio/pkg/test/framework/components/cluster"
    31  	"istio.io/istio/pkg/test/framework/components/echo"
    32  	"istio.io/istio/pkg/test/framework/components/istio"
    33  	"istio.io/istio/pkg/test/framework/components/istio/ingress"
    34  	"istio.io/istio/pkg/util/istiomultierror"
    35  )
    36  
    37  // Each applies the given per-response function across all responses.
    38  func Each(v Visitor) echo.Checker {
    39  	return v.Checker()
    40  }
    41  
    42  // And is an aggregate Checker that requires all Checkers succeed. Any nil Checkers are ignored.
    43  func And(checkers ...echo.Checker) echo.Checker {
    44  	return func(result echo.CallResult, err error) error {
    45  		for _, c := range filterNil(checkers) {
    46  			if err := c(result, err); err != nil {
    47  				return err
    48  			}
    49  		}
    50  		return nil
    51  	}
    52  }
    53  
    54  // Or is an aggregate Checker that requires at least one Checker succeeds.
    55  func Or(checkers ...echo.Checker) echo.Checker {
    56  	return func(result echo.CallResult, err error) error {
    57  		out := istiomultierror.New()
    58  		for _, c := range checkers {
    59  			err := c(result, err)
    60  			if err == nil {
    61  				return nil
    62  			}
    63  			out = multierror.Append(out, err)
    64  		}
    65  		return out.ErrorOrNil()
    66  	}
    67  }
    68  
    69  func filterNil(checkers []echo.Checker) []echo.Checker {
    70  	var out []echo.Checker
    71  	for _, c := range checkers {
    72  		if c != nil {
    73  			out = append(out, c)
    74  		}
    75  	}
    76  	return out
    77  }
    78  
    79  // NoError is similar to echo.NoChecker, but provides additional context information.
    80  func NoError() echo.Checker {
    81  	return func(_ echo.CallResult, err error) error {
    82  		if err != nil {
    83  			return fmt.Errorf("expected no error, but encountered %v", err)
    84  		}
    85  		return nil
    86  	}
    87  }
    88  
    89  // Error provides a checker that returns an error if the call succeeds.
    90  func Error() echo.Checker {
    91  	return func(_ echo.CallResult, err error) error {
    92  		if err == nil {
    93  			return errors.New("expected error, but none occurred")
    94  		}
    95  		return nil
    96  	}
    97  }
    98  
    99  // ErrorContains is similar to Error, but checks that the error message contains the given string.
   100  func ErrorContains(expected string) echo.Checker {
   101  	return func(_ echo.CallResult, err error) error {
   102  		if err == nil {
   103  			return errors.New("expected error, but none occurred")
   104  		}
   105  		if !strings.Contains(err.Error(), expected) {
   106  			return fmt.Errorf("expected error to contain %s: %v", expected, err)
   107  		}
   108  		return nil
   109  	}
   110  }
   111  
   112  func ErrorOrStatus(expected int) echo.Checker {
   113  	return Or(Error(), Status(expected))
   114  }
   115  
   116  func ErrorOrNotStatus(expected int) echo.Checker {
   117  	return Or(Error(), NotStatus(expected))
   118  }
   119  
   120  // OK is shorthand for NoErrorAndStatus(200).
   121  func OK() echo.Checker {
   122  	return NoErrorAndStatus(http.StatusOK)
   123  }
   124  
   125  // NotOK is shorthand for ErrorOrNotStatus(http.StatusOK).
   126  func NotOK() echo.Checker {
   127  	return ErrorOrNotStatus(http.StatusOK)
   128  }
   129  
   130  // NoErrorAndStatus is checks that no error occurred and that the returned status code matches the expected
   131  // value.
   132  func NoErrorAndStatus(expected int) echo.Checker {
   133  	return And(NoError(), Status(expected))
   134  }
   135  
   136  // Status checks that the response status code matches the expected value. If the expected value is zero,
   137  // checks that the response code is unset.
   138  func Status(expected int) echo.Checker {
   139  	return Each(VStatus(expected))
   140  }
   141  
   142  // NotStatus checks that the response status code does not match the expected value.
   143  func NotStatus(expected int) echo.Checker {
   144  	return Each(VNotStatus(expected))
   145  }
   146  
   147  // VStatus is a Visitor-based version of Status.
   148  func VStatus(expected int) Visitor {
   149  	expectedStr := ""
   150  	if expected > 0 {
   151  		expectedStr = strconv.Itoa(expected)
   152  	}
   153  	return func(r echoClient.Response) error {
   154  		if r.Code != expectedStr {
   155  			return fmt.Errorf("expected response code `%s`, got %q. Response: %s", expectedStr, r.Code, r)
   156  		}
   157  		return nil
   158  	}
   159  }
   160  
   161  // VNotStatus is a Visitor-based version of NotStatus.
   162  func VNotStatus(notExpected int) Visitor {
   163  	notExpectedStr := ""
   164  	if notExpected > 0 {
   165  		notExpectedStr = strconv.Itoa(notExpected)
   166  	}
   167  	return func(r echoClient.Response) error {
   168  		if r.Code == notExpectedStr {
   169  			return fmt.Errorf("received unexpected response code `%s`. Response: %s", notExpectedStr, r)
   170  		}
   171  		return nil
   172  	}
   173  }
   174  
   175  // GRPCStatus checks that the gRPC response status code matches the expected value.
   176  func GRPCStatus(expected codes.Code) echo.Checker {
   177  	return func(result echo.CallResult, err error) error {
   178  		if expected == codes.OK {
   179  			if err != nil {
   180  				return fmt.Errorf("unexpected error: %w", err)
   181  			}
   182  			return nil
   183  		}
   184  		if err == nil {
   185  			return fmt.Errorf("expected gRPC error with status %s, but got OK", expected.String())
   186  		}
   187  		expectedSubstr := fmt.Sprintf("code = %s", expected.String())
   188  		if strings.Contains(err.Error(), expectedSubstr) {
   189  			return nil
   190  		}
   191  		return fmt.Errorf("expected gRPC response code %q. Instead got: %w", expected.String(), err)
   192  	}
   193  }
   194  
   195  // BodyContains checks that the response body contains the given string.
   196  func BodyContains(expected string) echo.Checker {
   197  	return Each(func(r echoClient.Response) error {
   198  		if !strings.Contains(r.RawContent, expected) {
   199  			return fmt.Errorf("want %q in body but not found: %s", expected, r.RawContent)
   200  		}
   201  		return nil
   202  	})
   203  }
   204  
   205  // Forbidden checks that the response indicates that the request was rejected by RBAC.
   206  func Forbidden(p protocol.Instance) echo.Checker {
   207  	switch {
   208  	case p.IsGRPC():
   209  		return ErrorContains("rpc error: code = PermissionDenied")
   210  	case p.IsTCP():
   211  		return ErrorContains("EOF")
   212  	default:
   213  		return NoErrorAndStatus(http.StatusForbidden)
   214  	}
   215  }
   216  
   217  // TooManyRequests checks that at least one message receives a StatusTooManyRequests status code.
   218  func TooManyRequests() echo.Checker {
   219  	codeStr := strconv.Itoa(http.StatusTooManyRequests)
   220  	return func(result echo.CallResult, _ error) error {
   221  		for _, r := range result.Responses {
   222  			if codeStr == r.Code {
   223  				// Successfully received too many requests.
   224  				return nil
   225  			}
   226  		}
   227  		return errors.New("no request received StatusTooManyRequest error")
   228  	}
   229  }
   230  
   231  func Host(expected string) echo.Checker {
   232  	return Each(func(r echoClient.Response) error {
   233  		if r.Host != expected {
   234  			return fmt.Errorf("expected host %s, received %s", expected, r.Host)
   235  		}
   236  		return nil
   237  	})
   238  }
   239  
   240  // Hostname checks the hostname the request landed on. This differs from Host which is the request we called.
   241  func Hostname(expected string) echo.Checker {
   242  	return Each(func(r echoClient.Response) error {
   243  		if r.Hostname != expected {
   244  			return fmt.Errorf("expected hostname %s, received %s", expected, r.Hostname)
   245  		}
   246  		return nil
   247  	})
   248  }
   249  
   250  func Protocol(expected string) echo.Checker {
   251  	return Each(func(r echoClient.Response) error {
   252  		if r.Protocol != expected {
   253  			return fmt.Errorf("expected protocol %s, received %s", expected, r.Protocol)
   254  		}
   255  		return nil
   256  	})
   257  }
   258  
   259  func Alpn(expected string) echo.Checker {
   260  	return Each(func(r echoClient.Response) error {
   261  		if r.Alpn != expected {
   262  			return fmt.Errorf("expected alpn %s, received %s", expected, r.Alpn)
   263  		}
   264  		return nil
   265  	})
   266  }
   267  
   268  func isHTTPProtocol(r echoClient.Response) bool {
   269  	return strings.HasPrefix(r.RequestURL, "http://") ||
   270  		strings.HasPrefix(r.RequestURL, "grpc://") ||
   271  		strings.HasPrefix(r.RequestURL, "ws://")
   272  }
   273  
   274  func isMTLS(r echoClient.Response) bool {
   275  	_, f1 := r.RequestHeaders["X-Forwarded-Client-Cert"]
   276  	// nolint: staticcheck
   277  	_, f2 := r.RequestHeaders["x-forwarded-client-cert"] // grpc has different casing
   278  	return f1 || f2
   279  }
   280  
   281  func MTLSForHTTP() echo.Checker {
   282  	return Each(func(r echoClient.Response) error {
   283  		if !isHTTPProtocol(r) {
   284  			// Non-HTTP traffic. Fail open, we cannot check mTLS.
   285  			return nil
   286  		}
   287  		if isMTLS(r) {
   288  			return nil
   289  		}
   290  		return fmt.Errorf("expected X-Forwarded-Client-Cert but not found: %v", r)
   291  	})
   292  }
   293  
   294  func PlaintextForHTTP() echo.Checker {
   295  	return Each(func(r echoClient.Response) error {
   296  		if !isHTTPProtocol(r) {
   297  			// Non-HTTP traffic. Fail open, we cannot check mTLS.
   298  			return nil
   299  		}
   300  		if !isMTLS(r) {
   301  			return nil
   302  		}
   303  		return fmt.Errorf("expected plaintext but found X-Forwarded-Client-Cert header: %v", r)
   304  	})
   305  }
   306  
   307  func Port(expected int) echo.Checker {
   308  	return Each(func(r echoClient.Response) error {
   309  		expectedStr := strconv.Itoa(expected)
   310  		if r.Port != expectedStr {
   311  			return fmt.Errorf("expected port %s, received %s", expectedStr, r.Port)
   312  		}
   313  		return nil
   314  	})
   315  }
   316  
   317  func requestHeader(r echoClient.Response, key, expected string) error {
   318  	actual := r.RequestHeaders.Get(key)
   319  	if actual != expected {
   320  		return fmt.Errorf("request header %s: expected `%s`, received `%s`", key, expected, actual)
   321  	}
   322  	return nil
   323  }
   324  
   325  func responseHeader(r echoClient.Response, key, expected string) error {
   326  	actual := r.ResponseHeaders.Get(key)
   327  	if actual != expected {
   328  		return fmt.Errorf("response header %s: expected `%s`, received `%s`", key, expected, actual)
   329  	}
   330  	return nil
   331  }
   332  
   333  func RequestHeader(key, expected string) echo.Checker {
   334  	return Each(func(r echoClient.Response) error {
   335  		return requestHeader(r, key, expected)
   336  	})
   337  }
   338  
   339  func ResponseHeader(key, expected string) echo.Checker {
   340  	return Each(func(r echoClient.Response) error {
   341  		return responseHeader(r, key, expected)
   342  	})
   343  }
   344  
   345  func RequestHeaders(expected map[string]string) echo.Checker {
   346  	return Each(func(r echoClient.Response) error {
   347  		outErr := istiomultierror.New()
   348  		for k, v := range expected {
   349  			outErr = multierror.Append(outErr, requestHeader(r, k, v))
   350  		}
   351  		return outErr.ErrorOrNil()
   352  	})
   353  }
   354  
   355  func ResponseHeaders(expected map[string]string) echo.Checker {
   356  	return Each(func(r echoClient.Response) error {
   357  		outErr := istiomultierror.New()
   358  		for k, v := range expected {
   359  			outErr = multierror.Append(outErr, responseHeader(r, k, v))
   360  		}
   361  		return outErr.ErrorOrNil()
   362  	})
   363  }
   364  
   365  func Cluster(expected string) echo.Checker {
   366  	return Each(func(r echoClient.Response) error {
   367  		if r.Cluster != expected {
   368  			return fmt.Errorf("expected cluster %s, received %s", expected, r.Cluster)
   369  		}
   370  		return nil
   371  	})
   372  }
   373  
   374  func URL(expected string) echo.Checker {
   375  	return Each(func(r echoClient.Response) error {
   376  		if r.URL != expected {
   377  			return fmt.Errorf("expected URL %s, received %s", expected, r.URL)
   378  		}
   379  		return nil
   380  	})
   381  }
   382  
   383  func IsDNSCaptureEnabled(t framework.TestContext) bool {
   384  	t.Helper()
   385  	mc := istio.GetOrFail(t, t).MeshConfigOrFail(t)
   386  	if mc.DefaultConfig != nil && mc.DefaultConfig.ProxyMetadata != nil {
   387  		return mc.DefaultConfig.ProxyMetadata["ISTIO_META_DNS_CAPTURE"] == "true"
   388  	}
   389  	return false
   390  }
   391  
   392  // ReachedTargetClusters is similar to ReachedClusters, except that the set of expected clusters is
   393  // retrieved from the Target of the request.
   394  func ReachedTargetClusters(t framework.TestContext) echo.Checker {
   395  	dnsCaptureEnabled := IsDNSCaptureEnabled(t)
   396  	return func(result echo.CallResult, err error) error {
   397  		from := result.From
   398  		to := result.Opts.To
   399  		if from == nil || to == nil {
   400  			// We need src and target in order to determine which clusters should be reached.
   401  			return nil
   402  		}
   403  
   404  		if result.Opts.Count < to.Clusters().Len() {
   405  			// There weren't enough calls to hit all the target clusters. Don't bother
   406  			// checking which clusters were reached.
   407  			return nil
   408  		}
   409  
   410  		allClusters := t.Clusters()
   411  		if isNaked(from) {
   412  			// Naked clients rely on k8s DNS to lookup endpoint IPs. This
   413  			// means that they will only ever reach endpoint in the same cluster.
   414  			return checkReachedSourceClusterOnly(result, allClusters)
   415  		}
   416  
   417  		if to.Config().IsAllNaked() {
   418  			// Requests to naked services will not cross network boundaries.
   419  			// Istio filters out cross-network endpoints.
   420  			return checkReachedSourceNetworkOnly(result, allClusters)
   421  		}
   422  
   423  		if !dnsCaptureEnabled && to.Config().IsHeadless() {
   424  			// Headless services rely on DNS resolution. If DNS capture is
   425  			// enabled, DNS will return all endpoints in the mesh, which will
   426  			// allow requests to go cross-cluster. Otherwise, k8s DNS will
   427  			// only return the endpoints within the same cluster as the source
   428  			// pod.
   429  			return checkReachedSourceClusterOnly(result, allClusters)
   430  		}
   431  
   432  		toClusters := to.Clusters()
   433  		return checkReachedClusters(result, allClusters, toClusters.ByNetwork())
   434  	}
   435  }
   436  
   437  // ReachedClusters returns an error if requests did not load balance as expected.
   438  //
   439  // For cases where all clusters are on the same network, verifies that each of the expected clusters was reached.
   440  //
   441  // For multi-network configurations, verifies the current (limited) Istio load balancing behavior when going through
   442  // a gateway. Ensures that all expected networks were reached, and that all clusters on the same network as the
   443  // client were reached.
   444  func ReachedClusters(allClusters cluster.Clusters, expectedClusters cluster.Clusters) echo.Checker {
   445  	expectedByNetwork := expectedClusters.ByNetwork()
   446  	return func(result echo.CallResult, err error) error {
   447  		return checkReachedClusters(result, allClusters, expectedByNetwork)
   448  	}
   449  }
   450  
   451  // ReachedSourceCluster is similar to ReachedClusters, except it only checks the reachability of source cluster only
   452  func ReachedSourceCluster(allClusters cluster.Clusters) echo.Checker {
   453  	return func(result echo.CallResult, err error) error {
   454  		return checkReachedSourceClusterOnly(result, allClusters)
   455  	}
   456  }
   457  
   458  // checkReachedSourceClusterOnly verifies that the only cluster that was reached is the cluster where
   459  // the source workload resides.
   460  func checkReachedSourceClusterOnly(result echo.CallResult, allClusters cluster.Clusters) error {
   461  	from := result.From
   462  	to := result.Opts.To
   463  	if from == nil || to == nil {
   464  		return nil
   465  	}
   466  
   467  	fromCluster := clusterFor(from)
   468  	if fromCluster == nil {
   469  		return nil
   470  	}
   471  
   472  	if !to.Clusters().Contains(fromCluster) {
   473  		// The target is not deployed in the same cluster as the source. Skip this check.
   474  		return nil
   475  	}
   476  
   477  	return checkReachedClusters(result, allClusters, cluster.ClustersByNetwork{
   478  		// Use the source network of the caller.
   479  		fromCluster.NetworkName(): cluster.Clusters{fromCluster},
   480  	})
   481  }
   482  
   483  // checkReachedSourceNetworkOnly verifies that the only network that was reached is the network where
   484  // the source workload resides.
   485  func checkReachedSourceNetworkOnly(result echo.CallResult, allClusters cluster.Clusters) error {
   486  	fromCluster := clusterFor(result.From)
   487  	if fromCluster == nil {
   488  		return nil
   489  	}
   490  
   491  	toClusters := result.Opts.To.Clusters()
   492  	expectedByNetwork := toClusters.ForNetworks(fromCluster.NetworkName()).ByNetwork()
   493  	return checkReachedClusters(result, allClusters, expectedByNetwork)
   494  }
   495  
   496  func checkReachedClusters(result echo.CallResult, allClusters cluster.Clusters, expectedByNetwork cluster.ClustersByNetwork) error {
   497  	if err := checkReachedNetworks(result, allClusters, expectedByNetwork); err != nil {
   498  		return err
   499  	}
   500  	return checkReachedClustersInNetwork(result, allClusters, expectedByNetwork)
   501  }
   502  
   503  func checkReachedNetworks(result echo.CallResult, allClusters cluster.Clusters, expectedByNetwork cluster.ClustersByNetwork) error {
   504  	// Gather the networks that were reached.
   505  	networkHits := make(map[string]int)
   506  	for _, rr := range result.Responses {
   507  		c := allClusters.GetByName(rr.Cluster)
   508  		if c != nil {
   509  			networkHits[c.NetworkName()]++
   510  		}
   511  	}
   512  
   513  	// Verify that all expected networks were reached.
   514  	for network := range expectedByNetwork {
   515  		if networkHits[network] == 0 {
   516  			return fmt.Errorf("did not reach network %v, got %v", network, networkHits)
   517  		}
   518  	}
   519  
   520  	// Verify that no unexpected networks were reached.
   521  	for network := range networkHits {
   522  		if expectedByNetwork[network] == nil {
   523  			return fmt.Errorf("reached network not in %v, got %v", expectedByNetwork.Networks(), networkHits)
   524  		}
   525  	}
   526  	return nil
   527  }
   528  
   529  func isNaked(c echo.Caller) bool {
   530  	if c != nil {
   531  		if inst, ok := c.(echo.Instance); ok {
   532  			return inst.Config().IsNaked()
   533  		}
   534  	}
   535  	return false
   536  }
   537  
   538  func clusterFor(c echo.Caller) cluster.Cluster {
   539  	if c != nil {
   540  		// Determine the source network of the caller.
   541  		switch from := c.(type) {
   542  		case echo.Instance:
   543  			return from.Config().Cluster
   544  		case ingress.Instance:
   545  			return from.Cluster()
   546  		}
   547  	}
   548  
   549  	// Unable to determine the source network of the caller. Skip this check.
   550  	return nil
   551  }
   552  
   553  func checkReachedClustersInNetwork(result echo.CallResult, allClusters cluster.Clusters, expectedByNetwork cluster.ClustersByNetwork) error {
   554  	fromCluster := clusterFor(result.From)
   555  	if fromCluster == nil {
   556  		return nil
   557  	}
   558  	fromNetwork := fromCluster.NetworkName()
   559  
   560  	// Lookup only the expected clusters in the same network as the caller.
   561  	expectedClustersInSourceNetwork := expectedByNetwork[fromNetwork]
   562  
   563  	clusterHits := make(map[string]int)
   564  	for _, rr := range result.Responses {
   565  		clusterHits[rr.Cluster]++
   566  	}
   567  
   568  	for _, c := range expectedClustersInSourceNetwork {
   569  		if clusterHits[c.Name()] == 0 {
   570  			return fmt.Errorf("did not reach all of %v in source network %v, got %v",
   571  				expectedClustersInSourceNetwork, fromNetwork, clusterHits)
   572  		}
   573  	}
   574  
   575  	// Verify that no unexpected clusters were reached.
   576  	for clusterName := range clusterHits {
   577  		reachedCluster := allClusters.GetByName(clusterName)
   578  		if reachedCluster == nil || reachedCluster.NetworkName() != fromNetwork {
   579  			// Ignore clusters on a different network from the source.
   580  			continue
   581  		}
   582  
   583  		if expectedClustersInSourceNetwork.GetByName(clusterName) == nil {
   584  			return fmt.Errorf("reached cluster %v in source network %v not in %v, got %v",
   585  				clusterName, fromNetwork, expectedClustersInSourceNetwork, clusterHits)
   586  		}
   587  	}
   588  	return nil
   589  }