istio.io/istio@v0.0.0-20240520182934-d79c90f27776/pkg/hbone/dialer.go (about)

     1  // Copyright Istio Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package hbone
    16  
    17  import (
    18  	"context"
    19  	"crypto/tls"
    20  	"fmt"
    21  	"io"
    22  	"net"
    23  	"net/http"
    24  	"strings"
    25  	"sync"
    26  	"time"
    27  
    28  	"golang.org/x/net/http2"
    29  	"golang.org/x/net/proxy"
    30  
    31  	istiolog "istio.io/istio/pkg/log"
    32  	"istio.io/istio/security/pkg/pki/util"
    33  )
    34  
    35  var log = istiolog.RegisterScope("hbone", "")
    36  
    37  // Config defines the configuration for a given dialer. All fields other than ProxyAddress are optional
    38  type Config struct {
    39  	// ProxyAddress defines the address of the HBONE proxy we are connecting to
    40  	ProxyAddress string
    41  	Headers      http.Header
    42  	TLS          *tls.Config
    43  	Timeout      *time.Duration
    44  }
    45  
    46  type Dialer interface {
    47  	proxy.Dialer
    48  	proxy.ContextDialer
    49  }
    50  
    51  // NewDialer creates a Dialer that proxies connections over HBONE to the configured proxy.
    52  func NewDialer(cfg Config) Dialer {
    53  	var transport *http2.Transport
    54  
    55  	if cfg.TLS != nil {
    56  		transport = &http2.Transport{
    57  			TLSClientConfig: cfg.TLS,
    58  		}
    59  	} else {
    60  		transport = &http2.Transport{
    61  			// For h2c
    62  			AllowHTTP: true,
    63  			DialTLSContext: func(ctx context.Context, network, addr string, tlsCfg *tls.Config) (net.Conn, error) {
    64  				d := net.Dialer{}
    65  				if cfg.Timeout != nil {
    66  					d.Timeout = *cfg.Timeout
    67  				}
    68  				return d.Dial(network, addr)
    69  			},
    70  		}
    71  	}
    72  	return &dialer{
    73  		cfg:       cfg,
    74  		transport: transport,
    75  	}
    76  }
    77  
    78  type dialer struct {
    79  	cfg       Config
    80  	transport *http2.Transport
    81  }
    82  
    83  // DialContext connects to `address` via the HBONE proxy.
    84  func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
    85  	if network != "tcp" {
    86  		return net.Dial(network, address)
    87  	}
    88  	// TODO: use context
    89  	c, s := net.Pipe()
    90  	err := d.proxyTo(s, d.cfg, address)
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  	return c, nil
    95  }
    96  
    97  func (d dialer) Dial(network, address string) (c net.Conn, err error) {
    98  	return d.DialContext(context.Background(), network, address)
    99  }
   100  
   101  func (d *dialer) proxyTo(conn io.ReadWriteCloser, req Config, address string) error {
   102  	t0 := time.Now()
   103  
   104  	url := "http://" + req.ProxyAddress
   105  	if req.TLS != nil {
   106  		url = "https://" + req.ProxyAddress
   107  	}
   108  	// Setup a pipe. We could just pass `conn` to `http.NewRequest`, but this has a few issues:
   109  	// * Less visibility into i/o
   110  	// * http will call conn.Close, which will close before we want to (finished writing response).
   111  	pr, pw := io.Pipe()
   112  	r, err := http.NewRequest(http.MethodConnect, url, pr)
   113  	if err != nil {
   114  		return fmt.Errorf("new request: %v", err)
   115  	}
   116  	r.Host = address
   117  
   118  	// Initiate CONNECT.
   119  	log.Infof("initiate CONNECT to %v via %v", r.Host, url)
   120  
   121  	resp, err := d.transport.RoundTrip(r)
   122  	if err != nil {
   123  		return fmt.Errorf("round trip: %v", err)
   124  	}
   125  	var remoteID string
   126  	if resp.TLS != nil && len(resp.TLS.PeerCertificates) > 0 {
   127  		ids, _ := util.ExtractIDs(resp.TLS.PeerCertificates[0].Extensions)
   128  		if len(ids) > 0 {
   129  			remoteID = ids[0]
   130  		}
   131  	}
   132  	if resp.StatusCode != http.StatusOK {
   133  		return fmt.Errorf("round trip failed: %v", resp.Status)
   134  	}
   135  	log.WithLabels("host", r.Host, "remote", remoteID).Info("CONNECT established")
   136  	go func() {
   137  		defer conn.Close()
   138  		defer resp.Body.Close()
   139  
   140  		wg := sync.WaitGroup{}
   141  		wg.Add(1)
   142  		go func() {
   143  			// handle upstream (hbone server) --> downstream (app)
   144  			copyBuffered(conn, resp.Body, log.WithLabels("name", "body to conn"))
   145  			wg.Done()
   146  		}()
   147  		// Copy from conn into the pipe, which will then be sent as part of the request
   148  		// handle upstream (hbone server) <-- downstream (app)
   149  		copyBuffered(pw, conn, log.WithLabels("name", "conn to pipe"))
   150  
   151  		wg.Wait()
   152  		log.Infof("stream closed in %v", time.Since(t0))
   153  	}()
   154  
   155  	return nil
   156  }
   157  
   158  // TLSDialWithDialer is an implementation of tls.DialWithDialer that accepts a generic Dialer
   159  func TLSDialWithDialer(dialer Dialer, network, addr string, config *tls.Config) (*tls.Conn, error) {
   160  	return tlsDial(context.Background(), dialer, network, addr, config)
   161  }
   162  
   163  func tlsDial(ctx context.Context, netDialer Dialer, network, addr string, config *tls.Config) (*tls.Conn, error) {
   164  	rawConn, err := netDialer.DialContext(ctx, network, addr)
   165  	if err != nil {
   166  		return nil, err
   167  	}
   168  
   169  	colonPos := strings.LastIndex(addr, ":")
   170  	if colonPos == -1 {
   171  		colonPos = len(addr)
   172  	}
   173  	hostname := addr[:colonPos]
   174  
   175  	if config == nil {
   176  		config = &tls.Config{MinVersion: tls.VersionTLS12}
   177  	}
   178  	// If no ServerName is set, infer the ServerName
   179  	// from the hostname we're connecting to.
   180  	if config.ServerName == "" {
   181  		// Make a copy to avoid polluting argument or default.
   182  		c := config.Clone()
   183  		c.ServerName = hostname
   184  		config = c
   185  	}
   186  
   187  	conn := tls.Client(rawConn, config)
   188  	if err := conn.HandshakeContext(ctx); err != nil {
   189  		_ = rawConn.Close()
   190  		return nil, err
   191  	}
   192  	return conn, nil
   193  }