github.com/crowdsecurity/crowdsec@v1.6.1/pkg/apiserver/middlewares/v1/jwt.go (about)

     1  package v1
     2  
     3  import (
     4  	"crypto/rand"
     5  	"errors"
     6  	"fmt"
     7  	"net/http"
     8  	"os"
     9  	"strings"
    10  	"time"
    11  
    12  	jwt "github.com/appleboy/gin-jwt/v2"
    13  	"github.com/gin-gonic/gin"
    14  	"github.com/go-openapi/strfmt"
    15  	log "github.com/sirupsen/logrus"
    16  	"golang.org/x/crypto/bcrypt"
    17  
    18  	"github.com/crowdsecurity/crowdsec/pkg/database"
    19  	"github.com/crowdsecurity/crowdsec/pkg/database/ent"
    20  	"github.com/crowdsecurity/crowdsec/pkg/database/ent/machine"
    21  	"github.com/crowdsecurity/crowdsec/pkg/models"
    22  	"github.com/crowdsecurity/crowdsec/pkg/types"
    23  )
    24  
    25  const MachineIDKey = "id"
    26  
    27  type JWT struct {
    28  	Middleware *jwt.GinJWTMiddleware
    29  	DbClient   *database.Client
    30  	TlsAuth    *TLSAuth
    31  }
    32  
    33  func PayloadFunc(data interface{}) jwt.MapClaims {
    34  	if value, ok := data.(*models.WatcherAuthRequest); ok {
    35  		return jwt.MapClaims{
    36  			MachineIDKey: &value.MachineID,
    37  		}
    38  	}
    39  
    40  	return jwt.MapClaims{}
    41  }
    42  
    43  func IdentityHandler(c *gin.Context) interface{} {
    44  	claims := jwt.ExtractClaims(c)
    45  	machineID := claims[MachineIDKey].(string)
    46  
    47  	return &models.WatcherAuthRequest{
    48  		MachineID: &machineID,
    49  	}
    50  }
    51  
    52  type authInput struct {
    53  	machineID      string
    54  	clientMachine  *ent.Machine
    55  	scenariosInput []string
    56  }
    57  
    58  func (j *JWT) authTLS(c *gin.Context) (*authInput, error) {
    59  	ret := authInput{}
    60  
    61  	if j.TlsAuth == nil {
    62  		c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
    63  		c.Abort()
    64  
    65  		return nil, errors.New("TLS auth is not configured")
    66  	}
    67  
    68  	validCert, extractedCN, err := j.TlsAuth.ValidateCert(c)
    69  	if err != nil {
    70  		log.Error(err)
    71  		c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
    72  		c.Abort()
    73  
    74  		return nil, fmt.Errorf("while trying to validate client cert: %w", err)
    75  	}
    76  
    77  	if !validCert {
    78  		c.JSON(http.StatusForbidden, gin.H{"message": "access forbidden"})
    79  		c.Abort()
    80  
    81  		return nil, errors.New("failed cert authentication")
    82  	}
    83  
    84  	ret.machineID = fmt.Sprintf("%s@%s", extractedCN, c.ClientIP())
    85  
    86  	ret.clientMachine, err = j.DbClient.Ent.Machine.Query().
    87  		Where(machine.MachineId(ret.machineID)).
    88  		First(j.DbClient.CTX)
    89  	if ent.IsNotFound(err) {
    90  		// Machine was not found, let's create it
    91  		log.Infof("machine %s not found, create it", ret.machineID)
    92  		// let's use an apikey as the password, doesn't matter in this case (generatePassword is only available in cscli)
    93  		pwd, err := GenerateAPIKey(dummyAPIKeySize)
    94  		if err != nil {
    95  			log.WithFields(log.Fields{
    96  				"ip": c.ClientIP(),
    97  				"cn": extractedCN,
    98  			}).Errorf("error generating password: %s", err)
    99  
   100  			return nil, errors.New("error generating password")
   101  		}
   102  
   103  		password := strfmt.Password(pwd)
   104  
   105  		ret.clientMachine, err = j.DbClient.CreateMachine(&ret.machineID, &password, "", true, true, types.TlsAuthType)
   106  		if err != nil {
   107  			return nil, fmt.Errorf("while creating machine entry for %s: %w", ret.machineID, err)
   108  		}
   109  	} else if err != nil {
   110  		return nil, fmt.Errorf("while selecting machine entry for %s: %w", ret.machineID, err)
   111  	} else {
   112  		if ret.clientMachine.AuthType != types.TlsAuthType {
   113  			return nil, fmt.Errorf("machine %s attempted to auth with TLS cert but it is configured to use %s", ret.machineID, ret.clientMachine.AuthType)
   114  		}
   115  
   116  		ret.machineID = ret.clientMachine.MachineId
   117  	}
   118  
   119  	loginInput := struct {
   120  		Scenarios []string `json:"scenarios"`
   121  	}{
   122  		Scenarios: []string{},
   123  	}
   124  
   125  	err = c.ShouldBindJSON(&loginInput)
   126  	if err != nil {
   127  		return nil, fmt.Errorf("missing scenarios list in login request for TLS auth: %w", err)
   128  	}
   129  
   130  	ret.scenariosInput = loginInput.Scenarios
   131  
   132  	return &ret, nil
   133  }
   134  
   135  func (j *JWT) authPlain(c *gin.Context) (*authInput, error) {
   136  	var (
   137  		loginInput models.WatcherAuthRequest
   138  		err        error
   139  	)
   140  
   141  	ret := authInput{}
   142  
   143  	if err = c.ShouldBindJSON(&loginInput); err != nil {
   144  		return nil, fmt.Errorf("missing: %w", err)
   145  	}
   146  
   147  	if err = loginInput.Validate(strfmt.Default); err != nil {
   148  		return nil, err
   149  	}
   150  
   151  	ret.machineID = *loginInput.MachineID
   152  	password := *loginInput.Password
   153  	ret.scenariosInput = loginInput.Scenarios
   154  
   155  	ret.clientMachine, err = j.DbClient.Ent.Machine.Query().
   156  		Where(machine.MachineId(ret.machineID)).
   157  		First(j.DbClient.CTX)
   158  	if err != nil {
   159  		log.Infof("Error machine login for %s : %+v ", ret.machineID, err)
   160  		return nil, err
   161  	}
   162  
   163  	if ret.clientMachine == nil {
   164  		log.Errorf("Nothing for '%s'", ret.machineID)
   165  		return nil, jwt.ErrFailedAuthentication
   166  	}
   167  
   168  	if ret.clientMachine.AuthType != types.PasswordAuthType {
   169  		return nil, fmt.Errorf("machine %s attempted to auth with password but it is configured to use %s", ret.machineID, ret.clientMachine.AuthType)
   170  	}
   171  
   172  	if !ret.clientMachine.IsValidated {
   173  		return nil, fmt.Errorf("machine %s not validated", ret.machineID)
   174  	}
   175  
   176  	if err := bcrypt.CompareHashAndPassword([]byte(ret.clientMachine.Password), []byte(password)); err != nil {
   177  		return nil, jwt.ErrFailedAuthentication
   178  	}
   179  
   180  	return &ret, nil
   181  }
   182  
   183  func (j *JWT) Authenticator(c *gin.Context) (interface{}, error) {
   184  	var (
   185  		err  error
   186  		auth *authInput
   187  	)
   188  
   189  	if c.Request.TLS != nil && len(c.Request.TLS.PeerCertificates) > 0 {
   190  		auth, err = j.authTLS(c)
   191  		if err != nil {
   192  			return nil, err
   193  		}
   194  	} else {
   195  		auth, err = j.authPlain(c)
   196  		if err != nil {
   197  			return nil, err
   198  		}
   199  	}
   200  
   201  	var scenarios string
   202  
   203  	if len(auth.scenariosInput) > 0 {
   204  		for _, scenario := range auth.scenariosInput {
   205  			if scenarios == "" {
   206  				scenarios = scenario
   207  			} else {
   208  				scenarios += "," + scenario
   209  			}
   210  		}
   211  
   212  		err = j.DbClient.UpdateMachineScenarios(scenarios, auth.clientMachine.ID)
   213  		if err != nil {
   214  			log.Errorf("Failed to update scenarios list for '%s': %s\n", auth.machineID, err)
   215  			return nil, jwt.ErrFailedAuthentication
   216  		}
   217  	}
   218  
   219  	clientIP := c.ClientIP()
   220  
   221  	if auth.clientMachine.IpAddress == "" {
   222  		err = j.DbClient.UpdateMachineIP(clientIP, auth.clientMachine.ID)
   223  		if err != nil {
   224  			log.Errorf("Failed to update ip address for '%s': %s\n", auth.machineID, err)
   225  			return nil, jwt.ErrFailedAuthentication
   226  		}
   227  	}
   228  
   229  	if auth.clientMachine.IpAddress != clientIP && auth.clientMachine.IpAddress != "" {
   230  		log.Warningf("new IP address detected for machine '%s': %s (old: %s)", auth.clientMachine.MachineId, clientIP, auth.clientMachine.IpAddress)
   231  
   232  		err = j.DbClient.UpdateMachineIP(clientIP, auth.clientMachine.ID)
   233  		if err != nil {
   234  			log.Errorf("Failed to update ip address for '%s': %s\n", auth.clientMachine.MachineId, err)
   235  			return nil, jwt.ErrFailedAuthentication
   236  		}
   237  	}
   238  
   239  	useragent := strings.Split(c.Request.UserAgent(), "/")
   240  	if len(useragent) != 2 {
   241  		log.Warningf("bad user agent '%s' from '%s'", c.Request.UserAgent(), clientIP)
   242  		return nil, jwt.ErrFailedAuthentication
   243  	}
   244  
   245  	if err := j.DbClient.UpdateMachineVersion(useragent[1], auth.clientMachine.ID); err != nil {
   246  		log.Errorf("unable to update machine '%s' version '%s': %s", auth.clientMachine.MachineId, useragent[1], err)
   247  		log.Errorf("bad user agent from : %s", clientIP)
   248  
   249  		return nil, jwt.ErrFailedAuthentication
   250  	}
   251  
   252  	return &models.WatcherAuthRequest{
   253  		MachineID: &auth.machineID,
   254  	}, nil
   255  }
   256  
   257  func Authorizator(data interface{}, c *gin.Context) bool {
   258  	return true
   259  }
   260  
   261  func Unauthorized(c *gin.Context, code int, message string) {
   262  	c.JSON(code, gin.H{
   263  		"code":    code,
   264  		"message": message,
   265  	})
   266  }
   267  
   268  func randomSecret() ([]byte, error) {
   269  	size := 64
   270  	secret := make([]byte, size)
   271  
   272  	n, err := rand.Read(secret)
   273  	if err != nil {
   274  		return nil, errors.New("unable to generate a new random seed for JWT generation")
   275  	}
   276  
   277  	if n != size {
   278  		return nil, errors.New("not enough entropy at random seed generation for JWT generation")
   279  	}
   280  
   281  	return secret, nil
   282  }
   283  
   284  func NewJWT(dbClient *database.Client) (*JWT, error) {
   285  	// Get secret from environment variable "SECRET"
   286  	var (
   287  		secret []byte
   288  		err    error
   289  	)
   290  
   291  	// Please be aware that brute force HS256 is possible.
   292  	// PLEASE choose a STRONG secret
   293  	secretString := os.Getenv("CS_LAPI_SECRET")
   294  	secret = []byte(secretString)
   295  
   296  	switch l := len(secret); {
   297  	case l == 0:
   298  		secret, err = randomSecret()
   299  		if err != nil {
   300  			return &JWT{}, err
   301  		}
   302  	case l < 64:
   303  		return &JWT{}, errors.New("CS_LAPI_SECRET not strong enough")
   304  	}
   305  
   306  	jwtMiddleware := &JWT{
   307  		DbClient: dbClient,
   308  		TlsAuth:  &TLSAuth{},
   309  	}
   310  
   311  	ret, err := jwt.New(&jwt.GinJWTMiddleware{
   312  		Realm:           "Crowdsec API local",
   313  		Key:             secret,
   314  		Timeout:         time.Hour,
   315  		MaxRefresh:      time.Hour,
   316  		IdentityKey:     MachineIDKey,
   317  		PayloadFunc:     PayloadFunc,
   318  		IdentityHandler: IdentityHandler,
   319  		Authenticator:   jwtMiddleware.Authenticator,
   320  		Authorizator:    Authorizator,
   321  		Unauthorized:    Unauthorized,
   322  		TokenLookup:     "header: Authorization, query: token, cookie: jwt",
   323  		TokenHeadName:   "Bearer",
   324  		TimeFunc:        time.Now,
   325  	})
   326  	if err != nil {
   327  		return &JWT{}, err
   328  	}
   329  
   330  	errInit := ret.MiddlewareInit()
   331  	if errInit != nil {
   332  		return &JWT{}, errors.New("authMiddleware.MiddlewareInit() Error:" + errInit.Error())
   333  	}
   334  
   335  	jwtMiddleware.Middleware = ret
   336  
   337  	return jwtMiddleware, nil
   338  }