github.com/uvalib/orcid-access-ws@v0.0.0-20250612130209-7d062dbabf9d/orcidaccessws/dao/db-datastore.go (about)

     1  package dao
     2  
     3  import (
     4  	"database/sql"
     5  	"fmt"
     6  	"time"
     7  
     8  	// needed
     9  	_ "github.com/go-sql-driver/mysql"
    10  	"github.com/uvalib/orcid-access-ws/orcidaccessws/api"
    11  	"github.com/uvalib/orcid-access-ws/orcidaccessws/config"
    12  	"github.com/uvalib/orcid-access-ws/orcidaccessws/logger"
    13  )
    14  
    15  // this is our DB implementation
    16  type storage struct {
    17  	*sql.DB
    18  }
    19  
    20  // newDBStore -- create a DB version of the storage singleton
    21  func newDBStore() (Storage, error) {
    22  
    23  	dataSourceName := fmt.Sprintf("%s:%s@tcp(%s)/%s?allowOldPasswords=1&tls=%s&sql_notes=false&timeout=%s&readTimeout=%s&writeTimeout=%s",
    24  		config.Configuration.DbUser,
    25  		config.Configuration.DbPassphrase,
    26  		config.Configuration.DbHost,
    27  		config.Configuration.DbName,
    28  		config.Configuration.DbSecure,
    29  		config.Configuration.DbTimeout,
    30  		config.Configuration.DbTimeout,
    31  		config.Configuration.DbTimeout)
    32  
    33  	db, err := sql.Open("mysql", dataSourceName)
    34  	if err != nil {
    35  		return nil, err
    36  	}
    37  	if err = db.Ping(); err != nil {
    38  		return nil, err
    39  	}
    40  
    41  	//taken from https://github.com/go-sql-driver/mysql/issues/461
    42  	db.SetConnMaxLifetime(time.Minute * 5)
    43  	db.SetMaxIdleConns(2)
    44  	db.SetMaxOpenConns(2)
    45  
    46  	return &storage{db}, nil
    47  }
    48  
    49  // CheckDB -- check our database health
    50  func (s *storage) Check() error {
    51  	return s.Ping()
    52  }
    53  
    54  // GetAllOrcidAttributes -- get all orcid records
    55  func (s *storage) GetAllOrcidAttributes() ([]*api.OrcidAttributes, error) {
    56  
    57  	rows, err := s.Query("SELECT * FROM orcid_attributes ORDER BY id ASC")
    58  	if err != nil {
    59  		return nil, err
    60  	}
    61  	defer rows.Close()
    62  
    63  	return orcidResults(rows)
    64  }
    65  
    66  // GetOrcidAttributesByCid -- get all by ID (should only be 1)
    67  func (s *storage) GetOrcidAttributesByCid(id string) ([]*api.OrcidAttributes, error) {
    68  	return (getOrcidAttributesByCid(s, id))
    69  }
    70  
    71  // DelOrcidAttributesByCid -- delete by ID (should only be 1)
    72  func (s *storage) DelOrcidAttributesByCid(id string) error {
    73  
    74  	stmt, err := s.Prepare("DELETE FROM orcid_attributes WHERE cid = ? LIMIT 1")
    75  	if err != nil {
    76  		return err
    77  	}
    78  
    79  	_, err = stmt.Exec(id)
    80  
    81  	return err
    82  }
    83  
    84  // SetOrcidAttributesByCid -- set orcid attributes by ID
    85  func (s *storage) SetOrcidAttributesByCid(id string, attributes api.OrcidAttributes) error {
    86  
    87  	existing, err := getOrcidAttributesByCid(s, id)
    88  	if err != nil {
    89  		return err
    90  	}
    91  
    92  	// if we did not find a record, create a new one
    93  	if len(existing) == 0 {
    94  
    95  		stmt, err := s.Prepare("INSERT INTO orcid_attributes( cid, orcid, oauth_access, oauth_refresh, oauth_scope ) VALUES( ?,?,?,?,? )")
    96  		if err != nil {
    97  			return err
    98  		}
    99  
   100  		_, err = stmt.Exec(
   101  			id,
   102  			attributes.Orcid,
   103  			attributes.OauthAccessToken,
   104  			attributes.OauthRefreshToken,
   105  			attributes.OauthScope)
   106  
   107  	} else {
   108  
   109  		// a special case where we preserve the existing ORCID if none provided
   110  		newOrcid := existing[0].Orcid
   111  		if len(attributes.Orcid) != 0 {
   112  			newOrcid = attributes.Orcid
   113  		}
   114  
   115  		stmt, err := s.Prepare("UPDATE orcid_attributes SET orcid = ?, oauth_access = ?, oauth_refresh = ?, oauth_scope = ?, updated_at = NOW( ) WHERE cid = ? LIMIT 1")
   116  		if err != nil {
   117  			return err
   118  		}
   119  		_, err = stmt.Exec(
   120  			newOrcid,
   121  			attributes.OauthAccessToken,
   122  			attributes.OauthRefreshToken,
   123  			attributes.OauthScope,
   124  			id)
   125  	}
   126  
   127  	return err
   128  }
   129  
   130  //
   131  // private implementation methods
   132  //
   133  
   134  func getOrcidAttributesByCid(s *storage, id string) ([]*api.OrcidAttributes, error) {
   135  
   136  	rows, err := s.Query("SELECT * FROM orcid_attributes WHERE cid = ? LIMIT 1", id)
   137  	if err != nil {
   138  		return nil, err
   139  	}
   140  	defer rows.Close()
   141  	return orcidResults(rows)
   142  }
   143  
   144  func orcidResults(rows *sql.Rows) ([]*api.OrcidAttributes, error) {
   145  
   146  	var optionalUpdatedAt sql.NullString
   147  
   148  	results := make([]*api.OrcidAttributes, 0)
   149  	for rows.Next() {
   150  		reg := new(api.OrcidAttributes)
   151  		err := rows.Scan(&reg.ID,
   152  			&reg.Cid,
   153  			&reg.Orcid,
   154  			&reg.OauthAccessToken,
   155  			&reg.OauthRefreshToken,
   156  			&reg.OauthScope,
   157  			&reg.CreatedAt,
   158  			&optionalUpdatedAt)
   159  		if err != nil {
   160  			return nil, err
   161  		}
   162  
   163  		if optionalUpdatedAt.Valid {
   164  			reg.UpdatedAt = optionalUpdatedAt.String
   165  		}
   166  
   167  		// hack for now...
   168  		reg.URI = fmt.Sprintf("%s/%s", config.Configuration.OrcidOauthURL, reg.Orcid)
   169  
   170  		results = append(results, reg)
   171  	}
   172  	if err := rows.Err(); err != nil {
   173  		return nil, err
   174  	}
   175  
   176  	logger.Log(fmt.Sprintf("INFO: OrcidAttributes request returns %d row(s)", len(results)))
   177  	return results, nil
   178  }
   179  
   180  //
   181  // end of file
   182  //