github.com/benoitkugler/goacve@v0.0.0-20201217100549-151ce6e55dc8/server/core/rawdata/sql.go (about) 1 package rawdata 2 3 import ( 4 "database/sql" 5 "database/sql/driver" 6 "fmt" 7 "time" 8 9 "github.com/benoitkugler/goACVE/logs" 10 "github.com/lib/pq" 11 _ "github.com/lib/pq" // SQL driver registration 12 ) 13 14 // Handling of NULL values 15 func (s *Bool) Scan(src interface{}) error { 16 var tmp sql.NullBool 17 err := tmp.Scan(src) 18 if err != nil { 19 return err 20 } 21 *s = Bool(tmp.Bool) 22 return nil 23 } 24 25 func (s *Int) Scan(src interface{}) error { 26 var tmp sql.NullInt64 27 err := tmp.Scan(src) 28 if err != nil { 29 return err 30 } 31 *s = Int(tmp.Int64) 32 return nil 33 } 34 35 func (s *Float) Scan(src interface{}) error { 36 var tmp sql.NullFloat64 37 err := tmp.Scan(src) 38 if err != nil { 39 return err 40 } 41 *s = Float(tmp.Float64) 42 return nil 43 } 44 45 func (s *String) Scan(src interface{}) error { 46 var tmp sql.NullString 47 err := tmp.Scan(src) 48 if err != nil { 49 return err 50 } 51 *s = String(tmp.String) 52 return nil 53 } 54 55 func (s *Time) Scan(src interface{}) error { 56 var tmp pq.NullTime 57 err := tmp.Scan(src) 58 if err != nil { 59 return err 60 } 61 *s = Time(tmp.Time) 62 return nil 63 } 64 func (s Time) Value() (driver.Value, error) { 65 pqTime := pq.NullTime{Time: time.Time(s), Valid: true} 66 if s.Time().IsZero() { 67 pqTime = pq.NullTime{} 68 } 69 return pqTime.Value() 70 } 71 72 func (s *Date) Scan(src interface{}) error { 73 var tmp pq.NullTime 74 err := tmp.Scan(src) 75 if err != nil { 76 return err 77 } 78 *s = Date(tmp.Time) 79 return nil 80 } 81 func (s Date) Value() (driver.Value, error) { 82 return pq.NullTime{Time: time.Time(s), Valid: true}.Value() 83 } 84 85 // custom types 86 func (s *OptionnalId) Scan(src interface{}) error { 87 return (*sql.NullInt64)(s).Scan(src) 88 } 89 func (s OptionnalId) Value() (driver.Value, error) { 90 return (sql.NullInt64)(s).Value() 91 } 92 93 func (s *Tels) Scan(src interface{}) error { 94 return (*pq.StringArray)(s).Scan(src) 95 } 96 func (s Tels) Value() (driver.Value, error) { 97 return (pq.StringArray)(s).Value() 98 } 99 100 func (s *Cotisation) Scan(src interface{}) error { 101 return (*pq.Int64Array)(s).Scan(src) 102 } 103 func (s Cotisation) Value() (driver.Value, error) { 104 return (pq.Int64Array)(s).Value() 105 } 106 107 func (rs *Roles) Scan(src interface{}) error { 108 var tmp pq.StringArray 109 err := tmp.Scan(src) 110 // on convertit les strings en Role 111 b := make(Roles, len(tmp)) 112 for i, v := range tmp { 113 b[i] = Role(v) 114 } 115 // on met à jour la cible du pointeur 116 *rs = b 117 return err 118 } 119 func (rs Roles) Value() (driver.Value, error) { 120 tmp := make(pq.StringArray, len(rs)) 121 for i, v := range rs { 122 tmp[i] = string(v) 123 } 124 return tmp.Value() 125 } 126 127 func ConnectDB(credences logs.DB) (*sql.DB, error) { 128 port := credences.Port 129 if port == 0 { 130 port = 5432 131 } 132 connStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s", 133 credences.Host, port, credences.User, credences.Password, credences.Name) 134 db, err := sql.Open("postgres", connStr) 135 if err != nil { 136 return nil, fmt.Errorf("connexion DB : %s", err) 137 } 138 return db, nil 139 } 140 141 func ScanIds(rs *sql.Rows) (Ids, error) { 142 defer rs.Close() 143 ints := make(Ids, 0, 16) 144 var err error 145 for rs.Next() { 146 var s int64 147 if err = rs.Scan(&s); err != nil { 148 return nil, err 149 } 150 ints = append(ints, s) 151 } 152 if err = rs.Err(); err != nil { 153 return nil, err 154 } 155 return ints, nil 156 }