github.com/argoproj/argo-cd/v3@v3.2.1/pkg/apiclient/grpcproxy.go (about)

     1  package apiclient
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"encoding/binary"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"net"
    11  	"net/http"
    12  	"os"
    13  	"strconv"
    14  	"strings"
    15  
    16  	"google.golang.org/grpc"
    17  	"google.golang.org/grpc/codes"
    18  	"google.golang.org/grpc/keepalive"
    19  	"google.golang.org/grpc/metadata"
    20  	"google.golang.org/grpc/status"
    21  
    22  	"github.com/argoproj/argo-cd/v3/common"
    23  	argocderrors "github.com/argoproj/argo-cd/v3/util/errors"
    24  	utilio "github.com/argoproj/argo-cd/v3/util/io"
    25  	"github.com/argoproj/argo-cd/v3/util/rand"
    26  )
    27  
    28  const (
    29  	frameHeaderLength = 5
    30  	endOfStreamFlag   = 128
    31  )
    32  
    33  type noopCodec struct{}
    34  
    35  func (noopCodec) Marshal(v any) ([]byte, error) {
    36  	return v.([]byte), nil
    37  }
    38  
    39  func (noopCodec) Unmarshal(data []byte, v any) error {
    40  	pointer := v.(*[]byte)
    41  	*pointer = data
    42  	return nil
    43  }
    44  
    45  func (noopCodec) Name() string {
    46  	return "proto"
    47  }
    48  
    49  func toFrame(msg []byte) []byte {
    50  	frame := append([]byte{0, 0, 0, 0}, msg...)
    51  	binary.BigEndian.PutUint32(frame, uint32(len(msg)))
    52  	frame = append([]byte{0}, frame...)
    53  	return frame
    54  }
    55  
    56  func (c *client) executeRequest(ctx context.Context, fullMethodName string, msg []byte, md metadata.MD) (*http.Response, error) {
    57  	schema := "https"
    58  	if c.PlainText {
    59  		schema = "http"
    60  	}
    61  	rootPath := strings.TrimRight(strings.TrimLeft(c.GRPCWebRootPath, "/"), "/")
    62  
    63  	var requestURL string
    64  	if rootPath != "" {
    65  		requestURL = fmt.Sprintf("%s://%s/%s%s", schema, c.ServerAddr, rootPath, fullMethodName)
    66  	} else {
    67  		requestURL = fmt.Sprintf("%s://%s%s", schema, c.ServerAddr, fullMethodName)
    68  	}
    69  	// Use context in the HTTP request
    70  	req, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL, bytes.NewReader(toFrame(msg)))
    71  	if err != nil {
    72  		return nil, err
    73  	}
    74  	for k, v := range md {
    75  		if strings.HasPrefix(k, ":") {
    76  			continue
    77  		}
    78  		for i := range v {
    79  			req.Header.Set(k, v[i])
    80  		}
    81  	}
    82  	req.Header.Set("content-type", "application/grpc-web+proto")
    83  
    84  	resp, err := c.httpClient.Do(req)
    85  	if err != nil {
    86  		return nil, err
    87  	}
    88  	if resp.StatusCode != http.StatusOK {
    89  		return nil, fmt.Errorf("%s %s failed with status code %d", req.Method, req.URL, resp.StatusCode)
    90  	}
    91  	var code codes.Code
    92  	if statusStr := resp.Header.Get("Grpc-Status"); statusStr != "" {
    93  		statusInt, err := strconv.ParseUint(statusStr, 10, 32)
    94  		if err != nil {
    95  			code = codes.Unknown
    96  		} else {
    97  			code = codes.Code(statusInt)
    98  		}
    99  		if code != codes.OK {
   100  			return nil, status.Error(code, resp.Header.Get("Grpc-Message"))
   101  		}
   102  	}
   103  	return resp, nil
   104  }
   105  
   106  func (c *client) startGRPCProxy() (*grpc.Server, net.Listener, error) {
   107  	randSuffix, err := rand.String(16)
   108  	if err != nil {
   109  		return nil, nil, fmt.Errorf("failed to generate random socket filename: %w", err)
   110  	}
   111  	serverAddr := fmt.Sprintf("%s/argocd-%s.sock", os.TempDir(), randSuffix)
   112  	ln, err := net.Listen("unix", serverAddr)
   113  	if err != nil {
   114  		return nil, nil, err
   115  	}
   116  	proxySrv := grpc.NewServer(
   117  		grpc.ForceServerCodec(&noopCodec{}),
   118  		grpc.KeepaliveEnforcementPolicy(
   119  			keepalive.EnforcementPolicy{
   120  				MinTime: common.GetGRPCKeepAliveEnforcementMinimum(),
   121  			},
   122  		),
   123  		grpc.UnknownServiceHandler(func(_ any, stream grpc.ServerStream) error {
   124  			fullMethodName, ok := grpc.MethodFromServerStream(stream)
   125  			if !ok {
   126  				return errors.New("unable to get method name from stream context")
   127  			}
   128  			msg := make([]byte, 0)
   129  			err := stream.RecvMsg(&msg)
   130  			if err != nil {
   131  				return err
   132  			}
   133  
   134  			md, _ := metadata.FromIncomingContext(stream.Context())
   135  			headersMD, err := parseGRPCHeaders(c.Headers)
   136  			if err != nil {
   137  				return err
   138  			}
   139  
   140  			md = metadata.Join(md, headersMD)
   141  
   142  			resp, err := c.executeRequest(stream.Context(), fullMethodName, msg, md)
   143  			if err != nil {
   144  				return err
   145  			}
   146  
   147  			go func() {
   148  				<-stream.Context().Done()
   149  				utilio.Close(resp.Body)
   150  			}()
   151  			defer utilio.Close(resp.Body)
   152  			c.httpClient.CloseIdleConnections()
   153  
   154  			for {
   155  				header := make([]byte, frameHeaderLength)
   156  				if _, err := io.ReadAtLeast(resp.Body, header, frameHeaderLength); err != nil {
   157  					if errors.Is(err, io.EOF) {
   158  						err = io.ErrUnexpectedEOF
   159  					}
   160  					return err
   161  				}
   162  
   163  				if header[0] == endOfStreamFlag {
   164  					return nil
   165  				}
   166  				length := int(binary.BigEndian.Uint32(header[1:frameHeaderLength]))
   167  				data := make([]byte, length)
   168  
   169  				if read, err := io.ReadAtLeast(resp.Body, data, length); err != nil {
   170  					if !errors.Is(err, io.EOF) {
   171  						return err
   172  					} else if read < length {
   173  						return io.ErrUnexpectedEOF
   174  					}
   175  					return nil
   176  				}
   177  
   178  				if err := stream.SendMsg(data); err != nil {
   179  					return err
   180  				}
   181  			}
   182  		}))
   183  	go func() {
   184  		err := proxySrv.Serve(ln)
   185  		argocderrors.CheckError(err)
   186  	}()
   187  	return proxySrv, ln, nil
   188  }
   189  
   190  // useGRPCProxy ensures that grpc proxy server is started and return closer which stops server when no one uses it
   191  func (c *client) useGRPCProxy() (net.Addr, io.Closer, error) {
   192  	c.proxyMutex.Lock()
   193  	defer c.proxyMutex.Unlock()
   194  
   195  	if c.proxyListener == nil {
   196  		var err error
   197  		c.proxyServer, c.proxyListener, err = c.startGRPCProxy()
   198  		if err != nil {
   199  			return nil, nil, err
   200  		}
   201  	}
   202  	c.proxyUsersCount = c.proxyUsersCount + 1
   203  
   204  	return c.proxyListener.Addr(), utilio.NewCloser(func() error {
   205  		c.proxyMutex.Lock()
   206  		defer c.proxyMutex.Unlock()
   207  		c.proxyUsersCount = c.proxyUsersCount - 1
   208  		if c.proxyUsersCount == 0 {
   209  			c.proxyServer.Stop()
   210  			c.proxyListener = nil
   211  			c.proxyServer = nil
   212  			return nil
   213  		}
   214  		return nil
   215  	}), nil
   216  }
   217  
   218  func parseGRPCHeaders(headerStrings []string) (metadata.MD, error) {
   219  	md := metadata.New(map[string]string{})
   220  	for _, kv := range headerStrings {
   221  		i := strings.IndexByte(kv, ':')
   222  		// zero means meaningless empty header name
   223  		if i <= 0 {
   224  			return nil, fmt.Errorf("additional headers must be colon(:)-separated: %s", kv)
   225  		}
   226  		md.Append(kv[0:i], kv[i+1:])
   227  	}
   228  	return md, nil
   229  }