github.com/bcampbell/scrapeomat@v0.0.0-20220820232205-23e64141c89e/store/sqlstore/sqlstore.go (about)

     1  package sqlstore
     2  
     3  import (
     4  	"database/sql"
     5  	"encoding/json"
     6  	"fmt"
     7  	"os"
     8  	"regexp"
     9  	"strings"
    10  	"time"
    11  
    12  	"github.com/bcampbell/scrapeomat/store"
    13  )
    14  
    15  type nullLogger struct{}
    16  
    17  func (l nullLogger) Printf(format string, v ...interface{}) {
    18  }
    19  
    20  type stderrLogger struct{}
    21  
    22  func (l stderrLogger) Printf(format string, v ...interface{}) {
    23  	fmt.Fprintf(os.Stderr, format, v...)
    24  }
    25  
    26  // SQLStore stashes articles in an SQL database
    27  type SQLStore struct {
    28  	db         *sql.DB
    29  	driverName string
    30  	loc        *time.Location
    31  	ErrLog     store.Logger
    32  	DebugLog   store.Logger
    33  }
    34  
    35  type SQLArtIter struct {
    36  	rows    *sql.Rows
    37  	ss      *SQLStore
    38  	current *store.Article
    39  	err     error
    40  }
    41  
    42  // Which method to use to get last insert IDs
    43  const (
    44  	DUNNO     = iota
    45  	RESULT    // use Result.LastInsertID()
    46  	RETURNING // use sql "RETURNING" clause
    47  )
    48  
    49  // eg "postgres", "postgres://username@localhost/dbname"
    50  // eg "sqlite3", "/tmp/foo.db"
    51  func New(driver string, connStr string) (*SQLStore, error) {
    52  
    53  	//db, err := sql.Open("postgres", connStr)
    54  	db, err := sql.Open(driver, connStr)
    55  	if err != nil {
    56  		return nil, err
    57  	}
    58  	return NewFromDB(driver, db)
    59  }
    60  
    61  func NewFromDB(driver string, db *sql.DB) (*SQLStore, error) {
    62  	err := db.Ping()
    63  	if err != nil {
    64  		db.Close()
    65  		return nil, err
    66  	}
    67  
    68  	// our assumed location for publication dates, when no timezone given
    69  	// TODO: this is the wrong place for it. Scraper should handle this on a per-publication basis
    70  	//	loc, err := time.LoadLocation("Europe/London")
    71  	//	if err != nil {
    72  	//		return nil, err
    73  	//	}
    74  
    75  	ss := SQLStore{
    76  		db:         db,
    77  		driverName: driver,
    78  		loc:        time.UTC,
    79  		ErrLog:     nullLogger{}, // TODO: should log to stderr by default?
    80  		DebugLog:   nullLogger{},
    81  	}
    82  
    83  	// TODO: would be nice to have logger set up before here...
    84  	err = ss.checkSchema()
    85  	if err != nil {
    86  		db.Close()
    87  		return nil, err
    88  	}
    89  
    90  	return &ss, nil
    91  }
    92  
    93  // Same as New(), but if driver or connStr is missing, will try and read them
    94  // from environment vars: SCRAPEOMAT_DRIVER & SCRAPEOMAT_DB.
    95  // If both driver and SCRAPEOMAT_DRIVER are empty, default is "sqlite3".
    96  func NewWithEnv(driver string, connStr string) (*SQLStore, error) {
    97  	if connStr == "" {
    98  		connStr = os.Getenv("SCRAPEOMAT_DB")
    99  	}
   100  	if driver == "" {
   101  		driver = os.Getenv("SCRAPEOMAT_DRIVER")
   102  		if driver == "" {
   103  			driver = "sqlite3"
   104  		}
   105  	}
   106  
   107  	if connStr == "" {
   108  		return nil, fmt.Errorf("no database specified (set SCRAPEOMAT_DB?)")
   109  	}
   110  
   111  	return New(driver, connStr)
   112  }
   113  
   114  func (ss *SQLStore) Close() {
   115  	if ss.db != nil {
   116  		ss.db.Close()
   117  		ss.db = nil
   118  	}
   119  }
   120  
   121  func (ss *SQLStore) rebind(q string) string {
   122  	return rebind(bindType(ss.driverName), q)
   123  }
   124  
   125  // can we use Result.LastInsertID() or do we need to fiddle the SQL?
   126  func (ss *SQLStore) insertIDType() int {
   127  	switch ss.driverName {
   128  	case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres", "ql":
   129  		return RETURNING
   130  	case "sqlite3", "mysql":
   131  		return RESULT
   132  	case "oci8", "ora", "goracle":
   133  		// ora: https://godoc.org/gopkg.in/rana/ora.v4#hdr-LastInsertId
   134  		return DUNNO
   135  	case "sqlserver":
   136  		// https://github.com/denisenkom/go-mssqldb#important-notes
   137  		return DUNNO
   138  	default:
   139  		return DUNNO
   140  	}
   141  }
   142  
   143  // return a string with sql fn to return current timestamp.
   144  func (ss *SQLStore) nowSQL() string {
   145  	switch ss.driverName {
   146  	case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres", "ql":
   147  		return "NOW()"
   148  	case "sqlite3":
   149  		return "datetime('now','localtime')"
   150  	case "mysql":
   151  		return "PROPER_FN_GOES_HERE_PLEASE()"
   152  	case "oci8", "ora", "goracle":
   153  		return "PROPER_FN_GOES_HERE_PLEASE()"
   154  	case "sqlserver":
   155  		return "PROPER_FN_GOES_HERE_PLEASE()"
   156  	default:
   157  		return "PROPER_FN_GOES_HERE_PLEASE()"
   158  	}
   159  }
   160  
   161  var timeFmts = []string{
   162  	time.RFC3339,
   163  	"2006-01-02T15:04Z07:00",
   164  	//	"2006-01-02T15:04:05Z",
   165  	"2006-01-02T15:04:05",
   166  	"2006-01-02T15:04",
   167  	"2006-01-02",
   168  }
   169  
   170  func (ss *SQLStore) cvtTime(timestamp string) sql.NullTime {
   171  	for _, layout := range timeFmts {
   172  		t, err := time.ParseInLocation(layout, timestamp, ss.loc)
   173  		if err == nil {
   174  			return sql.NullTime{Time: t, Valid: true}
   175  		}
   176  	}
   177  
   178  	return sql.NullTime{Valid: false}
   179  }
   180  
   181  var datePat = regexp.MustCompile(`^\d\d\d\d-\d\d-\d\d`)
   182  
   183  // FindURLs Looks up article urls, returning a list of matching article IDs.
   184  // usually you'd use this on the URLs for a single article, expecting zero or one IDs back,
   185  // but there's no reason you can't look up a whole bunch of articles at once, although you won't
   186  // know which ones match which URLs.
   187  // remember that there can be multiple URLs for a single article, AND also multiple articles can
   188  // share the same URL (hopefully much much more rare).
   189  func (ss *SQLStore) FindURLs(urls []string) ([]int, error) {
   190  
   191  	params := make([]interface{}, len(urls))
   192  	placeholders := make([]string, len(urls))
   193  	for i, u := range urls {
   194  		params[i] = u
   195  		placeholders[i] = "?"
   196  	}
   197  
   198  	s := `SELECT distinct article_id FROM article_url WHERE url IN (` + strings.Join(placeholders, ",") + `)`
   199  	rows, err := ss.db.Query(ss.rebind(s), params...)
   200  	if err != nil {
   201  		return nil, err
   202  	}
   203  	defer rows.Close()
   204  
   205  	out := []int{}
   206  	for rows.Next() {
   207  		var artID int
   208  		if err := rows.Scan(&artID); err != nil {
   209  			return nil, err
   210  		}
   211  
   212  		out = append(out, artID)
   213  	}
   214  	if err := rows.Err(); err != nil {
   215  		return nil, err
   216  	}
   217  	return out, nil
   218  }
   219  
   220  // NOTE: remember article urls don't _have_ to be unique. If you only pass
   221  // canonical urls in here you should be ok :-)
   222  func (ss *SQLStore) WhichAreNew(artURLs []string) ([]string, error) {
   223  
   224  	stmt, err := ss.db.Prepare(ss.rebind(`SELECT article_id FROM article_url WHERE url=?`))
   225  	if err != nil {
   226  		return nil, err
   227  	}
   228  
   229  	newArts := []string{}
   230  	for _, u := range artURLs {
   231  		var artID int
   232  		err = stmt.QueryRow(u).Scan(&artID)
   233  		if err == sql.ErrNoRows {
   234  			newArts = append(newArts, u)
   235  		} else if err != nil {
   236  			return nil, err
   237  		}
   238  	}
   239  	return newArts, nil
   240  }
   241  
   242  // Build a WHERE clause from a filter.
   243  func buildWhere(filt *store.Filter) (string, []interface{}) {
   244  	params := []interface{}{}
   245  	frags := []string{}
   246  
   247  	if !filt.PubFrom.IsZero() {
   248  		frags = append(frags, "a.published>=?")
   249  		params = append(params, filt.PubFrom)
   250  	}
   251  	if !filt.PubTo.IsZero() {
   252  		frags = append(frags, "a.published<?")
   253  		params = append(params, filt.PubTo)
   254  	}
   255  	if !filt.AddedFrom.IsZero() {
   256  		frags = append(frags, "a.added>=?")
   257  		params = append(params, filt.AddedFrom)
   258  	}
   259  	if !filt.AddedTo.IsZero() {
   260  		frags = append(frags, "a.added<?")
   261  		params = append(params, filt.AddedTo)
   262  	}
   263  	if filt.SinceID > 0 {
   264  		frags = append(frags, "a.id>?")
   265  		params = append(params, filt.SinceID)
   266  	}
   267  
   268  	if len(filt.PubCodes) > 0 {
   269  		foo := []string{}
   270  		bar := []interface{}{}
   271  		for _, code := range filt.PubCodes {
   272  			foo = append(foo, "?")
   273  			bar = append(bar, code)
   274  		}
   275  		frags = append(frags, "p.code IN ("+strings.Join(foo, ",")+")")
   276  		params = append(params, bar...)
   277  	}
   278  
   279  	if len(filt.XPubCodes) > 0 {
   280  		foo := []string{}
   281  		bar := []interface{}{}
   282  		for _, code := range filt.XPubCodes {
   283  			foo = append(foo, "?")
   284  			bar = append(bar, code)
   285  		}
   286  		frags = append(frags, "p.code NOT IN ("+strings.Join(foo, ",")+")")
   287  		params = append(params, bar...)
   288  	}
   289  
   290  	var whereClause string
   291  	if len(frags) > 0 {
   292  		whereClause = "WHERE " + strings.Join(frags, " AND ")
   293  	}
   294  	return whereClause, params
   295  }
   296  
   297  func (ss *SQLStore) FetchCount(filt *store.Filter) (int, error) {
   298  	whereClause, params := buildWhere(filt)
   299  	q := `SELECT COUNT(*)
   300             FROM (article a INNER JOIN publication p ON a.publication_id=p.id)
   301             ` + whereClause
   302  	var cnt int
   303  	//ss.DebugLog.Printf("fetchcount: %s\n", q)
   304  	//ss.DebugLog.Printf("fetchcount params: %+v\n", params)
   305  	err := ss.db.QueryRow(ss.rebind(q), params...).Scan(&cnt)
   306  	return cnt, err
   307  }
   308  
   309  func (ss *SQLStore) Fetch(filt *store.Filter) store.ArtIter {
   310  
   311  	whereClause, params := buildWhere(filt)
   312  
   313  	q := `SELECT a.id,a.headline,a.canonical_url,a.content,a.published,a.updated,a.section,a.extra,p.code,p.name,p.domain
   314  	               FROM (article a INNER JOIN publication p ON a.publication_id=p.id)
   315  	               ` + whereClause + ` ORDER BY a.id`
   316  
   317  	if filt.Count > 0 {
   318  		q += fmt.Sprintf(" LIMIT %d", filt.Count)
   319  	}
   320  
   321  	ss.DebugLog.Printf("fetch: %s\n", q)
   322  	ss.DebugLog.Printf("fetch params: %+v\n", params)
   323  
   324  	rows, err := ss.db.Query(ss.rebind(q), params...)
   325  	return &SQLArtIter{ss: ss, rows: rows, err: err}
   326  }
   327  
   328  func (it *SQLArtIter) Close() error {
   329  	// may not even have got as far as initing rows!
   330  	var err error
   331  	if it.rows != nil {
   332  		err = it.rows.Close()
   333  		it.rows = nil
   334  	}
   335  	return err
   336  }
   337  
   338  func (it *SQLArtIter) Err() error {
   339  	return it.err
   340  }
   341  
   342  // if it returns true there will be an article.
   343  func (it *SQLArtIter) Next() bool {
   344  	it.current = nil
   345  	if it.err != nil {
   346  		return false // no more, if we're in error state
   347  	}
   348  	if !it.rows.Next() {
   349  		it.err = it.rows.Err()
   350  		return false // all done
   351  	}
   352  
   353  	art := &store.Article{}
   354  	var p = &art.Publication
   355  
   356  	var published, updated sql.NullTime
   357  	var extra []byte
   358  	err := it.rows.Scan(&art.ID, &art.Headline, &art.CanonicalURL, &art.Content, &published, &updated, &art.Section, &extra, &p.Code, &p.Name, &p.Domain)
   359  	if err != nil {
   360  		it.err = err
   361  		return false
   362  	}
   363  
   364  	if published.Valid {
   365  		art.Published = published.Time.Format(time.RFC3339)
   366  	}
   367  	if updated.Valid {
   368  		art.Updated = updated.Time.Format(time.RFC3339)
   369  	}
   370  
   371  	urls, err := it.ss.fetchURLs(art.ID)
   372  	if err != nil {
   373  		it.err = err
   374  		return false
   375  	}
   376  	art.URLs = urls
   377  
   378  	keywords, err := it.ss.fetchKeywords(art.ID)
   379  	if err != nil {
   380  		it.err = err
   381  		return false
   382  	}
   383  	art.Keywords = keywords
   384  
   385  	authors, err := it.ss.fetchAuthors(art.ID)
   386  	if err != nil {
   387  		it.err = err
   388  		return false
   389  	}
   390  	art.Authors = authors
   391  
   392  	// decode extra data
   393  	if len(extra) > 0 {
   394  		err = json.Unmarshal(extra, &art.Extra)
   395  		if err != nil {
   396  			it.err = err
   397  			return false
   398  		}
   399  	}
   400  
   401  	// if we get this far there's an article ready.
   402  	it.current = art
   403  	return true
   404  }
   405  
   406  func (it *SQLArtIter) Article() *store.Article {
   407  	return it.current
   408  }
   409  
   410  func (ss *SQLStore) fetchURLs(artID int) ([]string, error) {
   411  	q := `SELECT url FROM article_url WHERE article_id=?`
   412  	rows, err := ss.db.Query(ss.rebind(q), artID)
   413  	if err != nil {
   414  		return nil, err
   415  	}
   416  	defer rows.Close()
   417  	out := []string{}
   418  	for rows.Next() {
   419  		var u string
   420  		if err := rows.Scan(&u); err != nil {
   421  			return nil, err
   422  		}
   423  		out = append(out, u)
   424  	}
   425  	if err := rows.Err(); err != nil {
   426  		return nil, err
   427  	}
   428  	return out, nil
   429  }
   430  
   431  func (ss *SQLStore) fetchAuthors(artID int) ([]store.Author, error) {
   432  	q := `SELECT name,rel_link,email,twitter
   433          FROM (author a INNER JOIN author_attr attr ON attr.author_id=a.id)
   434          WHERE article_id=?`
   435  	rows, err := ss.db.Query(ss.rebind(q), artID)
   436  	if err != nil {
   437  		return nil, err
   438  	}
   439  	defer rows.Close()
   440  	out := []store.Author{}
   441  	for rows.Next() {
   442  		var a store.Author
   443  		if err := rows.Scan(&a.Name, &a.RelLink, &a.Email, &a.Twitter); err != nil {
   444  			return nil, err
   445  		}
   446  		out = append(out, a)
   447  	}
   448  	if err := rows.Err(); err != nil {
   449  		return nil, err
   450  	}
   451  	return out, nil
   452  }
   453  
   454  func (ss *SQLStore) fetchKeywords(artID int) ([]store.Keyword, error) {
   455  	q := `SELECT name,url
   456          FROM article_keyword
   457          WHERE article_id=?`
   458  	rows, err := ss.db.Query(ss.rebind(q), artID)
   459  	if err != nil {
   460  		return nil, err
   461  	}
   462  	defer rows.Close()
   463  	out := []store.Keyword{}
   464  	for rows.Next() {
   465  		var k store.Keyword
   466  		if err := rows.Scan(&k.Name, &k.URL); err != nil {
   467  			return nil, err
   468  		}
   469  		out = append(out, k)
   470  	}
   471  	if err := rows.Err(); err != nil {
   472  		return nil, err
   473  	}
   474  	return out, nil
   475  }
   476  
   477  func (ss *SQLStore) FetchPublications() ([]store.Publication, error) {
   478  	q := `SELECT code,name,domain FROM publication ORDER by code`
   479  	rows, err := ss.db.Query(ss.rebind(q))
   480  
   481  	if err != nil {
   482  		return nil, err
   483  	}
   484  	defer rows.Close()
   485  	out := []store.Publication{}
   486  	for rows.Next() {
   487  		var p store.Publication
   488  		if err := rows.Scan(&p.Code, &p.Name, &p.Domain); err != nil {
   489  			return nil, err
   490  		}
   491  		out = append(out, p)
   492  	}
   493  	if err := rows.Err(); err != nil {
   494  		return nil, err
   495  	}
   496  	return out, nil
   497  
   498  }
   499  
   500  func (ss *SQLStore) FetchSummary(filt *store.Filter, group string) ([]store.DatePubCount, error) {
   501  	// TODO: FetchSummary() should probably take a timezone, in order to properly
   502  	// group by day... for now, days are UTC days!
   503  	tz := time.UTC
   504  
   505  	whereClause, params := buildWhere(filt)
   506  
   507  	var dayField string
   508  	switch group {
   509  	case "published":
   510  		dayField = "a.published"
   511  	case "added":
   512  		dayField = "a.Added"
   513  	default:
   514  		return nil, fmt.Errorf("Bad group field (%s)", group)
   515  	}
   516  
   517  	var q string
   518  	q = `SELECT DATE(` + dayField + `) AS day, p.code, COUNT(*)
   519  	    FROM (article a INNER JOIN publication p ON a.publication_id=p.id) ` +
   520  		whereClause + ` GROUP BY day, p.code ORDER BY day ASC ,p.code ASC;`
   521  
   522  	ss.DebugLog.Printf("summary: %s\n", q)
   523  	ss.DebugLog.Printf("summary params: %+v\n", params)
   524  
   525  	rows, err := ss.db.Query(ss.rebind(q), params...)
   526  	if err != nil {
   527  		return nil, err
   528  	}
   529  	defer rows.Close()
   530  	out := []store.DatePubCount{}
   531  	for rows.Next() {
   532  		foo := store.DatePubCount{}
   533  		if ss.driverName == "sqlite3" {
   534  			// TODO: sqlite3 driver can't seem to scan a DATE() to a time.Time (or sql.NullTime)
   535  			// TODO: INVESTIGATE!
   536  			// for now, workaround with string parsing.
   537  			var day sql.NullString
   538  			if err := rows.Scan(&day, &foo.PubCode, &foo.Count); err != nil {
   539  				return nil, err
   540  			}
   541  
   542  			if day.Valid {
   543  				t, err := time.ParseInLocation("2006-01-02", day.String, tz)
   544  				if err == nil {
   545  					foo.Date = t
   546  				}
   547  			}
   548  		} else {
   549  			// the non-sqlite3 version:
   550  			var day sql.NullTime
   551  			if err := rows.Scan(&day, &foo.PubCode, &foo.Count); err != nil {
   552  				return nil, err
   553  			}
   554  			if day.Valid {
   555  				foo.Date = day.Time.In(tz)
   556  			}
   557  		}
   558  		//ss.DebugLog.Printf("summary: %v\n", foo)
   559  
   560  		out = append(out, foo)
   561  	}
   562  	if err := rows.Err(); err != nil {
   563  		return nil, err
   564  	}
   565  
   566  	ss.DebugLog.Printf("summary out: %d\n", len(out))
   567  	return out, nil
   568  
   569  }
   570  
   571  // Fetch a single article by ID
   572  func (ss *SQLStore) FetchArt(artID int) (*store.Article, error) {
   573  
   574  	q := `SELECT a.id,a.headline,a.canonical_url,a.content,a.published,a.updated,a.section,a.extra,p.code,p.name,p.domain
   575  	               FROM (article a INNER JOIN publication p ON a.publication_id=p.id)
   576  	               WHERE a.id=?`
   577  
   578  	ss.DebugLog.Printf("fetch: %s [%d]\n", q, artID)
   579  	row := ss.db.QueryRow(q, artID)
   580  
   581  	/* TODO: split scanning/augmenting out into function, to share with Fetch() */
   582  	var art store.Article
   583  	var p = &art.Publication
   584  
   585  	var published, updated sql.NullTime
   586  	var extra []byte
   587  	if err := row.Scan(&art.ID, &art.Headline, &art.CanonicalURL, &art.Content, &published, &updated, &art.Section, &extra, &p.Code, &p.Name, &p.Domain); err != nil {
   588  		return nil, err
   589  	}
   590  
   591  	if published.Valid {
   592  		art.Published = published.Time.Format(time.RFC3339)
   593  	}
   594  	if updated.Valid {
   595  		art.Updated = updated.Time.Format(time.RFC3339)
   596  	}
   597  
   598  	urls, err := ss.fetchURLs(art.ID)
   599  	if err != nil {
   600  		return nil, err
   601  	}
   602  	art.URLs = urls
   603  
   604  	keywords, err := ss.fetchKeywords(art.ID)
   605  	if err != nil {
   606  		return nil, err
   607  	}
   608  	art.Keywords = keywords
   609  
   610  	authors, err := ss.fetchAuthors(art.ID)
   611  	if err != nil {
   612  		return nil, err
   613  	}
   614  	art.Authors = authors
   615  
   616  	// decode extra data
   617  	if len(extra) > 0 {
   618  		err = json.Unmarshal(extra, &art.Extra)
   619  		if err != nil {
   620  			err = fmt.Errorf("error in 'Extra' (artid %d): %s", art.ID, err)
   621  			return nil, err
   622  		}
   623  	}
   624  
   625  	/* end scanning/augmenting */
   626  
   627  	return &art, nil
   628  }