github.com/inspektor-gadget/inspektor-gadget@v0.28.1/pkg/runtime/grpc/k8s-portfwd-dialer.go (about)

     1  // Copyright 2023 The Inspektor Gadget authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package grpcruntime
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"io"
    21  	"net"
    22  	"net/http"
    23  	"net/url"
    24  	"strconv"
    25  	"time"
    26  
    27  	log "github.com/sirupsen/logrus"
    28  	v1 "k8s.io/api/core/v1"
    29  	"k8s.io/apimachinery/pkg/util/httpstream"
    30  	"k8s.io/client-go/rest"
    31  	"k8s.io/client-go/tools/portforward"
    32  	"k8s.io/client-go/transport/spdy"
    33  
    34  	"github.com/inspektor-gadget/inspektor-gadget/pkg/factory"
    35  )
    36  
    37  type k8sPortFwdDialer struct {
    38  	io.Writer
    39  	io.Reader
    40  	conn    httpstream.Connection
    41  	stream  httpstream.Stream
    42  	podName string
    43  }
    44  
    45  // NewK8SPortFwdConn connects to a Pod using PortForwarding via the Kubernetes API Server
    46  func NewK8SPortFwdConn(ctx context.Context, config *rest.Config, namespace string, pod target, targetPort uint16, timeout time.Duration) (net.Conn, error) {
    47  	conn := &k8sPortFwdDialer{}
    48  
    49  	// set GroupVersion and NegotiatedSerializer for RESTClient
    50  	factory.SetKubernetesDefaults(config)
    51  
    52  	conn.podName = pod.addressOrPod
    53  
    54  	config.Timeout = timeout
    55  
    56  	transport, upgrader, err := spdy.RoundTripperFor(config)
    57  	if err != nil {
    58  		return nil, fmt.Errorf("creating roundtripper: %w", err)
    59  	}
    60  
    61  	targetURL, err := url.Parse(config.Host)
    62  	if err != nil {
    63  		return nil, fmt.Errorf("parsing restConfig.Host: %w", err)
    64  	}
    65  
    66  	targetURL.Path = fmt.Sprintf("api/v1/namespaces/%s/pods/%s/portforward", namespace, conn.podName)
    67  
    68  	dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, http.MethodPost, targetURL)
    69  
    70  	newConn, _, err := dialer.Dial(portforward.PortForwardProtocolV1Name)
    71  	if err != nil {
    72  		return nil, err
    73  	}
    74  
    75  	// create error stream
    76  	headers := http.Header{}
    77  	headers.Set(v1.StreamType, v1.StreamTypeError)
    78  	headers.Set(v1.PortHeader, fmt.Sprintf("%d", targetPort))
    79  	headers.Set(v1.PortForwardRequestIDHeader, strconv.Itoa(1))
    80  	errorStream, err := newConn.CreateStream(headers)
    81  	if err != nil {
    82  		newConn.Close()
    83  		return nil, fmt.Errorf("creating error stream for port forward: %w", err)
    84  	}
    85  	// we're not writing to this stream, but it is required for other streams to be able to connect
    86  	errorStream.Close()
    87  
    88  	go func() {
    89  		message, err := io.ReadAll(errorStream)
    90  		switch {
    91  		case err != nil:
    92  			log.Errorf("k8sPortFwd connection: reading from error stream: %v", err)
    93  		case len(message) > 0:
    94  			log.Errorf("k8sPortFwd tcp connection: forwarding port: %v", string(message))
    95  			log.Errorf("Please make sure the --connection-method value matches your installation.")
    96  		}
    97  	}()
    98  
    99  	// create data stream
   100  	headers.Set(v1.StreamType, v1.StreamTypeData)
   101  	dataStream, err := newConn.CreateStream(headers)
   102  	if err != nil {
   103  		newConn.Close()
   104  		return nil, fmt.Errorf("creating data stream for port forward: %w", err)
   105  	}
   106  
   107  	conn.conn = newConn
   108  	conn.stream = dataStream
   109  	return conn, nil
   110  }
   111  
   112  func (k *k8sPortFwdDialer) Close() error {
   113  	k.stream.Close()
   114  	return k.conn.Close()
   115  }
   116  
   117  func (k *k8sPortFwdDialer) Read(b []byte) (n int, err error) {
   118  	return k.stream.Read(b)
   119  }
   120  
   121  func (k *k8sPortFwdDialer) Write(b []byte) (n int, err error) {
   122  	return k.stream.Write(b)
   123  }
   124  
   125  func (k *k8sPortFwdDialer) LocalAddr() net.Addr {
   126  	return nil
   127  }
   128  
   129  func (k *k8sPortFwdDialer) RemoteAddr() net.Addr {
   130  	return &k8sAddress{podName: k.podName}
   131  }
   132  
   133  // satisfying the net.Conn interface
   134  
   135  func (k *k8sPortFwdDialer) SetDeadline(_ time.Time) error {
   136  	return nil
   137  }
   138  
   139  func (k *k8sPortFwdDialer) SetReadDeadline(_ time.Time) error {
   140  	return nil
   141  }
   142  
   143  func (k *k8sPortFwdDialer) SetWriteDeadline(_ time.Time) error {
   144  	return nil
   145  }