
     1  // Copyright Istio Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    15  package nodeagent
    17  import (
    18  	"context"
    19  	"errors"
    20  	"fmt"
    21  	"io"
    22  	"net"
    23  	"os"
    24  	"sync"
    25  	"time"
    27  	""
    28  	""
    29  	v1 ""
    31  	""
    32  	""
    33  	""
    34  )
    36  var (
    37  	ztunnelKeepAliveCheckInterval = 5 * time.Second
    38  	readWriteDeadline             = 5 * time.Second
    39  )
    41  var ztunnelConnected = monitoring.NewGauge("ztunnel_connected",
    42  	"number of connections to ztunnel")
    44  type ZtunnelServer interface {
    45  	Run(ctx context.Context)
    46  	PodDeleted(ctx context.Context, uid string) error
    47  	PodAdded(ctx context.Context, pod *v1.Pod, netns Netns) error
    48  	Close() error
    49  }
    51  /*
    52  To clean up stale ztunnels
    54  	we may need to ztunnel to send its (uid, bootid / boot time) to us
    55  	so that we can remove stale entries when the ztunnel pod is deleted
    56  	or when the ztunnel pod is restarted in the same pod (remove old entries when the same uid connects again, but with different boot id?)
    58  	save a queue of what needs to be sent to the ztunnel pod and send it one by one when it connects.
    60  	when a new ztunnel connects with different uid, only propagate deletes to older ztunnels.
    61  */
    63  type connMgr struct {
    64  	connectionSet map[*ZtunnelConnection]struct{}
    65  	latestConn    *ZtunnelConnection
    66  	mu            sync.Mutex
    67  }
    69  func (c *connMgr) addConn(conn *ZtunnelConnection) {
    70  	log.Debug("ztunnel connected")
    72  	defer
    73  	c.connectionSet[conn] = struct{}{}
    74  	c.latestConn = conn
    75  	ztunnelConnected.RecordInt(int64(len(c.connectionSet)))
    76  }
    78  func (c *connMgr) LatestConn() *ZtunnelConnection {
    80  	defer
    81  	return c.latestConn
    82  }
    84  func (c *connMgr) deleteConn(conn *ZtunnelConnection) {
    85  	log.Debug("ztunnel disconnected")
    87  	defer
    88  	delete(c.connectionSet, conn)
    89  	if c.latestConn == conn {
    90  		c.latestConn = nil
    91  	}
    92  	ztunnelConnected.RecordInt(int64(len(c.connectionSet)))
    93  }
    95  // this is used in tests
    96  // nolint: unused
    97  func (c *connMgr) len() int {
    99  	defer
   100  	return len(c.connectionSet)
   101  }
   103  type ztunnelServer struct {
   104  	listener *net.UnixListener
   106  	// connections to pod delivered map
   107  	// add pod goes to newest connection
   108  	// delete pod goes to all connections
   109  	conns *connMgr
   110  	pods  PodNetnsCache
   111  }
   113  var _ ZtunnelServer = &ztunnelServer{}
   115  func newZtunnelServer(addr string, pods PodNetnsCache) (*ztunnelServer, error) {
   116  	if addr == "" {
   117  		return nil, fmt.Errorf("addr cannot be empty")
   118  	}
   120  	resolvedAddr, err := net.ResolveUnixAddr("unixpacket", addr)
   121  	if err != nil {
   122  		return nil, fmt.Errorf("failed to resolve unix addr: %w", err)
   123  	}
   124  	// remove potentially existing address
   125  	// Remove unix socket before use, if one is leftover from previous CNI restart
   126  	if err := os.Remove(addr); err != nil && !os.IsNotExist(err) {
   127  		// Anything other than "file not found" is an error.
   128  		return nil, fmt.Errorf("failed to remove unix://%s: %w", addr, err)
   129  	}
   131  	l, err := net.ListenUnix("unixpacket", resolvedAddr)
   132  	if err != nil {
   133  		return nil, fmt.Errorf("failed to listen unix: %w", err)
   134  	}
   136  	return &ztunnelServer{
   137  		listener: l,
   138  		conns: &connMgr{
   139  			connectionSet: map[*ZtunnelConnection]struct{}{},
   140  		},
   141  		pods: pods,
   142  	}, nil
   143  }
   145  func (z *ztunnelServer) Close() error {
   146  	return z.listener.Close()
   147  }
   149  func (z *ztunnelServer) Run(ctx context.Context) {
   150  	context.AfterFunc(ctx, func() { _ = z.Close() })
   152  	for {
   153  		log.Debug("accepting conn")
   154  		conn, err := z.accept()
   155  		if err != nil {
   156  			if errors.Is(err, net.ErrClosed) {
   157  				log.Debug("listener closed - returning")
   158  				return
   159  			}
   161  			log.Errorf("failed to accept conn: %v", err)
   162  			continue
   163  		}
   164  		log.Debug("connection accepted")
   165  		go func() {
   166  			log.Debug("handling conn")
   167  			if err := z.handleConn(ctx, conn); err != nil {
   168  				log.Errorf("failed to handle conn: %v", err)
   169  			}
   170  		}()
   171  	}
   172  }
   174  // ZDS protocol is very simple, for every message sent, and ack is sent.
   175  // the ack only has temporal correlation (i.e. it is the first and only ack msg after the message was sent)
   176  // All this to say, that we want to make sure that message to ztunnel are sent from a single goroutine
   177  // so we don't mix messages and acks.
   178  // nolint: unparam
   179  func (z *ztunnelServer) handleConn(ctx context.Context, conn *ZtunnelConnection) error {
   180  	defer conn.Close()
   182  	context.AfterFunc(ctx, func() {
   183  		log.Debug("context cancelled - closing conn")
   184  		conn.Close()
   185  	})
   187  	// before doing anything, add the connection to the list of active connections
   188  	z.conns.addConn(conn)
   189  	defer z.conns.deleteConn(conn)
   191  	// get hello message from ztunnel
   192  	m, _, err := readProto[zdsapi.ZdsHello](conn.u, readWriteDeadline, nil)
   193  	if err != nil {
   194  		return err
   195  	}
   196  	log.Infof("received hello from ztunnel. %v", m.Version)
   197  	log.Debug("sending snapshot to ztunnel")
   198  	if err := z.sendSnapshot(ctx, conn); err != nil {
   199  		return err
   200  	}
   201  	for {
   202  		// listen for updates:
   203  		select {
   204  		case update, ok := <-conn.Updates:
   205  			if !ok {
   206  				log.Debug("update channel closed - returning")
   207  				return nil
   208  			}
   209  			log.Debugf("got update to send to ztunnel")
   210  			resp, err := conn.sendDataAndWaitForAck(update.Update, update.Fd)
   211  			if err != nil {
   212  				log.Errorf("ztunnel acked error: err %v ackErr %s", err, resp.GetAck().GetError())
   213  			}
   214  			log.Debugf("ztunnel acked")
   215  			// Safety: Resp is buffered, so this will not block
   216  			update.Resp <- updateResponse{
   217  				err:  err,
   218  				resp: resp,
   219  			}
   221  		case <-time.After(ztunnelKeepAliveCheckInterval):
   222  			// do a short read, just to see if the connection to ztunnel is
   223  			// still alive. As ztunnel shouldn't send anything unless we send
   224  			// something first, we expect to get an os.ErrDeadlineExceeded error
   225  			// here if the connection is still alive.
   226  			// note that unlike tcp connections, reading is a good enough test here.
   227  			_, err := conn.readMessage(time.Second / 100)
   228  			switch {
   229  			case !errors.Is(err, os.ErrDeadlineExceeded):
   230  				log.Debugf("ztunnel keepalive failed: %v", err)
   231  				if errors.Is(err, io.EOF) {
   232  					log.Debug("ztunnel EOF")
   233  					return nil
   234  				}
   235  				return err
   236  			case err == nil:
   237  				log.Warn("ztunnel protocol error, unexpected message")
   238  				return fmt.Errorf("ztunnel protocol error, unexpected message")
   239  			default:
   240  				// we get here if error is deadline exceeded, which means ztunnel is alive.
   241  			}
   243  		case <-ctx.Done():
   244  			return nil
   245  		}
   246  	}
   247  }
   249  func (z *ztunnelServer) PodDeleted(ctx context.Context, uid string) error {
   250  	r := &zdsapi.WorkloadRequest{
   251  		Payload: &zdsapi.WorkloadRequest_Del{
   252  			Del: &zdsapi.DelWorkload{
   253  				Uid: uid,
   254  			},
   255  		},
   256  	}
   257  	data, err := proto.Marshal(r)
   258  	if err != nil {
   259  		return err
   260  	}
   262  	log.Debugf("sending delete pod to ztunnel: %s %v", uid, r)
   264  	var delErr []error
   267  	defer
   268  	for conn := range z.conns.connectionSet {
   269  		_, err := conn.send(ctx, data, nil)
   270  		if err != nil {
   271  			delErr = append(delErr, err)
   272  		}
   273  	}
   274  	return errors.Join(delErr...)
   275  }
   277  func podToWorkload(pod *v1.Pod) *zdsapi.WorkloadInfo {
   278  	namespace := pod.ObjectMeta.Namespace
   279  	name := pod.ObjectMeta.Name
   280  	svcAccount := pod.Spec.ServiceAccountName
   281  	trustDomain := spiffe.GetTrustDomain()
   282  	return &zdsapi.WorkloadInfo{
   283  		Namespace:      namespace,
   284  		Name:           name,
   285  		ServiceAccount: svcAccount,
   286  		TrustDomain:    trustDomain,
   287  	}
   288  }
   290  func (z *ztunnelServer) PodAdded(ctx context.Context, pod *v1.Pod, netns Netns) error {
   291  	latestConn := z.conns.LatestConn()
   292  	if latestConn == nil {
   293  		return fmt.Errorf("no ztunnel connection")
   294  	}
   295  	uid := string(pod.ObjectMeta.UID)
   297  	r := &zdsapi.WorkloadRequest{
   298  		Payload: &zdsapi.WorkloadRequest_Add{
   299  			Add: &zdsapi.AddWorkload{
   300  				WorkloadInfo: podToWorkload(pod),
   301  				Uid:          uid,
   302  			},
   303  		},
   304  	}
   305  	log.Infof("About to send added pod: %s to ztunnel: %+v", uid, r)
   306  	data, err := proto.Marshal(r)
   307  	if err != nil {
   308  		return err
   309  	}
   311  	fd := int(netns.Fd())
   312  	resp, err := latestConn.send(ctx, data, &fd)
   313  	if err != nil {
   314  		return err
   315  	}
   317  	if resp.GetAck().GetError() != "" {
   318  		log.Errorf("add-workload: got ack error: %s", resp.GetAck().GetError())
   319  		return fmt.Errorf("got ack error: %s", resp.GetAck().GetError())
   320  	}
   321  	return nil
   322  }
   324  // TODO ctx is unused here
   325  // nolint: unparam
   326  func (z *ztunnelServer) sendSnapshot(ctx context.Context, conn *ZtunnelConnection) error {
   327  	snap := z.pods.ReadCurrentPodSnapshot()
   328  	for uid, wl := range snap {
   329  		var resp *zdsapi.WorkloadResponse
   330  		var err error
   331  		if wl.Netns != nil {
   332  			fd := int(wl.Netns.Fd())
   333  			log.Infof("Sending local pod %s ztunnel", uid)
   334  			resp, err = conn.sendMsgAndWaitForAck(&zdsapi.WorkloadRequest{
   335  				Payload: &zdsapi.WorkloadRequest_Add{
   336  					Add: &zdsapi.AddWorkload{
   337  						Uid:          uid,
   338  						WorkloadInfo: wl.Workload,
   339  					},
   340  				},
   341  			}, &fd)
   342  		} else {
   343  			log.Infof("netns not available for local pod %s. sending keep to ztunnel", uid)
   344  			resp, err = conn.sendMsgAndWaitForAck(&zdsapi.WorkloadRequest{
   345  				Payload: &zdsapi.WorkloadRequest_Keep{
   346  					Keep: &zdsapi.KeepWorkload{
   347  						Uid: uid,
   348  					},
   349  				},
   350  			}, nil)
   351  		}
   352  		if err != nil {
   353  			return err
   354  		}
   355  		if resp.GetAck().GetError() != "" {
   356  			log.Errorf("add-workload: got ack error: %s", resp.GetAck().GetError())
   357  		}
   358  	}
   359  	resp, err := conn.sendMsgAndWaitForAck(&zdsapi.WorkloadRequest{
   360  		Payload: &zdsapi.WorkloadRequest_SnapshotSent{
   361  			SnapshotSent: &zdsapi.SnapshotSent{},
   362  		},
   363  	}, nil)
   364  	if err != nil {
   365  		return err
   366  	}
   367  	log.Debugf("snaptshot sent to ztunnel")
   368  	if resp.GetAck().GetError() != "" {
   369  		log.Errorf("snap-sent: got ack error: %s", resp.GetAck().GetError())
   370  	}
   372  	return nil
   373  }
   375  func (z *ztunnelServer) accept() (*ZtunnelConnection, error) {
   376  	log.Debug("accepting unix conn")
   377  	conn, err := z.listener.AcceptUnix()
   378  	if err != nil {
   379  		return nil, fmt.Errorf("failed to accept unix: %w", err)
   380  	}
   381  	log.Debug("accepted conn")
   382  	return newZtunnelConnection(conn), nil
   383  }
   385  type updateResponse struct {
   386  	err  error
   387  	resp *zdsapi.WorkloadResponse
   388  }
   390  type updateRequest struct {
   391  	Update []byte
   392  	Fd     *int
   394  	Resp chan updateResponse
   395  }
   397  type ZtunnelConnection struct {
   398  	u       *net.UnixConn
   399  	Updates chan updateRequest
   400  }
   402  func newZtunnelConnection(u *net.UnixConn) *ZtunnelConnection {
   403  	return &ZtunnelConnection{u: u, Updates: make(chan updateRequest, 100)}
   404  }
   406  func (z *ZtunnelConnection) Close() {
   407  	z.u.Close()
   408  }
   410  func (z *ZtunnelConnection) send(ctx context.Context, data []byte, fd *int) (*zdsapi.WorkloadResponse, error) {
   411  	ret := make(chan updateResponse, 1)
   412  	req := updateRequest{
   413  		Update: data,
   414  		Fd:     fd,
   415  		Resp:   ret,
   416  	}
   417  	select {
   418  	case z.Updates <- req:
   419  	case <-ctx.Done():
   420  		return nil, ctx.Err()
   421  	}
   423  	select {
   424  	case r := <-ret:
   425  		return r.resp, r.err
   426  	case <-ctx.Done():
   427  		return nil, ctx.Err()
   428  	}
   429  }
   431  func (z *ZtunnelConnection) sendMsgAndWaitForAck(msg *zdsapi.WorkloadRequest, fd *int) (*zdsapi.WorkloadResponse, error) {
   432  	data, err := proto.Marshal(msg)
   433  	if err != nil {
   434  		return nil, err
   435  	}
   436  	return z.sendDataAndWaitForAck(data, fd)
   437  }
   439  func (z *ZtunnelConnection) sendDataAndWaitForAck(data []byte, fd *int) (*zdsapi.WorkloadResponse, error) {
   440  	var rights []byte
   441  	if fd != nil {
   442  		rights = unix.UnixRights(*fd)
   443  	}
   444  	err := z.u.SetWriteDeadline(time.Now().Add(readWriteDeadline))
   445  	if err != nil {
   446  		return nil, err
   447  	}
   449  	_, _, err = z.u.WriteMsgUnix(data, rights, nil)
   450  	if err != nil {
   451  		return nil, err
   452  	}
   454  	// wait for ack
   455  	return z.readMessage(readWriteDeadline)
   456  }
   458  func (z *ZtunnelConnection) readMessage(timeout time.Duration) (*zdsapi.WorkloadResponse, error) {
   459  	m, _, err := readProto[zdsapi.WorkloadResponse](z.u, timeout, nil)
   460  	return m, err
   461  }
   463  func readProto[T any, PT interface {
   464  	proto.Message
   465  	*T
   466  }](c *net.UnixConn, timeout time.Duration, oob []byte) (PT, int, error) {
   467  	var buf [1024]byte
   468  	err := c.SetReadDeadline(time.Now().Add(timeout))
   469  	if err != nil {
   470  		return nil, 0, err
   471  	}
   472  	n, oobn, flags, _, err := c.ReadMsgUnix(buf[:], oob)
   473  	if err != nil {
   474  		return nil, 0, err
   475  	}
   476  	if flags&unix.MSG_TRUNC != 0 {
   477  		return nil, 0, fmt.Errorf("truncated message")
   478  	}
   479  	if flags&unix.MSG_CTRUNC != 0 {
   480  		return nil, 0, fmt.Errorf("truncated control message")
   481  	}
   482  	var resp T
   483  	var respPtr PT = &resp
   484  	err = proto.Unmarshal(buf[:n], respPtr)
   485  	if err != nil {
   486  		return nil, 0, err
   487  	}
   488  	return respPtr, oobn, nil
   489  }