github.com/smartcontractkit/chainlink-testing-framework/libs@v0.0.0-20240227141906-ec710b4eb1a3/k8s/client/forwarder.go (about)

     1  package client
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"net/http"
     7  	"net/url"
     8  	"strings"
     9  	"sync"
    10  
    11  	"github.com/rs/zerolog/log"
    12  	"golang.org/x/sync/errgroup"
    13  	v1 "k8s.io/api/core/v1"
    14  	"k8s.io/client-go/tools/portforward"
    15  	"k8s.io/client-go/transport/spdy"
    16  )
    17  
    18  type Forwarder struct {
    19  	Client         *K8sClient
    20  	mu             *sync.Mutex
    21  	KeepConnection bool
    22  	Info           map[string]interface{}
    23  }
    24  
    25  type ConnectionInfo struct {
    26  	Ports portforward.ForwardedPort
    27  	Host  string
    28  }
    29  
    30  func NewForwarder(client *K8sClient, keepConnection bool) *Forwarder {
    31  	return &Forwarder{
    32  		Client:         client,
    33  		mu:             &sync.Mutex{},
    34  		KeepConnection: keepConnection,
    35  		Info:           make(map[string]interface{}),
    36  	}
    37  }
    38  
    39  func (m *Forwarder) forwardPodPorts(pod v1.Pod, namespaceName string) error {
    40  	if pod.Status.Phase != v1.PodRunning {
    41  		log.Debug().Str("Pod", pod.Name).Interface("Phase", pod.Status.Phase).Msg("Skipping pod for port forwarding")
    42  		return nil
    43  	}
    44  	roundTripper, upgrader, err := spdy.RoundTripperFor(m.Client.RESTConfig)
    45  	if err != nil {
    46  		return err
    47  	}
    48  	httpPath := fmt.Sprintf("/api/v1/namespaces/%s/pods/%s/portforward", namespaceName, pod.Name)
    49  	hostIP := strings.TrimLeft(m.Client.RESTConfig.Host, "htps:/")
    50  	serverURL := url.URL{Scheme: "https", Path: httpPath, Host: hostIP}
    51  
    52  	dialer := spdy.NewDialer(upgrader, &http.Client{Transport: roundTripper}, http.MethodPost, &serverURL)
    53  
    54  	portRules := m.portRulesForPod(pod)
    55  	if len(portRules) == 0 {
    56  		return nil
    57  	}
    58  
    59  	// porforward is not thread safe for using multiple rules in the same forwarder,
    60  	// at least not until this pr is merged: https://github.com/kubernetes/kubernetes/pull/114342
    61  	forwardedPorts := []portforward.ForwardedPort{}
    62  	for _, portRule := range portRules {
    63  		stopChan, readyChan := make(chan struct{}, 1), make(chan struct{}, 1)
    64  		out, errOut := new(bytes.Buffer), new(bytes.Buffer)
    65  
    66  		log.Debug().
    67  			Str("Pod", pod.Name).
    68  			Msg("Attempting to forward ports")
    69  
    70  		forwarder, err := portforward.New(dialer, []string{portRule}, stopChan, readyChan, out, errOut)
    71  		if err != nil {
    72  			return err
    73  		}
    74  		go func() {
    75  			if err := forwarder.ForwardPorts(); err != nil {
    76  				log.Error().Str("Pod", pod.Name).Err(err)
    77  			}
    78  		}()
    79  
    80  		<-readyChan
    81  		if len(errOut.String()) > 0 {
    82  			return fmt.Errorf("error on forwarding k8s port: %v", errOut.String())
    83  		}
    84  		fP, err := forwarder.GetPorts()
    85  		if err != nil {
    86  			return err
    87  		}
    88  		forwardedPorts = append(forwardedPorts, fP...)
    89  	}
    90  	m.mu.Lock()
    91  	defer m.mu.Unlock()
    92  	namedPorts := m.podPortsByName(pod, forwardedPorts)
    93  	if pod.Labels[AppLabel] != "" {
    94  		m.Info[fmt.Sprintf("%s:%s", pod.Labels[AppLabel], pod.Labels["instance"])] = namedPorts
    95  	}
    96  	return nil
    97  }
    98  
    99  func (m *Forwarder) collectPodPorts(pod v1.Pod) error {
   100  	namedPorts := make(map[string]interface{})
   101  	for _, c := range pod.Spec.Containers {
   102  		for _, cp := range c.Ports {
   103  			if namedPorts[c.Name] == nil {
   104  				namedPorts[c.Name] = make(map[string]interface{})
   105  			}
   106  			namedPorts[c.Name].(map[string]interface{})[cp.Name] = ConnectionInfo{
   107  				Host:  pod.Status.PodIP,
   108  				Ports: portforward.ForwardedPort{Remote: uint16(cp.ContainerPort)},
   109  			}
   110  		}
   111  	}
   112  	m.mu.Lock()
   113  	defer m.mu.Unlock()
   114  	if pod.Labels[AppLabel] != "" {
   115  		m.Info[fmt.Sprintf("%s:%s", pod.Labels[AppLabel], pod.Labels["instance"])] = namedPorts
   116  	}
   117  	return nil
   118  }
   119  
   120  func (m *Forwarder) podPortsByName(pod v1.Pod, fp []portforward.ForwardedPort) map[string]interface{} {
   121  	ports := make(map[string]interface{})
   122  	for _, forwardedPort := range fp {
   123  		for _, c := range pod.Spec.Containers {
   124  			for _, cp := range c.Ports {
   125  				if cp.ContainerPort == int32(forwardedPort.Remote) {
   126  					if ports[c.Name] == nil {
   127  						ports[c.Name] = make(map[string]interface{})
   128  					}
   129  					ports[c.Name].(map[string]interface{})[cp.Name] = ConnectionInfo{
   130  						Host:  pod.Status.PodIP,
   131  						Ports: forwardedPort,
   132  					}
   133  				}
   134  			}
   135  		}
   136  	}
   137  	return ports
   138  }
   139  
   140  func (m *Forwarder) portRulesForPod(pod v1.Pod) []string {
   141  	rules := make([]string, 0)
   142  	for _, c := range pod.Spec.Containers {
   143  		for _, port := range c.Ports {
   144  			rules = append(rules, fmt.Sprintf(":%d", port.ContainerPort))
   145  		}
   146  	}
   147  	return rules
   148  }
   149  
   150  func (m *Forwarder) Connect(namespaceName string, selector string, insideK8s bool) error {
   151  	m.Info = make(map[string]interface{})
   152  	pods, err := m.Client.ListPods(namespaceName, selector)
   153  	if err != nil {
   154  		return err
   155  	}
   156  	eg := &errgroup.Group{}
   157  	for _, p := range pods.Items {
   158  		p := p
   159  		if insideK8s {
   160  			eg.Go(func() error {
   161  				return m.collectPodPorts(p)
   162  			})
   163  		} else {
   164  			eg.Go(func() error {
   165  				return m.forwardPodPorts(p, namespaceName)
   166  			})
   167  		}
   168  	}
   169  	return eg.Wait()
   170  }
   171  
   172  // PrintLocalPorts prints all local forwarded ports
   173  func (m *Forwarder) PrintLocalPorts() {
   174  	for labeledAppPodName, labeledAppPod := range m.Info {
   175  		for containerName, container := range labeledAppPod.(map[string]interface{}) {
   176  			for fpName, portsData := range container.(map[string]interface{}) {
   177  				log.Info().
   178  					Str("Label", labeledAppPodName).
   179  					Str("Container", containerName).
   180  					Str("PortNames", fpName).
   181  					Uint16("Port", portsData.(ConnectionInfo).Ports.Local).
   182  					Msg("Local ports")
   183  			}
   184  		}
   185  	}
   186  }
   187  
   188  func (m *Forwarder) FindPort(ks ...string) *URLConverter {
   189  	d, err := lookupMap(m.Info, ks...)
   190  	return NewURLConverter(d.(ConnectionInfo), err)
   191  }
   192  
   193  func lookupMap(m map[string]interface{}, ks ...string) (rval interface{}, err error) {
   194  	var ok bool
   195  	if len(ks) == 0 {
   196  		return nil, fmt.Errorf("select port path like $app_name:$instance $container_name $port_name")
   197  	}
   198  	if rval, ok = m[ks[0]]; !ok {
   199  		return ConnectionInfo{}, fmt.Errorf("key not found: '%s' remaining keys: %s, provided map: %s", ks[0], ks, m)
   200  	} else if len(ks) == 1 {
   201  		return rval, nil
   202  	}
   203  	return lookupMap(m[ks[0]].(map[string]interface{}), ks[1:]...)
   204  }