github.com/Axway/agent-sdk@v1.1.101/pkg/watchmanager/client.go (about)

     1  package watchmanager
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"math/big"
     8  	"sync"
     9  	"time"
    10  
    11  	"google.golang.org/grpc"
    12  
    13  	"github.com/Axway/agent-sdk/pkg/watchmanager/proto"
    14  	"github.com/golang-jwt/jwt"
    15  )
    16  
    17  type clientConfig struct {
    18  	errors        chan error
    19  	events        chan *proto.Event
    20  	tokenGetter   TokenGetter
    21  	topicSelfLink string
    22  }
    23  
    24  type watchClient struct {
    25  	cancelStreamCtx        context.CancelFunc
    26  	cfg                    clientConfig
    27  	getTokenExpirationTime getTokenExpFunc
    28  	isRunning              bool
    29  	stream                 proto.Watch_SubscribeClient
    30  	streamCtx              context.Context
    31  	timer                  *time.Timer
    32  	mutex                  sync.Mutex
    33  }
    34  
    35  // newWatchClientFunc func signature to create a watch client
    36  type newWatchClientFunc func(cc grpc.ClientConnInterface) proto.WatchClient
    37  
    38  type getTokenExpFunc func(token string) (time.Duration, error)
    39  
    40  func newWatchClient(cc grpc.ClientConnInterface, clientCfg clientConfig, newClient newWatchClientFunc) (*watchClient, error) {
    41  	svcClient := newClient(cc)
    42  
    43  	streamCtx, streamCancel := context.WithCancel(context.Background())
    44  	stream, err := svcClient.Subscribe(streamCtx)
    45  	if err != nil {
    46  		streamCancel()
    47  		return nil, err
    48  	}
    49  
    50  	client := &watchClient{
    51  		cancelStreamCtx:        streamCancel,
    52  		cfg:                    clientCfg,
    53  		getTokenExpirationTime: getTokenExpirationTime,
    54  		isRunning:              true,
    55  		stream:                 stream,
    56  		streamCtx:              streamCtx,
    57  		timer:                  time.NewTimer(0),
    58  	}
    59  
    60  	return client, nil
    61  }
    62  
    63  // processEvents process incoming chimera events
    64  func (c *watchClient) processEvents() {
    65  	for {
    66  		err := c.recv()
    67  		if err != nil {
    68  			c.handleError(err)
    69  			return
    70  		}
    71  	}
    72  }
    73  
    74  // recv blocks until an event is received
    75  func (c *watchClient) recv() error {
    76  	event, err := c.stream.Recv()
    77  	if err != nil {
    78  		return err
    79  	}
    80  	c.cfg.events <- event
    81  	return nil
    82  }
    83  
    84  // processRequest sends a message to the client when the timer expires, and handles when the stream is closed.
    85  func (c *watchClient) processRequest() error {
    86  	var err error
    87  	wg := sync.WaitGroup{}
    88  	wg.Add(1)
    89  	wait := true
    90  	go func() {
    91  		for {
    92  			select {
    93  			case <-c.streamCtx.Done():
    94  				c.handleError(c.streamCtx.Err())
    95  				return
    96  			case <-c.stream.Context().Done():
    97  				c.handleError(c.stream.Context().Err())
    98  				return
    99  			case <-c.timer.C:
   100  				err = c.send()
   101  				if wait {
   102  					wg.Done()
   103  					wait = false
   104  				}
   105  				if err != nil {
   106  					c.handleError(err)
   107  					return
   108  				}
   109  			}
   110  		}
   111  	}()
   112  
   113  	wg.Wait()
   114  	return err
   115  }
   116  
   117  // send a message with a new token to the grpc server and returns the expiration time
   118  func (c *watchClient) send() error {
   119  	c.timer.Stop()
   120  
   121  	token, err := c.cfg.tokenGetter()
   122  	if err != nil {
   123  		return err
   124  	}
   125  
   126  	exp, err := c.getTokenExpirationTime(token)
   127  	if err != nil {
   128  		return err
   129  	}
   130  
   131  	req := createWatchRequest(c.cfg.topicSelfLink, token)
   132  	err = c.stream.Send(req)
   133  	if err != nil {
   134  		return err
   135  	}
   136  	c.timer.Reset(exp)
   137  	return nil
   138  }
   139  
   140  // handleError stop the running timer, send to the error channel, and close the open stream.
   141  func (c *watchClient) handleError(err error) {
   142  	c.mutex.Lock()
   143  	defer c.mutex.Unlock()
   144  
   145  	if c.isRunning {
   146  		c.isRunning = false
   147  		c.timer.Stop()
   148  		c.cfg.errors <- err
   149  		c.cancelStreamCtx()
   150  	}
   151  }
   152  
   153  func createWatchRequest(watchTopicSelfLink, token string) *proto.Request {
   154  	return &proto.Request{
   155  		SelfLink: watchTopicSelfLink,
   156  		Token:    "Bearer " + token,
   157  	}
   158  }
   159  
   160  func getTokenExpirationTime(token string) (time.Duration, error) {
   161  	parser := new(jwt.Parser)
   162  	parser.SkipClaimsValidation = true
   163  
   164  	claims := jwt.MapClaims{}
   165  	_, _, err := parser.ParseUnverified(token, claims)
   166  	if err != nil {
   167  		return time.Duration(0), fmt.Errorf("getTokenExpirationTime failed to parse token: %s", err)
   168  	}
   169  
   170  	var tm time.Time
   171  	switch exp := claims["exp"].(type) {
   172  	case float64:
   173  		tm = time.Unix(int64(exp), 0)
   174  	case json.Number:
   175  		v, _ := exp.Int64()
   176  		tm = time.Unix(v, 0)
   177  	}
   178  
   179  	exp := time.Until(tm)
   180  	// use big.NewInt to avoid an int overflow
   181  	i := big.NewInt(int64(exp))
   182  	i = i.Mul(i, big.NewInt(4))
   183  	i = i.Div(i, big.NewInt(5))
   184  	d := time.Duration(i.Int64())
   185  
   186  	if d.Milliseconds() < 0 {
   187  		return time.Duration(0), fmt.Errorf("token is expired")
   188  	}
   189  	return d, nil
   190  }