k8s.io/kubernetes@v1.29.3/test/e2e/storage/drivers/csi-test/driver/driver.go (about)

     1  /*
     2  Copyright 2021 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  //go:generate mockgen -package=driver -destination=driver.mock.go -build_flags=-mod=mod github.com/container-storage-interface/spec/lib/go/csi IdentityServer,ControllerServer,NodeServer
    18  
    19  package driver
    20  
    21  import (
    22  	"context"
    23  	"encoding/json"
    24  	"errors"
    25  	"net"
    26  	"sync"
    27  
    28  	"google.golang.org/grpc/codes"
    29  	"google.golang.org/grpc/status"
    30  	"k8s.io/klog/v2"
    31  
    32  	"github.com/container-storage-interface/spec/lib/go/csi"
    33  	"google.golang.org/grpc"
    34  )
    35  
    36  var (
    37  	// ErrNoCredentials is the error when a secret is enabled but not passed in the request.
    38  	ErrNoCredentials = errors.New("secret must be provided")
    39  	// ErrAuthFailed is the error when the secret is incorrect.
    40  	ErrAuthFailed = errors.New("authentication failed")
    41  )
    42  
    43  // CSIDriverServers is a unified driver component with both Controller and Node
    44  // services.
    45  type CSIDriverServers struct {
    46  	Controller csi.ControllerServer
    47  	Identity   csi.IdentityServer
    48  	Node       csi.NodeServer
    49  }
    50  
    51  // This is the key name in all the CSI secret objects.
    52  const secretField = "secretKey"
    53  
    54  // CSICreds is a driver specific secret type. Drivers can have a key-val pair of
    55  // secrets. This mock driver has a single string secret with secretField as the
    56  // key.
    57  type CSICreds struct {
    58  	CreateVolumeSecret                         string
    59  	DeleteVolumeSecret                         string
    60  	ControllerPublishVolumeSecret              string
    61  	ControllerUnpublishVolumeSecret            string
    62  	NodeStageVolumeSecret                      string
    63  	NodePublishVolumeSecret                    string
    64  	CreateSnapshotSecret                       string
    65  	DeleteSnapshotSecret                       string
    66  	ControllerValidateVolumeCapabilitiesSecret string
    67  }
    68  
    69  type CSIDriver struct {
    70  	listener net.Listener
    71  	server   *grpc.Server
    72  	servers  *CSIDriverServers
    73  	wg       sync.WaitGroup
    74  	running  bool
    75  	lock     sync.Mutex
    76  	creds    *CSICreds
    77  	logGRPC  LogGRPC
    78  }
    79  
    80  type LogGRPC func(method string, request, reply interface{}, err error)
    81  
    82  func NewCSIDriver(servers *CSIDriverServers) *CSIDriver {
    83  	return &CSIDriver{
    84  		servers: servers,
    85  	}
    86  }
    87  
    88  func (c *CSIDriver) goServe(started chan<- bool) {
    89  	goServe(c.server, &c.wg, c.listener, started)
    90  }
    91  
    92  func (c *CSIDriver) Address() string {
    93  	return c.listener.Addr().String()
    94  }
    95  
    96  // Start runs a gRPC server with all enabled services. If an interceptor
    97  // is give, then it will be used. Otherwise, an interceptor which
    98  // handles simple credential checks and logs gRPC calls in JSON format
    99  // will be used.
   100  func (c *CSIDriver) Start(l net.Listener, interceptor grpc.UnaryServerInterceptor) error {
   101  	c.lock.Lock()
   102  	defer c.lock.Unlock()
   103  
   104  	// Set listener
   105  	c.listener = l
   106  
   107  	// Create a new grpc server
   108  	if interceptor == nil {
   109  		interceptor = c.callInterceptor
   110  	}
   111  	c.server = grpc.NewServer(grpc.UnaryInterceptor(interceptor))
   112  
   113  	// Register Mock servers
   114  	if c.servers.Controller != nil {
   115  		csi.RegisterControllerServer(c.server, c.servers.Controller)
   116  	}
   117  	if c.servers.Identity != nil {
   118  		csi.RegisterIdentityServer(c.server, c.servers.Identity)
   119  	}
   120  	if c.servers.Node != nil {
   121  		csi.RegisterNodeServer(c.server, c.servers.Node)
   122  	}
   123  
   124  	// Start listening for requests
   125  	waitForServer := make(chan bool)
   126  	c.goServe(waitForServer)
   127  	<-waitForServer
   128  	c.running = true
   129  	return nil
   130  }
   131  
   132  func (c *CSIDriver) Stop() {
   133  	stop(&c.lock, &c.wg, c.server, c.running)
   134  }
   135  
   136  func (c *CSIDriver) Close() {
   137  	c.server.Stop()
   138  }
   139  
   140  func (c *CSIDriver) IsRunning() bool {
   141  	c.lock.Lock()
   142  	defer c.lock.Unlock()
   143  
   144  	return c.running
   145  }
   146  
   147  // SetDefaultCreds sets the default secrets for CSI creds.
   148  func (c *CSIDriver) SetDefaultCreds() {
   149  	setDefaultCreds(c.creds)
   150  }
   151  
   152  // goServe starts a grpc server.
   153  func goServe(server *grpc.Server, wg *sync.WaitGroup, listener net.Listener, started chan<- bool) {
   154  	wg.Add(1)
   155  	go func() {
   156  		defer wg.Done()
   157  		started <- true
   158  		err := server.Serve(listener)
   159  		if err != nil {
   160  			klog.Infof("gRPC server for CSI driver stopped: %v", err)
   161  		}
   162  	}()
   163  }
   164  
   165  // stop stops a grpc server.
   166  func stop(lock *sync.Mutex, wg *sync.WaitGroup, server *grpc.Server, running bool) {
   167  	lock.Lock()
   168  	defer lock.Unlock()
   169  
   170  	if !running {
   171  		return
   172  	}
   173  
   174  	server.Stop()
   175  	wg.Wait()
   176  }
   177  
   178  // setDefaultCreds sets the default credentials, given a CSICreds instance.
   179  func setDefaultCreds(creds *CSICreds) {
   180  	*creds = CSICreds{
   181  		CreateVolumeSecret:                         "secretval1",
   182  		DeleteVolumeSecret:                         "secretval2",
   183  		ControllerPublishVolumeSecret:              "secretval3",
   184  		ControllerUnpublishVolumeSecret:            "secretval4",
   185  		NodeStageVolumeSecret:                      "secretval5",
   186  		NodePublishVolumeSecret:                    "secretval6",
   187  		CreateSnapshotSecret:                       "secretval7",
   188  		DeleteSnapshotSecret:                       "secretval8",
   189  		ControllerValidateVolumeCapabilitiesSecret: "secretval9",
   190  	}
   191  }
   192  
   193  func (c *CSIDriver) callInterceptor(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
   194  	err := authInterceptor(c.creds, req)
   195  	if err != nil {
   196  		logGRPC(info.FullMethod, req, nil, err)
   197  		return nil, err
   198  	}
   199  	rsp, err := handler(ctx, req)
   200  	logGRPC(info.FullMethod, req, rsp, err)
   201  	if c.logGRPC != nil {
   202  		c.logGRPC(info.FullMethod, req, rsp, err)
   203  	}
   204  	return rsp, err
   205  }
   206  
   207  func authInterceptor(creds *CSICreds, req interface{}) error {
   208  	if creds != nil {
   209  		authenticated, authErr := isAuthenticated(req, creds)
   210  		if !authenticated {
   211  			if authErr == ErrNoCredentials {
   212  				return status.Error(codes.InvalidArgument, authErr.Error())
   213  			}
   214  			if authErr == ErrAuthFailed {
   215  				return status.Error(codes.Unauthenticated, authErr.Error())
   216  			}
   217  		}
   218  	}
   219  	return nil
   220  }
   221  
   222  func logGRPC(method string, request, reply interface{}, err error) {
   223  	// Log JSON with the request and response for easier parsing
   224  	logMessage := struct {
   225  		Method   string
   226  		Request  interface{}
   227  		Response interface{}
   228  		// Error as string, for backward compatibility.
   229  		// "" on no error.
   230  		Error string
   231  		// Full error dump, to be able to parse out full gRPC error code and message separately in a test.
   232  		FullError error
   233  	}{
   234  		Method:    method,
   235  		Request:   request,
   236  		Response:  reply,
   237  		FullError: err,
   238  	}
   239  
   240  	if err != nil {
   241  		logMessage.Error = err.Error()
   242  	}
   243  
   244  	msg, _ := json.Marshal(logMessage)
   245  	klog.V(3).Infof("gRPCCall: %s\n", msg)
   246  }
   247  
   248  func isAuthenticated(req interface{}, creds *CSICreds) (bool, error) {
   249  	switch r := req.(type) {
   250  	case *csi.CreateVolumeRequest:
   251  		return authenticateCreateVolume(r, creds)
   252  	case *csi.DeleteVolumeRequest:
   253  		return authenticateDeleteVolume(r, creds)
   254  	case *csi.ControllerPublishVolumeRequest:
   255  		return authenticateControllerPublishVolume(r, creds)
   256  	case *csi.ControllerUnpublishVolumeRequest:
   257  		return authenticateControllerUnpublishVolume(r, creds)
   258  	case *csi.NodeStageVolumeRequest:
   259  		return authenticateNodeStageVolume(r, creds)
   260  	case *csi.NodePublishVolumeRequest:
   261  		return authenticateNodePublishVolume(r, creds)
   262  	case *csi.CreateSnapshotRequest:
   263  		return authenticateCreateSnapshot(r, creds)
   264  	case *csi.DeleteSnapshotRequest:
   265  		return authenticateDeleteSnapshot(r, creds)
   266  	case *csi.ValidateVolumeCapabilitiesRequest:
   267  		return authenticateControllerValidateVolumeCapabilities(r, creds)
   268  	default:
   269  		return true, nil
   270  	}
   271  }
   272  
   273  func authenticateCreateVolume(req *csi.CreateVolumeRequest, creds *CSICreds) (bool, error) {
   274  	return credsCheck(req.GetSecrets(), creds.CreateVolumeSecret)
   275  }
   276  
   277  func authenticateDeleteVolume(req *csi.DeleteVolumeRequest, creds *CSICreds) (bool, error) {
   278  	return credsCheck(req.GetSecrets(), creds.DeleteVolumeSecret)
   279  }
   280  
   281  func authenticateControllerPublishVolume(req *csi.ControllerPublishVolumeRequest, creds *CSICreds) (bool, error) {
   282  	return credsCheck(req.GetSecrets(), creds.ControllerPublishVolumeSecret)
   283  }
   284  
   285  func authenticateControllerUnpublishVolume(req *csi.ControllerUnpublishVolumeRequest, creds *CSICreds) (bool, error) {
   286  	return credsCheck(req.GetSecrets(), creds.ControllerUnpublishVolumeSecret)
   287  }
   288  
   289  func authenticateNodeStageVolume(req *csi.NodeStageVolumeRequest, creds *CSICreds) (bool, error) {
   290  	return credsCheck(req.GetSecrets(), creds.NodeStageVolumeSecret)
   291  }
   292  
   293  func authenticateNodePublishVolume(req *csi.NodePublishVolumeRequest, creds *CSICreds) (bool, error) {
   294  	return credsCheck(req.GetSecrets(), creds.NodePublishVolumeSecret)
   295  }
   296  
   297  func authenticateCreateSnapshot(req *csi.CreateSnapshotRequest, creds *CSICreds) (bool, error) {
   298  	return credsCheck(req.GetSecrets(), creds.CreateSnapshotSecret)
   299  }
   300  
   301  func authenticateDeleteSnapshot(req *csi.DeleteSnapshotRequest, creds *CSICreds) (bool, error) {
   302  	return credsCheck(req.GetSecrets(), creds.DeleteSnapshotSecret)
   303  }
   304  
   305  func authenticateControllerValidateVolumeCapabilities(req *csi.ValidateVolumeCapabilitiesRequest, creds *CSICreds) (bool, error) {
   306  	return credsCheck(req.GetSecrets(), creds.ControllerValidateVolumeCapabilitiesSecret)
   307  }
   308  
   309  func credsCheck(secrets map[string]string, secretVal string) (bool, error) {
   310  	if len(secrets) == 0 {
   311  		return false, ErrNoCredentials
   312  	}
   313  
   314  	if secrets[secretField] != secretVal {
   315  		return false, ErrAuthFailed
   316  	}
   317  	return true, nil
   318  }