gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/grpc/test/authority_test.go (about)

     1  //go:build linux
     2  // +build linux
     3  
     4  /*
     5   *
     6   * Copyright 2020 gRPC authors.
     7   *
     8   * Licensed under the Apache License, Version 2.0 (the "License");
     9   * you may not use this file except in compliance with the License.
    10   * You may obtain a copy of the License at
    11   *
    12   *     https://www.apache.org/licenses/LICENSE-2.0
    13   *
    14   * Unless required by applicable law or agreed to in writing, software
    15   * distributed under the License is distributed on an "AS IS" BASIS,
    16   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    17   * See the License for the specific language governing permissions and
    18   * limitations under the License.
    19   *
    20   */
    21  
    22  package test
    23  
    24  import (
    25  	"context"
    26  	"fmt"
    27  	"net"
    28  	"os"
    29  	"strings"
    30  	"sync"
    31  	"testing"
    32  	"time"
    33  
    34  	grpc "gitee.com/ks-custle/core-gm/grpc"
    35  	"gitee.com/ks-custle/core-gm/grpc/codes"
    36  	"gitee.com/ks-custle/core-gm/grpc/internal/stubserver"
    37  	"gitee.com/ks-custle/core-gm/grpc/metadata"
    38  	"gitee.com/ks-custle/core-gm/grpc/status"
    39  	testpb "gitee.com/ks-custle/core-gm/grpc/test/grpc_testing"
    40  )
    41  
    42  func authorityChecker(ctx context.Context, expectedAuthority string) (*testpb.Empty, error) {
    43  	md, ok := metadata.FromIncomingContext(ctx)
    44  	if !ok {
    45  		return nil, status.Error(codes.InvalidArgument, "failed to parse metadata")
    46  	}
    47  	auths, ok := md[":authority"]
    48  	if !ok {
    49  		return nil, status.Error(codes.InvalidArgument, "no authority header")
    50  	}
    51  	if len(auths) != 1 {
    52  		return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("no authority header, auths = %v", auths))
    53  	}
    54  	if auths[0] != expectedAuthority {
    55  		return nil, status.Error(codes.InvalidArgument, fmt.Sprintf("invalid authority header %v, expected %v", auths[0], expectedAuthority))
    56  	}
    57  	return &testpb.Empty{}, nil
    58  }
    59  
    60  func runUnixTest(t *testing.T, address, target, expectedAuthority string, dialer func(context.Context, string) (net.Conn, error)) {
    61  	if !strings.HasPrefix(target, "unix-abstract:") {
    62  		if err := os.RemoveAll(address); err != nil {
    63  			t.Fatalf("Error removing socket file %v: %v\n", address, err)
    64  		}
    65  	}
    66  	ss := &stubserver.StubServer{
    67  		EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
    68  			return authorityChecker(ctx, expectedAuthority)
    69  		},
    70  		Network: "unix",
    71  		Address: address,
    72  		Target:  target,
    73  	}
    74  	opts := []grpc.DialOption{}
    75  	if dialer != nil {
    76  		opts = append(opts, grpc.WithContextDialer(dialer))
    77  	}
    78  	if err := ss.Start(nil, opts...); err != nil {
    79  		t.Fatalf("Error starting endpoint server: %v", err)
    80  	}
    81  	defer ss.Stop()
    82  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
    83  	defer cancel()
    84  	_, err := ss.Client.EmptyCall(ctx, &testpb.Empty{})
    85  	if err != nil {
    86  		t.Errorf("us.client.EmptyCall(_, _) = _, %v; want _, nil", err)
    87  	}
    88  }
    89  
    90  type authorityTest struct {
    91  	name           string
    92  	address        string
    93  	target         string
    94  	authority      string
    95  	dialTargetWant string
    96  }
    97  
    98  var authorityTests = []authorityTest{
    99  	{
   100  		name:      "UnixRelative",
   101  		address:   "sock.sock",
   102  		target:    "unix:sock.sock",
   103  		authority: "localhost",
   104  	},
   105  	{
   106  		name:           "UnixAbsolute",
   107  		address:        "/tmp/sock.sock",
   108  		target:         "unix:/tmp/sock.sock",
   109  		authority:      "localhost",
   110  		dialTargetWant: "unix:///tmp/sock.sock",
   111  	},
   112  	{
   113  		name:      "UnixAbsoluteAlternate",
   114  		address:   "/tmp/sock.sock",
   115  		target:    "unix:///tmp/sock.sock",
   116  		authority: "localhost",
   117  	},
   118  	{
   119  		name:           "UnixPassthrough",
   120  		address:        "/tmp/sock.sock",
   121  		target:         "passthrough:///unix:///tmp/sock.sock",
   122  		authority:      "unix:///tmp/sock.sock",
   123  		dialTargetWant: "unix:///tmp/sock.sock",
   124  	},
   125  	{
   126  		name:           "UnixAbstract",
   127  		address:        "\x00abc efg",
   128  		target:         "unix-abstract:abc efg",
   129  		authority:      "localhost",
   130  		dialTargetWant: "\x00abc efg",
   131  	},
   132  }
   133  
   134  // TestUnix does end to end tests with the various supported unix target
   135  // formats, ensuring that the authority is set as expected.
   136  func (s) TestUnix(t *testing.T) {
   137  	for _, test := range authorityTests {
   138  		t.Run(test.name, func(t *testing.T) {
   139  			runUnixTest(t, test.address, test.target, test.authority, nil)
   140  		})
   141  	}
   142  }
   143  
   144  // TestUnixCustomDialer does end to end tests with various supported unix target
   145  // formats, ensuring that the target sent to the dialer does NOT have the
   146  // "unix:" prefix stripped.
   147  func (s) TestUnixCustomDialer(t *testing.T) {
   148  	for _, test := range authorityTests {
   149  		t.Run(test.name+"WithDialer", func(t *testing.T) {
   150  			if test.dialTargetWant == "" {
   151  				test.dialTargetWant = test.target
   152  			}
   153  			dialer := func(ctx context.Context, address string) (net.Conn, error) {
   154  				if address != test.dialTargetWant {
   155  					return nil, fmt.Errorf("expected target %v in custom dialer, instead got %v", test.dialTargetWant, address)
   156  				}
   157  				if !strings.HasPrefix(test.target, "unix-abstract:") {
   158  					address = address[len("unix:"):]
   159  				}
   160  				return (&net.Dialer{}).DialContext(ctx, "unix", address)
   161  			}
   162  			runUnixTest(t, test.address, test.target, test.authority, dialer)
   163  		})
   164  	}
   165  }
   166  
   167  // TestColonPortAuthority does an end to end test with the target for grpc.Dial
   168  // being ":[port]". Ensures authority is "localhost:[port]".
   169  func (s) TestColonPortAuthority(t *testing.T) {
   170  	expectedAuthority := ""
   171  	var authorityMu sync.Mutex
   172  	ss := &stubserver.StubServer{
   173  		EmptyCallF: func(ctx context.Context, _ *testpb.Empty) (*testpb.Empty, error) {
   174  			authorityMu.Lock()
   175  			defer authorityMu.Unlock()
   176  			return authorityChecker(ctx, expectedAuthority)
   177  		},
   178  		Network: "tcp",
   179  	}
   180  	if err := ss.Start(nil); err != nil {
   181  		t.Fatalf("Error starting endpoint server: %v", err)
   182  	}
   183  	defer ss.Stop()
   184  	_, port, err := net.SplitHostPort(ss.Address)
   185  	if err != nil {
   186  		t.Fatalf("Failed splitting host from post: %v", err)
   187  	}
   188  	authorityMu.Lock()
   189  	expectedAuthority = "localhost:" + port
   190  	authorityMu.Unlock()
   191  	// ss.Start dials, but not the ":[port]" target that is being tested here.
   192  	// Dial again, with ":[port]" as the target.
   193  	//
   194  	// Append "localhost" before calling net.Dial, in case net.Dial on certain
   195  	// platforms doesn't work well for address without the IP.
   196  	cc, err := grpc.Dial(":"+port, grpc.WithInsecure(), grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
   197  		return (&net.Dialer{}).DialContext(ctx, "tcp", "localhost"+addr)
   198  	}))
   199  	if err != nil {
   200  		t.Fatalf("grpc.Dial(%q) = %v", ss.Target, err)
   201  	}
   202  	defer cc.Close()
   203  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   204  	defer cancel()
   205  	_, err = testpb.NewTestServiceClient(cc).EmptyCall(ctx, &testpb.Empty{})
   206  	if err != nil {
   207  		t.Errorf("us.client.EmptyCall(_, _) = _, %v; want _, nil", err)
   208  	}
   209  }