github.com/nsqio/nsq@v1.3.0/internal/auth/authorizations.go (about)

     1  package auth
     2  
     3  import (
     4  	"crypto/tls"
     5  	"errors"
     6  	"fmt"
     7  	"math/rand"
     8  	"net/url"
     9  	"regexp"
    10  	"strings"
    11  	"time"
    12  
    13  	"github.com/nsqio/nsq/internal/http_api"
    14  )
    15  
    16  type Authorization struct {
    17  	Topic       string   `json:"topic"`
    18  	Channels    []string `json:"channels"`
    19  	Permissions []string `json:"permissions"`
    20  }
    21  
    22  type State struct {
    23  	TTL            int             `json:"ttl"`
    24  	Authorizations []Authorization `json:"authorizations"`
    25  	Identity       string          `json:"identity"`
    26  	IdentityURL    string          `json:"identity_url"`
    27  	Expires        time.Time
    28  }
    29  
    30  func (a *Authorization) HasPermission(permission string) bool {
    31  	for _, p := range a.Permissions {
    32  		if permission == p {
    33  			return true
    34  		}
    35  	}
    36  	return false
    37  }
    38  
    39  func (a *Authorization) IsAllowed(topic, channel string) bool {
    40  	if channel != "" {
    41  		if !a.HasPermission("subscribe") {
    42  			return false
    43  		}
    44  	} else {
    45  		if !a.HasPermission("publish") {
    46  			return false
    47  		}
    48  	}
    49  
    50  	topicRegex := regexp.MustCompile(a.Topic)
    51  
    52  	if !topicRegex.MatchString(topic) {
    53  		return false
    54  	}
    55  
    56  	for _, c := range a.Channels {
    57  		channelRegex := regexp.MustCompile(c)
    58  		if channelRegex.MatchString(channel) {
    59  			return true
    60  		}
    61  	}
    62  	return false
    63  }
    64  
    65  func (a *State) IsAllowed(topic, channel string) bool {
    66  	for _, aa := range a.Authorizations {
    67  		if aa.IsAllowed(topic, channel) {
    68  			return true
    69  		}
    70  	}
    71  	return false
    72  }
    73  
    74  func (a *State) IsExpired() bool {
    75  	return a.Expires.Before(time.Now())
    76  }
    77  
    78  func QueryAnyAuthd(authd []string, remoteIP string, tlsEnabled bool, commonName string, authSecret string,
    79  	clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) {
    80  	var retErr error
    81  	start := rand.Int()
    82  	n := len(authd)
    83  	for i := 0; i < n; i++ {
    84  		a := authd[(i+start)%n]
    85  		authState, err := QueryAuthd(a, remoteIP, tlsEnabled, commonName, authSecret, clientTLSConfig, connectTimeout, requestTimeout)
    86  		if err != nil {
    87  			es := fmt.Sprintf("failed to auth against %s - %s", a, err)
    88  			if retErr != nil {
    89  				es = fmt.Sprintf("%s; %s", retErr, es)
    90  			}
    91  			retErr = errors.New(es)
    92  			continue
    93  		}
    94  		return authState, nil
    95  	}
    96  	return nil, retErr
    97  }
    98  
    99  func QueryAuthd(authd string, remoteIP string, tlsEnabled bool, commonName string, authSecret string,
   100  	clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) {
   101  	v := url.Values{}
   102  	v.Set("remote_ip", remoteIP)
   103  	if tlsEnabled {
   104  		v.Set("tls", "true")
   105  	} else {
   106  		v.Set("tls", "false")
   107  	}
   108  	v.Set("secret", authSecret)
   109  	v.Set("common_name", commonName)
   110  
   111  	var endpoint string
   112  	if strings.Contains(authd, "://") {
   113  		endpoint = fmt.Sprintf("%s?%s", authd, v.Encode())
   114  	} else {
   115  		endpoint = fmt.Sprintf("http://%s/auth?%s", authd, v.Encode())
   116  	}
   117  
   118  	var authState State
   119  	client := http_api.NewClient(clientTLSConfig, connectTimeout, requestTimeout)
   120  	if err := client.GETV1(endpoint, &authState); err != nil {
   121  		return nil, err
   122  	}
   123  
   124  	// validation on response
   125  	for _, auth := range authState.Authorizations {
   126  		for _, p := range auth.Permissions {
   127  			switch p {
   128  			case "subscribe", "publish":
   129  			default:
   130  				return nil, fmt.Errorf("unknown permission %s", p)
   131  			}
   132  		}
   133  
   134  		if _, err := regexp.Compile(auth.Topic); err != nil {
   135  			return nil, fmt.Errorf("unable to compile topic %q %s", auth.Topic, err)
   136  		}
   137  
   138  		for _, channel := range auth.Channels {
   139  			if _, err := regexp.Compile(channel); err != nil {
   140  				return nil, fmt.Errorf("unable to compile channel %q %s", channel, err)
   141  			}
   142  		}
   143  	}
   144  
   145  	if authState.TTL <= 0 {
   146  		return nil, fmt.Errorf("invalid TTL %d (must be >0)", authState.TTL)
   147  	}
   148  
   149  	authState.Expires = time.Now().Add(time.Duration(authState.TTL) * time.Second)
   150  	return &authState, nil
   151  }