code.vegaprotocol.io/vega@v0.79.0/wallet/service/v1/auth.go (about)

     1  // Copyright (C) 2023 Gobalsky Labs Limited
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU Affero General Public License as
     5  // published by the Free Software Foundation, either version 3 of the
     6  // License, or (at your option) any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU Affero General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU Affero General Public License
    14  // along with this program.  If not, see <http://www.gnu.org/licenses/>.
    15  
    16  package v1
    17  
    18  import (
    19  	"crypto/rsa"
    20  	"encoding/hex"
    21  	"encoding/json"
    22  	"errors"
    23  	"fmt"
    24  	"net/http"
    25  	"strings"
    26  	"sync"
    27  	"time"
    28  
    29  	vgcrypto "code.vegaprotocol.io/vega/libs/crypto"
    30  	vgrand "code.vegaprotocol.io/vega/libs/rand"
    31  
    32  	"github.com/dgrijalva/jwt-go/v4"
    33  	"go.uber.org/zap"
    34  )
    35  
    36  const (
    37  	LengthForSessionHashSeed = 10
    38  
    39  	jwtBearer = "Bearer "
    40  )
    41  
    42  var ErrSessionNotFound = errors.New("session not found")
    43  
    44  type auth struct {
    45  	log *zap.Logger
    46  	// sessionID -> wallet name
    47  	sessions    map[string]string
    48  	privKey     *rsa.PrivateKey
    49  	pubKey      *rsa.PublicKey
    50  	tokenExpiry time.Duration
    51  
    52  	mu sync.Mutex
    53  }
    54  
    55  func NewAuth(log *zap.Logger, cfgStore RSAStore, tokenExpiry time.Duration) (Auth, error) { //revive:disable:unexported-return
    56  	keys, err := cfgStore.GetRsaKeys()
    57  	if err != nil {
    58  		return nil, err
    59  	}
    60  	priv, err := jwt.ParseRSAPrivateKeyFromPEM(keys.Priv)
    61  	if err != nil {
    62  		return nil, fmt.Errorf("couldn't parse private RSA key: %w", err)
    63  	}
    64  	pub, err := jwt.ParseRSAPublicKeyFromPEM(keys.Pub)
    65  	if err != nil {
    66  		return nil, fmt.Errorf("couldn't parse public RSA key: %w", err)
    67  	}
    68  
    69  	return &auth{
    70  		sessions:    map[string]string{},
    71  		privKey:     priv,
    72  		pubKey:      pub,
    73  		log:         log,
    74  		tokenExpiry: tokenExpiry,
    75  	}, nil
    76  }
    77  
    78  type Claims struct {
    79  	jwt.StandardClaims
    80  	Session string
    81  	Wallet  string
    82  }
    83  
    84  func (a *auth) NewSession(walletName string) (string, error) {
    85  	a.mu.Lock()
    86  	defer a.mu.Unlock()
    87  
    88  	expiresAt := time.Now().Add(a.tokenExpiry)
    89  
    90  	session := genSession()
    91  
    92  	claims := &Claims{
    93  		Session: session,
    94  		Wallet:  walletName,
    95  		StandardClaims: jwt.StandardClaims{
    96  			// these are seconds
    97  			ExpiresAt: jwt.NewTime((float64)(expiresAt.Unix())),
    98  			Issuer:    "vega wallet",
    99  		},
   100  	}
   101  
   102  	token := jwt.NewWithClaims(jwt.SigningMethodPS256, claims)
   103  	ss, err := token.SignedString(a.privKey)
   104  	if err != nil {
   105  		a.log.Error("unable to sign token", zap.Error(err))
   106  		return "", err
   107  	}
   108  
   109  	a.sessions[session] = walletName
   110  	return ss, nil
   111  }
   112  
   113  // VerifyToken returns the wallet name associated for this session.
   114  func (a *auth) VerifyToken(token string) (string, error) {
   115  	a.mu.Lock()
   116  	defer a.mu.Unlock()
   117  
   118  	claims, err := a.parseToken(token)
   119  	if err != nil {
   120  		return "", err
   121  	}
   122  
   123  	walletName, ok := a.sessions[claims.Session]
   124  	if !ok {
   125  		return "", ErrSessionNotFound
   126  	}
   127  
   128  	return walletName, nil
   129  }
   130  
   131  func (a *auth) Revoke(token string) (string, error) {
   132  	a.mu.Lock()
   133  	defer a.mu.Unlock()
   134  
   135  	claims, err := a.parseToken(token)
   136  	if err != nil {
   137  		return "", err
   138  	}
   139  
   140  	w, ok := a.sessions[claims.Session]
   141  	if !ok {
   142  		return "", ErrSessionNotFound
   143  	}
   144  	delete(a.sessions, claims.Session)
   145  	return w, nil
   146  }
   147  
   148  func (a *auth) RevokeAllToken() {
   149  	a.mu.Lock()
   150  	defer a.mu.Unlock()
   151  
   152  	a.sessions = map[string]string{}
   153  }
   154  
   155  func (a *auth) parseToken(tokenStr string) (*Claims, error) {
   156  	token, err := jwt.ParseWithClaims(tokenStr, &Claims{}, func(t *jwt.Token) (interface{}, error) {
   157  		return a.pubKey, nil
   158  	})
   159  	if err != nil {
   160  		return nil, fmt.Errorf("couldn't parse JWT token: %w", err)
   161  	}
   162  	if !token.Valid {
   163  		return nil, ErrInvalidToken
   164  	}
   165  	if claims, ok := token.Claims.(*Claims); ok {
   166  		return claims, nil
   167  	}
   168  	return nil, ErrInvalidClaims
   169  }
   170  
   171  func extractToken(r *http.Request) (string, error) {
   172  	token := strings.TrimSpace(r.Header.Get("Authorization"))
   173  	if !strings.HasPrefix(token, jwtBearer) {
   174  		return "", ErrInvalidOrMissingToken
   175  	}
   176  	return strings.TrimSpace(token[len(jwtBearer):]), nil
   177  }
   178  
   179  func genSession() string {
   180  	return hex.EncodeToString(vgcrypto.Hash(vgrand.RandomBytes(LengthForSessionHashSeed)))
   181  }
   182  
   183  func writeError(w http.ResponseWriter, e error) {
   184  	w.Header().Set("Content-Type", "application/json")
   185  	w.WriteHeader(http.StatusBadRequest)
   186  
   187  	buf, err := json.Marshal(e)
   188  	if err != nil {
   189  		w.WriteHeader(http.StatusInternalServerError)
   190  		return
   191  	}
   192  	_, err = w.Write(buf)
   193  	if err != nil {
   194  		w.WriteHeader(http.StatusInternalServerError)
   195  	}
   196  }