github.com/nsqio/nsq@v1.3.0/internal/auth/authorizations.go (about) 1 package auth 2 3 import ( 4 "crypto/tls" 5 "errors" 6 "fmt" 7 "math/rand" 8 "net/url" 9 "regexp" 10 "strings" 11 "time" 12 13 "github.com/nsqio/nsq/internal/http_api" 14 ) 15 16 type Authorization struct { 17 Topic string `json:"topic"` 18 Channels []string `json:"channels"` 19 Permissions []string `json:"permissions"` 20 } 21 22 type State struct { 23 TTL int `json:"ttl"` 24 Authorizations []Authorization `json:"authorizations"` 25 Identity string `json:"identity"` 26 IdentityURL string `json:"identity_url"` 27 Expires time.Time 28 } 29 30 func (a *Authorization) HasPermission(permission string) bool { 31 for _, p := range a.Permissions { 32 if permission == p { 33 return true 34 } 35 } 36 return false 37 } 38 39 func (a *Authorization) IsAllowed(topic, channel string) bool { 40 if channel != "" { 41 if !a.HasPermission("subscribe") { 42 return false 43 } 44 } else { 45 if !a.HasPermission("publish") { 46 return false 47 } 48 } 49 50 topicRegex := regexp.MustCompile(a.Topic) 51 52 if !topicRegex.MatchString(topic) { 53 return false 54 } 55 56 for _, c := range a.Channels { 57 channelRegex := regexp.MustCompile(c) 58 if channelRegex.MatchString(channel) { 59 return true 60 } 61 } 62 return false 63 } 64 65 func (a *State) IsAllowed(topic, channel string) bool { 66 for _, aa := range a.Authorizations { 67 if aa.IsAllowed(topic, channel) { 68 return true 69 } 70 } 71 return false 72 } 73 74 func (a *State) IsExpired() bool { 75 return a.Expires.Before(time.Now()) 76 } 77 78 func QueryAnyAuthd(authd []string, remoteIP string, tlsEnabled bool, commonName string, authSecret string, 79 clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) { 80 var retErr error 81 start := rand.Int() 82 n := len(authd) 83 for i := 0; i < n; i++ { 84 a := authd[(i+start)%n] 85 authState, err := QueryAuthd(a, remoteIP, tlsEnabled, commonName, authSecret, clientTLSConfig, connectTimeout, requestTimeout) 86 if err != nil { 87 es := fmt.Sprintf("failed to auth against %s - %s", a, err) 88 if retErr != nil { 89 es = fmt.Sprintf("%s; %s", retErr, es) 90 } 91 retErr = errors.New(es) 92 continue 93 } 94 return authState, nil 95 } 96 return nil, retErr 97 } 98 99 func QueryAuthd(authd string, remoteIP string, tlsEnabled bool, commonName string, authSecret string, 100 clientTLSConfig *tls.Config, connectTimeout time.Duration, requestTimeout time.Duration) (*State, error) { 101 v := url.Values{} 102 v.Set("remote_ip", remoteIP) 103 if tlsEnabled { 104 v.Set("tls", "true") 105 } else { 106 v.Set("tls", "false") 107 } 108 v.Set("secret", authSecret) 109 v.Set("common_name", commonName) 110 111 var endpoint string 112 if strings.Contains(authd, "://") { 113 endpoint = fmt.Sprintf("%s?%s", authd, v.Encode()) 114 } else { 115 endpoint = fmt.Sprintf("http://%s/auth?%s", authd, v.Encode()) 116 } 117 118 var authState State 119 client := http_api.NewClient(clientTLSConfig, connectTimeout, requestTimeout) 120 if err := client.GETV1(endpoint, &authState); err != nil { 121 return nil, err 122 } 123 124 // validation on response 125 for _, auth := range authState.Authorizations { 126 for _, p := range auth.Permissions { 127 switch p { 128 case "subscribe", "publish": 129 default: 130 return nil, fmt.Errorf("unknown permission %s", p) 131 } 132 } 133 134 if _, err := regexp.Compile(auth.Topic); err != nil { 135 return nil, fmt.Errorf("unable to compile topic %q %s", auth.Topic, err) 136 } 137 138 for _, channel := range auth.Channels { 139 if _, err := regexp.Compile(channel); err != nil { 140 return nil, fmt.Errorf("unable to compile channel %q %s", channel, err) 141 } 142 } 143 } 144 145 if authState.TTL <= 0 { 146 return nil, fmt.Errorf("invalid TTL %d (must be >0)", authState.TTL) 147 } 148 149 authState.Expires = time.Now().Add(time.Duration(authState.TTL) * time.Second) 150 return &authState, nil 151 }