code.vegaprotocol.io/vega@v0.79.0/wallet/service/v1/auth.go (about) 1 // Copyright (C) 2023 Gobalsky Labs Limited 2 // 3 // This program is free software: you can redistribute it and/or modify 4 // it under the terms of the GNU Affero General Public License as 5 // published by the Free Software Foundation, either version 3 of the 6 // License, or (at your option) any later version. 7 // 8 // This program is distributed in the hope that it will be useful, 9 // but WITHOUT ANY WARRANTY; without even the implied warranty of 10 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 11 // GNU Affero General Public License for more details. 12 // 13 // You should have received a copy of the GNU Affero General Public License 14 // along with this program. If not, see <http://www.gnu.org/licenses/>. 15 16 package v1 17 18 import ( 19 "crypto/rsa" 20 "encoding/hex" 21 "encoding/json" 22 "errors" 23 "fmt" 24 "net/http" 25 "strings" 26 "sync" 27 "time" 28 29 vgcrypto "code.vegaprotocol.io/vega/libs/crypto" 30 vgrand "code.vegaprotocol.io/vega/libs/rand" 31 32 "github.com/dgrijalva/jwt-go/v4" 33 "go.uber.org/zap" 34 ) 35 36 const ( 37 LengthForSessionHashSeed = 10 38 39 jwtBearer = "Bearer " 40 ) 41 42 var ErrSessionNotFound = errors.New("session not found") 43 44 type auth struct { 45 log *zap.Logger 46 // sessionID -> wallet name 47 sessions map[string]string 48 privKey *rsa.PrivateKey 49 pubKey *rsa.PublicKey 50 tokenExpiry time.Duration 51 52 mu sync.Mutex 53 } 54 55 func NewAuth(log *zap.Logger, cfgStore RSAStore, tokenExpiry time.Duration) (Auth, error) { //revive:disable:unexported-return 56 keys, err := cfgStore.GetRsaKeys() 57 if err != nil { 58 return nil, err 59 } 60 priv, err := jwt.ParseRSAPrivateKeyFromPEM(keys.Priv) 61 if err != nil { 62 return nil, fmt.Errorf("couldn't parse private RSA key: %w", err) 63 } 64 pub, err := jwt.ParseRSAPublicKeyFromPEM(keys.Pub) 65 if err != nil { 66 return nil, fmt.Errorf("couldn't parse public RSA key: %w", err) 67 } 68 69 return &auth{ 70 sessions: map[string]string{}, 71 privKey: priv, 72 pubKey: pub, 73 log: log, 74 tokenExpiry: tokenExpiry, 75 }, nil 76 } 77 78 type Claims struct { 79 jwt.StandardClaims 80 Session string 81 Wallet string 82 } 83 84 func (a *auth) NewSession(walletName string) (string, error) { 85 a.mu.Lock() 86 defer a.mu.Unlock() 87 88 expiresAt := time.Now().Add(a.tokenExpiry) 89 90 session := genSession() 91 92 claims := &Claims{ 93 Session: session, 94 Wallet: walletName, 95 StandardClaims: jwt.StandardClaims{ 96 // these are seconds 97 ExpiresAt: jwt.NewTime((float64)(expiresAt.Unix())), 98 Issuer: "vega wallet", 99 }, 100 } 101 102 token := jwt.NewWithClaims(jwt.SigningMethodPS256, claims) 103 ss, err := token.SignedString(a.privKey) 104 if err != nil { 105 a.log.Error("unable to sign token", zap.Error(err)) 106 return "", err 107 } 108 109 a.sessions[session] = walletName 110 return ss, nil 111 } 112 113 // VerifyToken returns the wallet name associated for this session. 114 func (a *auth) VerifyToken(token string) (string, error) { 115 a.mu.Lock() 116 defer a.mu.Unlock() 117 118 claims, err := a.parseToken(token) 119 if err != nil { 120 return "", err 121 } 122 123 walletName, ok := a.sessions[claims.Session] 124 if !ok { 125 return "", ErrSessionNotFound 126 } 127 128 return walletName, nil 129 } 130 131 func (a *auth) Revoke(token string) (string, error) { 132 a.mu.Lock() 133 defer a.mu.Unlock() 134 135 claims, err := a.parseToken(token) 136 if err != nil { 137 return "", err 138 } 139 140 w, ok := a.sessions[claims.Session] 141 if !ok { 142 return "", ErrSessionNotFound 143 } 144 delete(a.sessions, claims.Session) 145 return w, nil 146 } 147 148 func (a *auth) RevokeAllToken() { 149 a.mu.Lock() 150 defer a.mu.Unlock() 151 152 a.sessions = map[string]string{} 153 } 154 155 func (a *auth) parseToken(tokenStr string) (*Claims, error) { 156 token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (interface{}, error) { 157 return a.pubKey, nil 158 }) 159 if err != nil { 160 return nil, fmt.Errorf("couldn't parse JWT token: %w", err) 161 } 162 if !token.Valid { 163 return nil, ErrInvalidToken 164 } 165 if claims, ok := token.Claims.(*Claims); ok { 166 return claims, nil 167 } 168 return nil, ErrInvalidClaims 169 } 170 171 func extractToken(r *http.Request) (string, error) { 172 token := strings.TrimSpace(r.Header.Get("Authorization")) 173 if !strings.HasPrefix(token, jwtBearer) { 174 return "", ErrInvalidOrMissingToken 175 } 176 return strings.TrimSpace(token[len(jwtBearer):]), nil 177 } 178 179 func genSession() string { 180 return hex.EncodeToString(vgcrypto.Hash(vgrand.RandomBytes(LengthForSessionHashSeed))) 181 } 182 183 func writeError(w http.ResponseWriter, e error) { 184 w.Header().Set("Content-Type", "application/json") 185 w.WriteHeader(http.StatusBadRequest) 186 187 buf, err := json.Marshal(e) 188 if err != nil { 189 w.WriteHeader(http.StatusInternalServerError) 190 return 191 } 192 _, err = w.Write(buf) 193 if err != nil { 194 w.WriteHeader(http.StatusInternalServerError) 195 } 196 }