github.com/twilio/twilio-go@v1.20.1/client/jwt/taskrouter/capability_token.go (about)

     1  package taskrouter
     2  
     3  import (
     4  	"encoding/json"
     5  	"fmt"
     6  	"reflect"
     7  	"strings"
     8  	"time"
     9  
    10  	baseJwt "github.com/twilio/twilio-go/client/jwt"
    11  	. "github.com/twilio/twilio-go/client/jwt/util"
    12  )
    13  
    14  type CapabilityToken struct {
    15  	baseJwt      *baseJwt.Jwt
    16  	AccountSid   string
    17  	AuthToken    string
    18  	WorkspaceSid string
    19  	ChannelID    string
    20  	Policies     []Policy
    21  }
    22  
    23  type CapabilityTokenParams struct {
    24  	// Twilio Account sid
    25  	AccountSid string
    26  	// Twilio auth token used to sign the JWT
    27  	AuthToken string
    28  	// TaskRouter Workspace SID
    29  	WorkspaceSid string
    30  	// TaskRouter Channel SID
    31  	ChannelID string
    32  	// Time in secs since epoch before which this JWT is invalid, defaults to now
    33  	Nbf float64
    34  	// Time to live of the JWT in seconds, defaults to 1 hour
    35  	Ttl float64
    36  	// Time in secs since epoch this JWT is valid for. Overrides ttl if provided.
    37  	ValidUntil float64
    38  }
    39  
    40  // Create Capability Token for TaskRouter
    41  func CreateCapabilityToken(params CapabilityTokenParams) CapabilityToken {
    42  	return CapabilityToken{
    43  		baseJwt: &baseJwt.Jwt{
    44  			SecretKey:  params.AuthToken,
    45  			Issuer:     params.AccountSid,
    46  			Subject:    "",
    47  			Algorithm:  HS256,
    48  			Nbf:        params.Nbf,
    49  			Ttl:        Max(params.Ttl, 3600),
    50  			ValidUntil: params.ValidUntil,
    51  		},
    52  		AccountSid:   params.AccountSid,
    53  		AuthToken:    params.AuthToken,
    54  		WorkspaceSid: params.WorkspaceSid,
    55  		ChannelID:    params.ChannelID,
    56  		Policies:     make([]Policy, 0),
    57  	}
    58  }
    59  
    60  func (token *CapabilityToken) AddPolicy(policy Policy) {
    61  	token.Policies = append(token.Policies, policy)
    62  }
    63  
    64  func (token *CapabilityToken) generatePayload() map[string]interface{} {
    65  	now := float64(time.Now().Unix())
    66  
    67  	// These are required since we want to authenticate and authorize the opening of a websocket in the first place.
    68  	// Subsequent events to GET, POST or DELETE to other APIs will utilize this websocket.
    69  	defaultPolicies := WebSocketPolicies(token.AccountSid, token.ChannelID)
    70  	token.Policies = append(token.Policies, defaultPolicies...)
    71  
    72  	payload := map[string]interface{}{
    73  		"version": Version,
    74  	}
    75  	if token.AccountSid != "" {
    76  		payload["account_sid"] = token.AccountSid
    77  	}
    78  	if token.WorkspaceSid != "" {
    79  		payload["workspace_sid"] = token.WorkspaceSid
    80  	}
    81  	if token.ChannelID != "" {
    82  		payload["channel"] = token.ChannelID
    83  		payload["friendly_name"] = token.ChannelID
    84  	}
    85  
    86  	var policies []map[string]interface{}
    87  	for _, policy := range token.Policies {
    88  		policyPayload := policy.Payload()
    89  		policies = append(policies, policyPayload)
    90  	}
    91  
    92  	if len(policies) > 0 {
    93  		payload["policies"] = policies
    94  	}
    95  
    96  	payload["iss"] = token.baseJwt.Issuer
    97  	payload["exp"] = now + token.baseJwt.Ttl
    98  	if token.baseJwt.Nbf != 0 {
    99  		payload["nbf"] = token.baseJwt.Nbf
   100  	} else {
   101  		payload["nbf"] = now
   102  	}
   103  	if token.baseJwt.ValidUntil != 0 {
   104  		payload["exp"] = token.baseJwt.ValidUntil
   105  	}
   106  	if strings.HasPrefix(token.ChannelID, "WK") {
   107  		payload["worker_sid"] = token.ChannelID
   108  	} else if strings.HasPrefix(token.ChannelID, "WQ") {
   109  		payload["taskqueue_sid"] = token.ChannelID
   110  	}
   111  
   112  	return payload
   113  }
   114  
   115  func (token *CapabilityToken) ToString() string {
   116  	signedStr, err := token.ToJwt()
   117  	if err != nil {
   118  		return ""
   119  	}
   120  	return fmt.Sprintf("<TaskRouterCapabilityToken %s>", signedStr)
   121  }
   122  
   123  // Encode the JWT struct into a string.
   124  func (token *CapabilityToken) ToJwt() (string, error) {
   125  	signedToken, err := token.baseJwt.ToJwt(token.baseJwt.Headers, token.generatePayload)
   126  	if err != nil {
   127  		return "", err
   128  	}
   129  
   130  	return signedToken, nil
   131  }
   132  
   133  // Get the decoded token back from the jwt String
   134  func (token *CapabilityToken) FromJwt(jwtStr string, key string) (*CapabilityToken, error) {
   135  	baseToken, err := token.baseJwt.FromJwt(jwtStr, key)
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  
   140  	return &CapabilityToken{
   141  		baseJwt:      baseToken,
   142  		AccountSid:   baseToken.Issuer,
   143  		AuthToken:    baseToken.SecretKey,
   144  		WorkspaceSid: baseToken.Payload()["workspace_sid"].(string),
   145  		ChannelID:    baseToken.Payload()["channel"].(string),
   146  		Policies:     decodePolicies(baseToken.Payload()["policies"]),
   147  	}, nil
   148  }
   149  
   150  func decodePolicies(policies interface{}) []Policy {
   151  	var decodedPolicies []Policy
   152  	switch reflect.TypeOf(policies).Kind() {
   153  	case reflect.Slice:
   154  		s := reflect.ValueOf(policies)
   155  
   156  		for i := 0; i < s.Len(); i++ {
   157  			var pol Policy
   158  			val := s.Index(i).Interface().(map[string]interface{})
   159  			if data, err := json.Marshal(val); err == nil {
   160  				if errJson := json.Unmarshal(data, &pol); errJson == nil {
   161  					decodedPolicies = append(decodedPolicies, pol)
   162  				}
   163  			}
   164  		}
   165  	}
   166  
   167  	return decodedPolicies
   168  }
   169  
   170  func (token *CapabilityToken) Headers() map[string]interface{} {
   171  	if token.baseJwt.DecodedHeaders == nil {
   172  		token.baseJwt.DecodedHeaders = token.baseJwt.Headers()
   173  	}
   174  
   175  	return token.baseJwt.DecodedHeaders
   176  }
   177  
   178  func (token *CapabilityToken) Payload() map[string]interface{} {
   179  	if token.baseJwt.DecodedPayload == nil {
   180  		token.baseJwt.DecodedPayload = token.generatePayload()
   181  	}
   182  
   183  	return token.baseJwt.DecodedPayload
   184  }