github.com/benoitkugler/goacve@v0.0.0-20201217100549-151ce6e55dc8/server/core/rawdata/sql.go (about)

     1  package rawdata
     2  
     3  import (
     4  	"database/sql"
     5  	"database/sql/driver"
     6  	"fmt"
     7  	"time"
     8  
     9  	"github.com/benoitkugler/goACVE/logs"
    10  	"github.com/lib/pq"
    11  	_ "github.com/lib/pq" // SQL driver registration
    12  )
    13  
    14  // Handling of NULL values
    15  func (s *Bool) Scan(src interface{}) error {
    16  	var tmp sql.NullBool
    17  	err := tmp.Scan(src)
    18  	if err != nil {
    19  		return err
    20  	}
    21  	*s = Bool(tmp.Bool)
    22  	return nil
    23  }
    24  
    25  func (s *Int) Scan(src interface{}) error {
    26  	var tmp sql.NullInt64
    27  	err := tmp.Scan(src)
    28  	if err != nil {
    29  		return err
    30  	}
    31  	*s = Int(tmp.Int64)
    32  	return nil
    33  }
    34  
    35  func (s *Float) Scan(src interface{}) error {
    36  	var tmp sql.NullFloat64
    37  	err := tmp.Scan(src)
    38  	if err != nil {
    39  		return err
    40  	}
    41  	*s = Float(tmp.Float64)
    42  	return nil
    43  }
    44  
    45  func (s *String) Scan(src interface{}) error {
    46  	var tmp sql.NullString
    47  	err := tmp.Scan(src)
    48  	if err != nil {
    49  		return err
    50  	}
    51  	*s = String(tmp.String)
    52  	return nil
    53  }
    54  
    55  func (s *Time) Scan(src interface{}) error {
    56  	var tmp pq.NullTime
    57  	err := tmp.Scan(src)
    58  	if err != nil {
    59  		return err
    60  	}
    61  	*s = Time(tmp.Time)
    62  	return nil
    63  }
    64  func (s Time) Value() (driver.Value, error) {
    65  	pqTime := pq.NullTime{Time: time.Time(s), Valid: true}
    66  	if s.Time().IsZero() {
    67  		pqTime = pq.NullTime{}
    68  	}
    69  	return pqTime.Value()
    70  }
    71  
    72  func (s *Date) Scan(src interface{}) error {
    73  	var tmp pq.NullTime
    74  	err := tmp.Scan(src)
    75  	if err != nil {
    76  		return err
    77  	}
    78  	*s = Date(tmp.Time)
    79  	return nil
    80  }
    81  func (s Date) Value() (driver.Value, error) {
    82  	return pq.NullTime{Time: time.Time(s), Valid: true}.Value()
    83  }
    84  
    85  // custom types
    86  func (s *OptionnalId) Scan(src interface{}) error {
    87  	return (*sql.NullInt64)(s).Scan(src)
    88  }
    89  func (s OptionnalId) Value() (driver.Value, error) {
    90  	return (sql.NullInt64)(s).Value()
    91  }
    92  
    93  func (s *Tels) Scan(src interface{}) error {
    94  	return (*pq.StringArray)(s).Scan(src)
    95  }
    96  func (s Tels) Value() (driver.Value, error) {
    97  	return (pq.StringArray)(s).Value()
    98  }
    99  
   100  func (s *Cotisation) Scan(src interface{}) error {
   101  	return (*pq.Int64Array)(s).Scan(src)
   102  }
   103  func (s Cotisation) Value() (driver.Value, error) {
   104  	return (pq.Int64Array)(s).Value()
   105  }
   106  
   107  func (rs *Roles) Scan(src interface{}) error {
   108  	var tmp pq.StringArray
   109  	err := tmp.Scan(src)
   110  	// on convertit les strings en Role
   111  	b := make(Roles, len(tmp))
   112  	for i, v := range tmp {
   113  		b[i] = Role(v)
   114  	}
   115  	// on met à jour la cible du  pointeur
   116  	*rs = b
   117  	return err
   118  }
   119  func (rs Roles) Value() (driver.Value, error) {
   120  	tmp := make(pq.StringArray, len(rs))
   121  	for i, v := range rs {
   122  		tmp[i] = string(v)
   123  	}
   124  	return tmp.Value()
   125  }
   126  
   127  func ConnectDB(credences logs.DB) (*sql.DB, error) {
   128  	port := credences.Port
   129  	if port == 0 {
   130  		port = 5432
   131  	}
   132  	connStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s",
   133  		credences.Host, port, credences.User, credences.Password, credences.Name)
   134  	db, err := sql.Open("postgres", connStr)
   135  	if err != nil {
   136  		return nil, fmt.Errorf("connexion DB : %s", err)
   137  	}
   138  	return db, nil
   139  }
   140  
   141  func ScanIds(rs *sql.Rows) (Ids, error) {
   142  	defer rs.Close()
   143  	ints := make(Ids, 0, 16)
   144  	var err error
   145  	for rs.Next() {
   146  		var s int64
   147  		if err = rs.Scan(&s); err != nil {
   148  			return nil, err
   149  		}
   150  		ints = append(ints, s)
   151  	}
   152  	if err = rs.Err(); err != nil {
   153  		return nil, err
   154  	}
   155  	return ints, nil
   156  }