github.com/aporeto-inc/trireme-lib@v10.358.0+incompatible/controller/internal/enforcer/utils/rpcwrapper/rpc_handle.go (about) 1 package rpcwrapper 2 3 import ( 4 "context" 5 "crypto/hmac" 6 "crypto/sha256" 7 "encoding/binary" 8 "encoding/gob" 9 "fmt" 10 "net" 11 "net/http" 12 "net/rpc" 13 "os" 14 "strconv" 15 "sync" 16 "time" 17 18 "github.com/mitchellh/hashstructure" 19 "go.aporeto.io/enforcerd/trireme-lib/collector" 20 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/secrets" 21 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/usertokens/oidc" 22 "go.aporeto.io/enforcerd/trireme-lib/controller/pkg/usertokens/pkitokens" 23 "go.aporeto.io/enforcerd/trireme-lib/utils/cache" 24 "go.uber.org/zap" 25 ) 26 27 // RPCHdl is a per client handle 28 type RPCHdl struct { 29 Client *rpc.Client 30 Channel string 31 Secret string 32 } 33 34 // RPCWrapper is a struct which holds stats for all rpc sesions 35 type RPCWrapper struct { 36 rpcClientMap *cache.Cache 37 server *rpc.Server 38 sync.Mutex 39 } 40 41 // NewRPCWrapper creates a new rpcwrapper 42 func NewRPCWrapper() *RPCWrapper { 43 44 RegisterTypes() 45 46 return &RPCWrapper{ 47 rpcClientMap: cache.NewCache("RPCWrapper"), 48 } 49 } 50 51 const ( 52 maxRetries = 10000 53 envRetryString = "REMOTE_RPCRETRIES" 54 ) 55 56 // NewRPCClient exported 57 func (r *RPCWrapper) NewRPCClient(contextID string, channel string, sharedsecret string) error { 58 59 r.Lock() 60 defer r.Unlock() 61 62 max := maxRetries 63 retries := os.Getenv(envRetryString) 64 if len(retries) > 0 { 65 max, _ = strconv.Atoi(retries) 66 } 67 68 numRetries := 0 69 client, err := rpc.DialHTTP("unix", channel) 70 for err != nil { 71 numRetries++ 72 if numRetries >= max { 73 return err 74 } 75 76 time.Sleep(5 * time.Millisecond) 77 client, err = rpc.DialHTTP("unix", channel) 78 } 79 80 r.rpcClientMap.AddOrUpdate(contextID, &RPCHdl{Client: client, Channel: channel, Secret: sharedsecret}) 81 82 return nil 83 84 } 85 86 // GetRPCClient gets a handle to the rpc client for the contextID( enforcer in the container) 87 func (r *RPCWrapper) GetRPCClient(contextID string) (*RPCHdl, error) { 88 89 r.Lock() 90 defer r.Unlock() 91 92 val, err := r.rpcClientMap.Get(contextID) 93 if err != nil { 94 return nil, err 95 } 96 97 return val.(*RPCHdl), nil 98 } 99 100 // RemoteCall is a wrapper around rpc.Call and also ensure message integrity by adding a hmac 101 func (r *RPCWrapper) RemoteCall(contextID string, methodName string, req *Request, resp *Response) error { 102 103 rpcClient, err := r.GetRPCClient(contextID) 104 if err != nil { 105 return err 106 } 107 108 digest := hmac.New(sha256.New, []byte(rpcClient.Secret)) 109 hash, err := payloadHash(req.Payload) 110 if err != nil { 111 return err 112 } 113 114 if _, err := digest.Write(hash); err != nil { 115 return err 116 } 117 118 req.HashAuth = digest.Sum(nil) 119 120 return rpcClient.Client.Call(methodName, req, resp) 121 } 122 123 // CheckValidity checks if the received message is valid 124 func (r *RPCWrapper) CheckValidity(req *Request, secret string) bool { 125 126 digest := hmac.New(sha256.New, []byte(secret)) 127 128 hash, err := payloadHash(req.Payload) 129 if err != nil { 130 return false 131 } 132 133 if _, err := digest.Write(hash); err != nil { 134 return false 135 } 136 137 return hmac.Equal(req.HashAuth, digest.Sum(nil)) 138 } 139 140 //NewRPCServer returns an interface RPCServer 141 func NewRPCServer() RPCServer { 142 143 return &RPCWrapper{ 144 server: rpc.NewServer(), 145 } 146 } 147 148 // StartServer Starts a server and waits for new connections this function never returns 149 func (r *RPCWrapper) StartServer(ctx context.Context, protocol string, path string, handler interface{}) error { 150 151 if len(path) == 0 { 152 zap.L().Fatal("Sock param not passed in environment") 153 } 154 155 // Register RPC Type 156 RegisterTypes() 157 158 // Register handlers 159 if err := r.server.Register(handler); err != nil { 160 return err 161 } 162 163 r.server.HandleHTTP(rpc.DefaultRPCPath, rpc.DefaultDebugPath) 164 165 // removing old path in case it exists already - error if we can't remove it 166 if _, err := os.Stat(path); err == nil { 167 168 zap.L().Debug("Socket path already exists: removing", zap.String("path", path)) 169 170 if rerr := os.Remove(path); rerr != nil { 171 return fmt.Errorf("unable to delete existing socket path %s: %s", path, rerr) 172 } 173 } 174 175 // Get listener 176 listen, err := net.Listen(protocol, path) 177 if err != nil { 178 return err 179 } 180 181 go http.Serve(listen, nil) // nolint 182 183 <-ctx.Done() 184 185 if merr := listen.Close(); merr != nil { 186 zap.L().Warn("Connection already closed", zap.Error(merr)) 187 } 188 189 if _, err = os.Stat(path); !os.IsNotExist(err) { 190 if err := os.Remove(path); err != nil { 191 zap.L().Warn("failed to remove old path", zap.Error(err)) 192 } 193 } 194 195 return nil 196 } 197 198 // DestroyRPCClient calls close on the rpc and cleans up the connection 199 func (r *RPCWrapper) DestroyRPCClient(contextID string) { 200 r.Lock() 201 defer r.Unlock() 202 203 rpcHdl, err := r.rpcClientMap.Get(contextID) 204 if err != nil { 205 return 206 } 207 208 if err = rpcHdl.(*RPCHdl).Client.Close(); err != nil { 209 zap.L().Warn("Failed to close channel", 210 zap.String("contextID", contextID), 211 zap.Error(err), 212 ) 213 } 214 215 if err = os.Remove(rpcHdl.(*RPCHdl).Channel); err != nil { 216 zap.L().Debug("Failed to remove channel - already closed", 217 zap.String("contextID", contextID), 218 zap.Error(err), 219 ) 220 } 221 222 if err = r.rpcClientMap.Remove(contextID); err != nil { 223 zap.L().Warn("Failed to remove item from cache", 224 zap.String("contextID", contextID), 225 zap.Error(err), 226 ) 227 } 228 } 229 230 // ContextList returns the list of active context managed by the rpcwrapper 231 func (r *RPCWrapper) ContextList() []string { 232 keylist := r.rpcClientMap.KeyList() 233 contextArray := []string{} 234 for _, key := range keylist { 235 if kstring, ok := key.(string); ok { 236 contextArray = append(contextArray, kstring) 237 } 238 } 239 return contextArray 240 } 241 242 // ProcessMessage checks if the given request is valid 243 func (r *RPCWrapper) ProcessMessage(req *Request, secret string) bool { 244 245 return r.CheckValidity(req, secret) 246 } 247 248 // payloadHash returns the has of the payload 249 func payloadHash(payload interface{}) ([]byte, error) { 250 hash, err := hashstructure.Hash(payload, nil) 251 if err != nil { 252 return []byte{}, err 253 } 254 255 buf := make([]byte, 8) 256 binary.BigEndian.PutUint64(buf, hash) 257 return buf, nil 258 } 259 260 // RegisterTypes registers types that are exchanged between the controller and remoteenforcer 261 func RegisterTypes() { 262 // TODO: figure out why the RegisterName() calls are written as *(&x{}) when the Register() calls are just &x{} 263 gob.Register(&secrets.RPCSecrets{}) 264 gob.Register(&pkitokens.PKIJWTVerifier{}) 265 gob.Register(&oidc.TokenVerifier{}) 266 gob.Register(&collector.CounterReport{}) 267 gob.Register(&collector.PingReport{}) 268 gob.Register(&collector.DNSRequestReport{}) 269 gob.Register(&collector.PacketReport{}) 270 gob.Register(&collector.ConnectionExceptionReport{}) 271 gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.Init_Request_Payload", *(&InitRequestPayload{})) // nolint:staticcheck // SA4001: *&x will be simplified to x. It will not copy x. 272 gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.Enforce_Payload", *(&EnforcePayload{})) // nolint:staticcheck 273 gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.UnEnforce_Payload", *(&UnEnforcePayload{})) // nolint:staticcheck 274 gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.Stats_Payload", *(&StatsPayload{})) // nolint:staticcheck 275 gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.UpdateSecrets_Payload", *(&UpdateSecretsPayload{})) // nolint:staticcheck 276 gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.SetTargetNetworks_Payload", *(&SetTargetNetworksPayload{})) // nolint:staticcheck 277 gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.EnableIPTablesPacketTracing_PayLoad", *(&EnableIPTablesPacketTracingPayLoad{})) // nolint:staticcheck 278 gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.EnableDatapathPacketTracing_PayLoad", *(&EnableDatapathPacketTracingPayLoad{})) // nolint:staticcheck 279 gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.Report_Payload", *(&ReportPayload{})) // nolint:staticcheck 280 gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.SetLogLevel_Payload", *(&SetLogLevelPayload{})) // nolint:staticcheck 281 gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.TokenRequest_Payload", *(&TokenRequestPayload{})) // nolint:staticcheck 282 gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.TokenResponse_Payload", *(&TokenResponsePayload{})) // nolint:staticcheck 283 gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.Ping_Payload", *(&PingPayload{})) // nolint:staticcheck 284 gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.DebugCollect_Payload", *(&DebugCollectPayload{})) // nolint:staticcheck 285 gob.RegisterName("go.aporeto.io/enforcerd/trireme-lib/controller/internal/enforcer/utils/rpcwrapper.DebugCollectResponse_Payload", *(&DebugCollectResponsePayload{})) // nolint:staticcheck 286 }