github.com/wanliu/go-oauth2-server@v0.0.0-20180817021415-f928fa1580df/oauth/client.go (about)

     1  package oauth
     2  
     3  import (
     4  	"errors"
     5  	"strings"
     6  	"time"
     7  
     8  	"github.com/RichardKnop/uuid"
     9  	"github.com/jinzhu/gorm"
    10  	"github.com/wanliu/go-oauth2-server/models"
    11  	"github.com/wanliu/go-oauth2-server/util"
    12  	"github.com/wanliu/go-oauth2-server/util/password"
    13  )
    14  
    15  var (
    16  	// ErrClientNotFound ...
    17  	ErrClientNotFound = errors.New("Client not found")
    18  	// ErrInvalidClientSecret ...
    19  	ErrInvalidClientSecret = errors.New("Invalid client secret")
    20  	// ErrClientIDTaken ...
    21  	ErrClientIDTaken = errors.New("Client ID taken")
    22  )
    23  
    24  // ClientExists returns true if client exists
    25  func (s *Service) ClientExists(clientID string) bool {
    26  	_, err := s.FindClientByClientID(clientID)
    27  	return err == nil
    28  }
    29  
    30  // FindClientByClientID looks up a client by client ID
    31  func (s *Service) FindClientByClientID(clientID string) (*models.OauthClient, error) {
    32  	// Client IDs are case insensitive
    33  	client := new(models.OauthClient)
    34  	notFound := s.db.Where("key = LOWER(?)", clientID).
    35  		First(client).RecordNotFound()
    36  
    37  	// Not found
    38  	if notFound {
    39  		return nil, ErrClientNotFound
    40  	}
    41  
    42  	return client, nil
    43  }
    44  
    45  // CreateClient saves a new client to database
    46  func (s *Service) CreateClient(clientID, secret, redirectURI string) (*models.OauthClient, error) {
    47  	return s.createClientCommon(s.db, clientID, secret, redirectURI, "", "")
    48  }
    49  
    50  // CreateClient saves a new client to database
    51  func (s *Service) CreateClientByUserID(userId, name, clientID, secret, redirectURI string) (*models.OauthClient, error) {
    52  	return s.createClientCommon(s.db, clientID, secret, redirectURI, userId, name)
    53  }
    54  
    55  // CreateClientTx saves a new client to database using injected db object
    56  func (s *Service) CreateClientTx(tx *gorm.DB, clientID, secret, redirectURI string) (*models.OauthClient, error) {
    57  	return s.createClientCommon(tx, clientID, secret, redirectURI, "", "")
    58  }
    59  
    60  // AuthClient authenticates client
    61  func (s *Service) AuthClient(clientID, secret string) (*models.OauthClient, error) {
    62  	// Fetch the client
    63  	client, err := s.FindClientByClientID(clientID)
    64  	if err != nil {
    65  		return nil, ErrClientNotFound
    66  	}
    67  
    68  	// Verify the secret
    69  	if password.VerifyPassword(client.Secret, secret) != nil {
    70  		return nil, ErrInvalidClientSecret
    71  	}
    72  
    73  	return client, nil
    74  }
    75  
    76  func (s *Service) createClientCommon(db *gorm.DB, clientID, pass, redirectURI, userId, name string) (*models.OauthClient, error) {
    77  	// Check client ID
    78  	if s.ClientExists(clientID) {
    79  		return nil, ErrClientIDTaken
    80  	}
    81  
    82  	// Hash password
    83  	secretHash, err := password.HashPassword(pass)
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  
    88  	client := &models.OauthClient{
    89  		MyGormModel: models.MyGormModel{
    90  			ID:        uuid.New(),
    91  			CreatedAt: time.Now().UTC(),
    92  		},
    93  		Name:        util.StringOrNull(name),
    94  		UserID:      util.StringOrNull(userId),
    95  		Key:         strings.ToLower(clientID),
    96  		Secret:      string(secretHash),
    97  		Password:    util.StringOrNull(pass),
    98  		RedirectURI: util.StringOrNull(redirectURI),
    99  	}
   100  	if err := db.Create(client).Error; err != nil {
   101  		return nil, err
   102  	}
   103  	return client, nil
   104  }
   105  
   106  func (s *Service) ListClientByUserID(userId string, offset, count int) ([]models.OauthClient, error) {
   107  	var clients []models.OauthClient
   108  	if err := s.db.Find(&clients, "user_id = ?", userId).Offset(offset).Limit(count).Error; err != nil {
   109  		return nil, err
   110  	}
   111  
   112  	return clients, nil
   113  }