github.com/Axway/agent-sdk@v1.1.101/pkg/watchmanager/client.go (about) 1 package watchmanager 2 3 import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "math/big" 8 "sync" 9 "time" 10 11 "google.golang.org/grpc" 12 13 "github.com/Axway/agent-sdk/pkg/watchmanager/proto" 14 "github.com/golang-jwt/jwt" 15 ) 16 17 type clientConfig struct { 18 errors chan error 19 events chan *proto.Event 20 tokenGetter TokenGetter 21 topicSelfLink string 22 } 23 24 type watchClient struct { 25 cancelStreamCtx context.CancelFunc 26 cfg clientConfig 27 getTokenExpirationTime getTokenExpFunc 28 isRunning bool 29 stream proto.Watch_SubscribeClient 30 streamCtx context.Context 31 timer *time.Timer 32 mutex sync.Mutex 33 } 34 35 // newWatchClientFunc func signature to create a watch client 36 type newWatchClientFunc func(cc grpc.ClientConnInterface) proto.WatchClient 37 38 type getTokenExpFunc func(token string) (time.Duration, error) 39 40 func newWatchClient(cc grpc.ClientConnInterface, clientCfg clientConfig, newClient newWatchClientFunc) (*watchClient, error) { 41 svcClient := newClient(cc) 42 43 streamCtx, streamCancel := context.WithCancel(context.Background()) 44 stream, err := svcClient.Subscribe(streamCtx) 45 if err != nil { 46 streamCancel() 47 return nil, err 48 } 49 50 client := &watchClient{ 51 cancelStreamCtx: streamCancel, 52 cfg: clientCfg, 53 getTokenExpirationTime: getTokenExpirationTime, 54 isRunning: true, 55 stream: stream, 56 streamCtx: streamCtx, 57 timer: time.NewTimer(0), 58 } 59 60 return client, nil 61 } 62 63 // processEvents process incoming chimera events 64 func (c *watchClient) processEvents() { 65 for { 66 err := c.recv() 67 if err != nil { 68 c.handleError(err) 69 return 70 } 71 } 72 } 73 74 // recv blocks until an event is received 75 func (c *watchClient) recv() error { 76 event, err := c.stream.Recv() 77 if err != nil { 78 return err 79 } 80 c.cfg.events <- event 81 return nil 82 } 83 84 // processRequest sends a message to the client when the timer expires, and handles when the stream is closed. 85 func (c *watchClient) processRequest() error { 86 var err error 87 wg := sync.WaitGroup{} 88 wg.Add(1) 89 wait := true 90 go func() { 91 for { 92 select { 93 case <-c.streamCtx.Done(): 94 c.handleError(c.streamCtx.Err()) 95 return 96 case <-c.stream.Context().Done(): 97 c.handleError(c.stream.Context().Err()) 98 return 99 case <-c.timer.C: 100 err = c.send() 101 if wait { 102 wg.Done() 103 wait = false 104 } 105 if err != nil { 106 c.handleError(err) 107 return 108 } 109 } 110 } 111 }() 112 113 wg.Wait() 114 return err 115 } 116 117 // send a message with a new token to the grpc server and returns the expiration time 118 func (c *watchClient) send() error { 119 c.timer.Stop() 120 121 token, err := c.cfg.tokenGetter() 122 if err != nil { 123 return err 124 } 125 126 exp, err := c.getTokenExpirationTime(token) 127 if err != nil { 128 return err 129 } 130 131 req := createWatchRequest(c.cfg.topicSelfLink, token) 132 err = c.stream.Send(req) 133 if err != nil { 134 return err 135 } 136 c.timer.Reset(exp) 137 return nil 138 } 139 140 // handleError stop the running timer, send to the error channel, and close the open stream. 141 func (c *watchClient) handleError(err error) { 142 c.mutex.Lock() 143 defer c.mutex.Unlock() 144 145 if c.isRunning { 146 c.isRunning = false 147 c.timer.Stop() 148 c.cfg.errors <- err 149 c.cancelStreamCtx() 150 } 151 } 152 153 func createWatchRequest(watchTopicSelfLink, token string) *proto.Request { 154 return &proto.Request{ 155 SelfLink: watchTopicSelfLink, 156 Token: "Bearer " + token, 157 } 158 } 159 160 func getTokenExpirationTime(token string) (time.Duration, error) { 161 parser := new(jwt.Parser) 162 parser.SkipClaimsValidation = true 163 164 claims := jwt.MapClaims{} 165 _, _, err := parser.ParseUnverified(token, claims) 166 if err != nil { 167 return time.Duration(0), fmt.Errorf("getTokenExpirationTime failed to parse token: %s", err) 168 } 169 170 var tm time.Time 171 switch exp := claims["exp"].(type) { 172 case float64: 173 tm = time.Unix(int64(exp), 0) 174 case json.Number: 175 v, _ := exp.Int64() 176 tm = time.Unix(v, 0) 177 } 178 179 exp := time.Until(tm) 180 // use big.NewInt to avoid an int overflow 181 i := big.NewInt(int64(exp)) 182 i = i.Mul(i, big.NewInt(4)) 183 i = i.Div(i, big.NewInt(5)) 184 d := time.Duration(i.Int64()) 185 186 if d.Milliseconds() < 0 { 187 return time.Duration(0), fmt.Errorf("token is expired") 188 } 189 return d, nil 190 }