sigs.k8s.io/cluster-api@v1.7.1/exp/runtime/server/server.go (about)

     1  /*
     2  Copyright 2021 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  // Package server contains the implementation of a RuntimeSDK webhook server.
    18  package server
    19  
    20  import (
    21  	"context"
    22  	"crypto/tls"
    23  	"encoding/json"
    24  	"fmt"
    25  	"io"
    26  	"net/http"
    27  	"os"
    28  	"path/filepath"
    29  	"reflect"
    30  
    31  	"github.com/pkg/errors"
    32  	"k8s.io/apimachinery/pkg/runtime"
    33  	ctrl "sigs.k8s.io/controller-runtime"
    34  	"sigs.k8s.io/controller-runtime/pkg/log"
    35  	"sigs.k8s.io/controller-runtime/pkg/webhook"
    36  
    37  	runtimecatalog "sigs.k8s.io/cluster-api/exp/runtime/catalog"
    38  	runtimehooksv1 "sigs.k8s.io/cluster-api/exp/runtime/hooks/api/v1alpha1"
    39  )
    40  
    41  // DefaultPort is the default port that the webhook server serves.
    42  var DefaultPort = 9443
    43  
    44  // Server is a runtime webhook server.
    45  type Server struct {
    46  	webhook.Server
    47  	catalog  *runtimecatalog.Catalog
    48  	handlers map[string]ExtensionHandler
    49  }
    50  
    51  // Options are the options for the Server.
    52  type Options struct {
    53  	// Catalog is the catalog used to handle requests.
    54  	Catalog *runtimecatalog.Catalog
    55  
    56  	// Host is the address that the server will listen on.
    57  	// Defaults to "" - all addresses.
    58  	// It is used to set webhook.Server.Host.
    59  	Host string
    60  
    61  	// Port is the port number that the server will serve.
    62  	// It will be defaulted to 9443 if unspecified.
    63  	// It is used to set webhook.Server.Port.
    64  	Port int
    65  
    66  	// CertDir is the directory that contains the server key and certificate.
    67  	// If not set, webhook server would look up the server key and certificate in
    68  	// {TempDir}/k8s-webhook-server/serving-certs. The server key and certificate
    69  	// must be named tls.key and tls.crt, respectively.
    70  	// It is used to set webhook.Server.CertDir.
    71  	CertDir string
    72  
    73  	// TLSOpts is used to allow configuring the TLS config used for the server.
    74  	// This also allows providing a certificate via GetCertificate.
    75  	TLSOpts []func(*tls.Config)
    76  }
    77  
    78  // New creates a new runtime webhook server based on the given Options.
    79  func New(options Options) (*Server, error) {
    80  	if options.Catalog == nil {
    81  		return nil, errors.Errorf("catalog is required")
    82  	}
    83  	if options.Port <= 0 {
    84  		options.Port = DefaultPort
    85  	}
    86  	if options.CertDir == "" {
    87  		options.CertDir = filepath.Join(os.TempDir(), "k8s-webhook-server", "serving-certs")
    88  	}
    89  
    90  	webhookServer := webhook.NewServer(
    91  		webhook.Options{
    92  			Port:       options.Port,
    93  			Host:       options.Host,
    94  			CertDir:    options.CertDir,
    95  			CertName:   "tls.crt",
    96  			KeyName:    "tls.key",
    97  			TLSOpts:    options.TLSOpts,
    98  			WebhookMux: http.NewServeMux(),
    99  		},
   100  	)
   101  
   102  	return &Server{
   103  		Server:   webhookServer,
   104  		catalog:  options.Catalog,
   105  		handlers: map[string]ExtensionHandler{},
   106  	}, nil
   107  }
   108  
   109  // ExtensionHandler represents an extension handler.
   110  type ExtensionHandler struct {
   111  	// gvh is the gvh of the hook corresponding to the extension handler.
   112  	gvh runtimecatalog.GroupVersionHook
   113  	// requestObject is a runtime object that the handler expects to receive.
   114  	requestObject runtime.Object
   115  	// responseObject is a runtime object that the handler expects to return.
   116  	responseObject runtime.Object
   117  
   118  	// Hook is the corresponding hook of the handler.
   119  	Hook runtimecatalog.Hook
   120  
   121  	// Name is the name of the extension handler.
   122  	// An extension handler name must be valid in line RFC 1123 Label Names.
   123  	Name string
   124  
   125  	// HandlerFunc is the handler function.
   126  	HandlerFunc runtimecatalog.Hook
   127  
   128  	// TimeoutSeconds is the timeout of the extension handler.
   129  	// If left undefined, this will be defaulted to 10s when processing the answer to the discovery
   130  	// call for this server.
   131  	TimeoutSeconds *int32
   132  
   133  	// FailurePolicy is the failure policy of the extension handler.
   134  	// If left undefined, this will be defaulted to FailurePolicyFail when processing the answer to the discovery
   135  	// call for this server.
   136  	FailurePolicy *runtimehooksv1.FailurePolicy
   137  }
   138  
   139  // AddExtensionHandler adds an extension handler to the server.
   140  func (s *Server) AddExtensionHandler(handler ExtensionHandler) error {
   141  	gvh, err := s.catalog.GroupVersionHook(handler.Hook)
   142  	if err != nil {
   143  		return errors.Wrapf(err, "hook %q does not exist in catalog", runtimecatalog.HookName(handler.Hook))
   144  	}
   145  	handler.gvh = gvh
   146  
   147  	requestObject, err := s.catalog.NewRequest(handler.gvh)
   148  	if err != nil {
   149  		return err
   150  	}
   151  	handler.requestObject = requestObject
   152  
   153  	responseObject, err := s.catalog.NewResponse(handler.gvh)
   154  	if err != nil {
   155  		return err
   156  	}
   157  	handler.responseObject = responseObject
   158  
   159  	if err := s.validateHandler(handler); err != nil {
   160  		return err
   161  	}
   162  
   163  	handlerPath := runtimecatalog.GVHToPath(handler.gvh, handler.Name)
   164  	if _, ok := s.handlers[handlerPath]; ok {
   165  		return errors.Errorf("there is already a handler registered for path %q", handlerPath)
   166  	}
   167  
   168  	s.handlers[handlerPath] = handler
   169  	return nil
   170  }
   171  
   172  // validateHandler validates a handler.
   173  func (s *Server) validateHandler(handler ExtensionHandler) error {
   174  	// Get hook and handler type.
   175  	hookFuncType := reflect.TypeOf(handler.Hook)
   176  	handlerFuncType := reflect.TypeOf(handler.HandlerFunc)
   177  
   178  	// Validate handler function signature.
   179  	if handlerFuncType.Kind() != reflect.Func {
   180  		return errors.Errorf("HandlerFunc must be a func")
   181  	}
   182  	if handlerFuncType.NumIn() != 3 {
   183  		return errors.Errorf("HandlerFunc must have three input parameter")
   184  	}
   185  	if handlerFuncType.NumOut() != 0 {
   186  		return errors.Errorf("HandlerFunc must have no output parameter")
   187  	}
   188  
   189  	// Get hook and handler request and response types.
   190  	hookRequestType := hookFuncType.In(0)
   191  	hookResponseType := hookFuncType.In(1)
   192  	handlerContextType := handlerFuncType.In(0)
   193  	handlerRequestType := handlerFuncType.In(1)
   194  	handlerResponseType := handlerFuncType.In(2)
   195  
   196  	// Validate handler request and response are pointers.
   197  	if handlerRequestType.Kind() != reflect.Ptr {
   198  		return errors.Errorf("HandlerFunc request type must be a pointer")
   199  	}
   200  	if handlerResponseType.Kind() != reflect.Ptr {
   201  		return errors.Errorf("HandlerFunc response type must be a pointer")
   202  	}
   203  
   204  	// Validate first handler parameter is a context
   205  	// TODO: improve check, how to check if param is a specific interface?
   206  	if handlerContextType.Name() != "Context" {
   207  		return errors.Errorf("HandlerFunc first parameter must be Context but is %s", handlerContextType.Name())
   208  	}
   209  
   210  	// Validate hook and handler request and response types are equal.
   211  	if hookRequestType != handlerRequestType {
   212  		return errors.Errorf("HandlerFunc request type must be *%s but is *%s", hookRequestType.Elem().Name(), handlerRequestType.Elem().Name())
   213  	}
   214  	if hookResponseType != handlerResponseType {
   215  		return errors.Errorf("HandlerFunc response type must be *%s but is *%s", hookResponseType.Elem().Name(), handlerResponseType.Elem().Name())
   216  	}
   217  
   218  	return nil
   219  }
   220  
   221  // Start starts the server.
   222  func (s *Server) Start(ctx context.Context) error {
   223  	// Add discovery handler.
   224  	err := s.AddExtensionHandler(ExtensionHandler{
   225  		Hook:        runtimehooksv1.Discovery,
   226  		HandlerFunc: discoveryHandler(s.handlers),
   227  	})
   228  	if err != nil {
   229  		return err
   230  	}
   231  
   232  	// Add handlers to router.
   233  	for handlerPath, h := range s.handlers {
   234  		handler := h
   235  
   236  		wrappedHandler := s.wrapHandler(handler)
   237  		s.Server.Register(handlerPath, http.HandlerFunc(wrappedHandler))
   238  	}
   239  
   240  	return s.Server.Start(ctx)
   241  }
   242  
   243  // discoveryHandler generates a discovery handler based on a list of handlers.
   244  func discoveryHandler(handlers map[string]ExtensionHandler) func(context.Context, *runtimehooksv1.DiscoveryRequest, *runtimehooksv1.DiscoveryResponse) {
   245  	cachedHandlers := []runtimehooksv1.ExtensionHandler{}
   246  	for _, handler := range handlers {
   247  		cachedHandlers = append(cachedHandlers, runtimehooksv1.ExtensionHandler{
   248  			Name: handler.Name,
   249  			RequestHook: runtimehooksv1.GroupVersionHook{
   250  				APIVersion: handler.gvh.GroupVersion().String(),
   251  				Hook:       handler.gvh.Hook,
   252  			},
   253  			TimeoutSeconds: handler.TimeoutSeconds,
   254  			FailurePolicy:  handler.FailurePolicy,
   255  		})
   256  	}
   257  
   258  	return func(_ context.Context, _ *runtimehooksv1.DiscoveryRequest, response *runtimehooksv1.DiscoveryResponse) {
   259  		response.SetStatus(runtimehooksv1.ResponseStatusSuccess)
   260  		response.Handlers = cachedHandlers
   261  	}
   262  }
   263  
   264  func (s *Server) wrapHandler(handler ExtensionHandler) func(w http.ResponseWriter, r *http.Request) {
   265  	return func(w http.ResponseWriter, r *http.Request) {
   266  		response := s.callHandler(handler, r)
   267  
   268  		responseBody, err := json.Marshal(response)
   269  		if err != nil {
   270  			w.WriteHeader(http.StatusInternalServerError)
   271  			_, _ = fmt.Fprintf(w, "unable to marshal response: %v", err)
   272  			return
   273  		}
   274  
   275  		w.WriteHeader(http.StatusOK)
   276  		_, _ = w.Write(responseBody)
   277  	}
   278  }
   279  
   280  func (s *Server) callHandler(handler ExtensionHandler, r *http.Request) runtimehooksv1.ResponseObject {
   281  	request := handler.requestObject.DeepCopyObject()
   282  	response := handler.responseObject.DeepCopyObject().(runtimehooksv1.ResponseObject)
   283  
   284  	requestBody, err := io.ReadAll(r.Body)
   285  	if err != nil {
   286  		response.SetStatus(runtimehooksv1.ResponseStatusFailure)
   287  		response.SetMessage(fmt.Sprintf("error reading request: %v", err))
   288  		return response
   289  	}
   290  
   291  	if err := json.Unmarshal(requestBody, request); err != nil {
   292  		response.SetStatus(runtimehooksv1.ResponseStatusFailure)
   293  		response.SetMessage(fmt.Sprintf("error unmarshalling request: %v", err))
   294  		return response
   295  	}
   296  
   297  	// log.Log is the logger previously set via ctrl.SetLogger.
   298  	// This implemented analog to the logger in the controller-runtime manager.
   299  	ctx := ctrl.LoggerInto(r.Context(), log.Log)
   300  
   301  	reflect.ValueOf(handler.HandlerFunc).Call([]reflect.Value{
   302  		reflect.ValueOf(ctx),
   303  		reflect.ValueOf(request),
   304  		reflect.ValueOf(response),
   305  	})
   306  
   307  	return response
   308  }