github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/utils/rpcwrapper/rpc_handle.go (about)

     1  package rpcwrapper
     2  
     3  import (
     4  	"context"
     5  	"crypto/hmac"
     6  	"crypto/sha256"
     7  	"encoding/binary"
     8  	"encoding/gob"
     9  	"fmt"
    10  	"net"
    11  	"net/http"
    12  	"net/rpc"
    13  	"os"
    14  	"strconv"
    15  	"sync"
    16  	"time"
    17  
    18  	"github.com/mitchellh/hashstructure"
    19  	"go.aporeto.io/enforcerd/trireme-lib/collector"
    20  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/secrets"
    21  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/usertokens/oidc"
    22  	"go.aporeto.io/enforcerd/trireme-lib/controller/pkg/usertokens/pkitokens"
    23  	"go.aporeto.io/enforcerd/trireme-lib/utils/cache"
    24  	"go.uber.org/zap"
    25  )
    26  
    27  // RPCHdl is a per client handle
    28  type RPCHdl struct {
    29  	Client  *rpc.Client
    30  	Channel string
    31  	Secret  string
    32  }
    33  
    34  // RPCWrapper  is a struct which holds stats for all rpc sesions
    35  type RPCWrapper struct {
    36  	rpcClientMap *cache.Cache
    37  	server       *rpc.Server
    38  	sync.Mutex
    39  }
    40  
    41  // NewRPCWrapper creates a new rpcwrapper
    42  func NewRPCWrapper() *RPCWrapper {
    43  
    44  	RegisterTypes()
    45  
    46  	return &RPCWrapper{
    47  		rpcClientMap: cache.NewCache("RPCWrapper"),
    48  	}
    49  }
    50  
    51  const (
    52  	maxRetries     = 10000
    53  	envRetryString = "REMOTE_RPCRETRIES"
    54  )
    55  
    56  // NewRPCClient exported
    57  func (r *RPCWrapper) NewRPCClient(contextID string, channel string, sharedsecret string) error {
    58  
    59  	r.Lock()
    60  	defer r.Unlock()
    61  
    62  	max := maxRetries
    63  	retries := os.Getenv(envRetryString)
    64  	if len(retries) > 0 {
    65  		max, _ = strconv.Atoi(retries)
    66  	}
    67  
    68  	numRetries := 0
    69  	client, err := rpc.DialHTTP("unix", channel)
    70  	for err != nil {
    71  		numRetries++
    72  		if numRetries >= max {
    73  			return err
    74  		}
    75  
    76  		time.Sleep(5 * time.Millisecond)
    77  		client, err = rpc.DialHTTP("unix", channel)
    78  	}
    79  
    80  	r.rpcClientMap.AddOrUpdate(contextID, &RPCHdl{Client: client, Channel: channel, Secret: sharedsecret})
    81  
    82  	return nil
    83  
    84  }
    85  
    86  // GetRPCClient gets a handle to the rpc client for the contextID( enforcer in the container)
    87  func (r *RPCWrapper) GetRPCClient(contextID string) (*RPCHdl, error) {
    88  
    89  	r.Lock()
    90  	defer r.Unlock()
    91  
    92  	val, err := r.rpcClientMap.Get(contextID)
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  
    97  	return val.(*RPCHdl), nil
    98  }
    99  
   100  // RemoteCall is a wrapper around rpc.Call and also ensure message integrity by adding a hmac
   101  func (r *RPCWrapper) RemoteCall(contextID string, methodName string, req *Request, resp *Response) error {
   102  
   103  	rpcClient, err := r.GetRPCClient(contextID)
   104  	if err != nil {
   105  		return err
   106  	}
   107  
   108  	digest := hmac.New(sha256.New, []byte(rpcClient.Secret))
   109  	hash, err := payloadHash(req.Payload)
   110  	if err != nil {
   111  		return err
   112  	}
   113  
   114  	if _, err := digest.Write(hash); err != nil {
   115  		return err
   116  	}
   117  
   118  	req.HashAuth = digest.Sum(nil)
   119  
   120  	return rpcClient.Client.Call(methodName, req, resp)
   121  }
   122  
   123  // CheckValidity checks if the received message is valid
   124  func (r *RPCWrapper) CheckValidity(req *Request, secret string) bool {
   125  
   126  	digest := hmac.New(sha256.New, []byte(secret))
   127  
   128  	hash, err := payloadHash(req.Payload)
   129  	if err != nil {
   130  		return false
   131  	}
   132  
   133  	if _, err := digest.Write(hash); err != nil {
   134  		return false
   135  	}
   136  
   137  	return hmac.Equal(req.HashAuth, digest.Sum(nil))
   138  }
   139  
   140  //NewRPCServer returns an interface RPCServer
   141  func NewRPCServer() RPCServer {
   142  
   143  	return &RPCWrapper{
   144  		server: rpc.NewServer(),
   145  	}
   146  }
   147  
   148  // StartServer Starts a server and waits for new connections this function never returns
   149  func (r *RPCWrapper) StartServer(ctx context.Context, protocol string, path string, handler interface{}) error {
   150  
   151  	if len(path) == 0 {
   152  		zap.L().Fatal("Sock param not passed in environment")
   153  	}
   154  
   155  	// Register RPC Type
   156  	RegisterTypes()
   157  
   158  	// Register handlers
   159  	if err := r.server.Register(handler); err != nil {
   160  		return err
   161  	}
   162  
   163  	r.server.HandleHTTP(rpc.DefaultRPCPath, rpc.DefaultDebugPath)
   164  
   165  	// removing old path in case it exists already - error if we can't remove it
   166  	if _, err := os.Stat(path); err == nil {
   167  
   168  		zap.L().Debug("Socket path already exists: removing", zap.String("path", path))
   169  
   170  		if rerr := os.Remove(path); rerr != nil {
   171  			return fmt.Errorf("unable to delete existing socket path %s: %s", path, rerr)
   172  		}
   173  	}
   174  
   175  	// Get listener
   176  	listen, err := net.Listen(protocol, path)
   177  	if err != nil {
   178  		return err
   179  	}
   180  
   181  	go http.Serve(listen, nil) // nolint
   182  
   183  	<-ctx.Done()
   184  
   185  	if merr := listen.Close(); merr != nil {
   186  		zap.L().Warn("Connection already closed", zap.Error(merr))
   187  	}
   188  
   189  	if _, err = os.Stat(path); !os.IsNotExist(err) {
   190  		if err := os.Remove(path); err != nil {
   191  			zap.L().Warn("failed to remove old path", zap.Error(err))
   192  		}
   193  	}
   194  
   195  	return nil
   196  }
   197  
   198  // DestroyRPCClient calls close on the rpc and cleans up the connection
   199  func (r *RPCWrapper) DestroyRPCClient(contextID string) {
   200  	r.Lock()
   201  	defer r.Unlock()
   202  
   203  	rpcHdl, err := r.rpcClientMap.Get(contextID)
   204  	if err != nil {
   205  		return
   206  	}
   207  
   208  	if err = rpcHdl.(*RPCHdl).Client.Close(); err != nil {
   209  		zap.L().Warn("Failed to close channel",
   210  			zap.String("contextID", contextID),
   211  			zap.Error(err),
   212  		)
   213  	}
   214  
   215  	if err = os.Remove(rpcHdl.(*RPCHdl).Channel); err != nil {
   216  		zap.L().Debug("Failed to remove channel - already closed",
   217  			zap.String("contextID", contextID),
   218  			zap.Error(err),
   219  		)
   220  	}
   221  
   222  	if err = r.rpcClientMap.Remove(contextID); err != nil {
   223  		zap.L().Warn("Failed to remove item from cache",
   224  			zap.String("contextID", contextID),
   225  			zap.Error(err),
   226  		)
   227  	}
   228  }
   229  
   230  // ContextList returns the list of active context managed by the rpcwrapper
   231  func (r *RPCWrapper) ContextList() []string {
   232  	keylist := r.rpcClientMap.KeyList()
   233  	contextArray := []string{}
   234  	for _, key := range keylist {
   235  		if kstring, ok := key.(string); ok {
   236  			contextArray = append(contextArray, kstring)
   237  		}
   238  	}
   239  	return contextArray
   240  }
   241  
   242  // ProcessMessage checks if the given request is valid
   243  func (r *RPCWrapper) ProcessMessage(req *Request, secret string) bool {
   244  
   245  	return r.CheckValidity(req, secret)
   246  }
   247  
   248  // payloadHash returns the has of the payload
   249  func payloadHash(payload interface{}) ([]byte, error) {
   250  	hash, err := hashstructure.Hash(payload, nil)
   251  	if err != nil {
   252  		return []byte{}, err
   253  	}
   254  
   255  	buf := make([]byte, 8)
   256  	binary.BigEndian.PutUint64(buf, hash)
   257  	return buf, nil
   258  }
   259  
   260  // RegisterTypes  registers types that are exchanged between the controller and remoteenforcer
   261  func RegisterTypes() {
   262  	// TODO: figure out why the RegisterName() calls are written as *(&x{}) when the Register() calls are just &x{}
   263  	gob.Register(&secrets.RPCSecrets{})
   264  	gob.Register(&pkitokens.PKIJWTVerifier{})
   265  	gob.Register(&oidc.TokenVerifier{})
   266  	gob.Register(&collector.CounterReport{})
   267  	gob.Register(&collector.PingReport{})
   268  	gob.Register(&collector.DNSRequestReport{})
   269  	gob.Register(&collector.PacketReport{})
   270  	gob.Register(&collector.ConnectionExceptionReport{})
   271  	gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.Init_Request_Payload", *(&InitRequestPayload{}))                                // nolint:staticcheck // SA4001: *&x will be simplified to x. It will not copy x.
   272  	gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.Enforce_Payload", *(&EnforcePayload{}))                                         // nolint:staticcheck
   273  	gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.UnEnforce_Payload", *(&UnEnforcePayload{}))                                     // nolint:staticcheck
   274  	gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.Stats_Payload", *(&StatsPayload{}))                                             // nolint:staticcheck
   275  	gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.UpdateSecrets_Payload", *(&UpdateSecretsPayload{}))                             // nolint:staticcheck
   276  	gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.SetTargetNetworks_Payload", *(&SetTargetNetworksPayload{}))                     // nolint:staticcheck
   277  	gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.EnableIPTablesPacketTracing_PayLoad", *(&EnableIPTablesPacketTracingPayLoad{})) // nolint:staticcheck
   278  	gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.EnableDatapathPacketTracing_PayLoad", *(&EnableDatapathPacketTracingPayLoad{})) // nolint:staticcheck
   279  	gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.Report_Payload", *(&ReportPayload{}))                                           // nolint:staticcheck
   280  	gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.SetLogLevel_Payload", *(&SetLogLevelPayload{}))                                 // nolint:staticcheck
   281  	gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.TokenRequest_Payload", *(&TokenRequestPayload{}))                               // nolint:staticcheck
   282  	gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.TokenResponse_Payload", *(&TokenResponsePayload{}))                             // nolint:staticcheck
   283  	gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.Ping_Payload", *(&PingPayload{}))                                               // nolint:staticcheck
   284  	gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.DebugCollect_Payload", *(&DebugCollectPayload{}))                               // nolint:staticcheck
   285  	gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.DebugCollectResponse_Payload", *(&DebugCollectResponsePayload{}))               // nolint:staticcheck
   286  }