github.com/Accefy/pop@v0.0.0-20230428174248-e9f677eab5b9/connection_instrumented_test.go (about)

     1  package pop
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"os"
     7  	"strings"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/luna-duclos/instrumentedsql"
    12  	"github.com/stretchr/testify/suite"
    13  )
    14  
    15  func testInstrumentedDriver(p *suite.Suite) {
    16  	r := p.Require()
    17  	deets := *Connections[os.Getenv("SODA_DIALECT")].Dialect.Details()
    18  
    19  	ctx, cancel := context.WithTimeout(context.TODO(), time.Second*5)
    20  	defer cancel()
    21  
    22  	// The WaitGroup and channel ensures that the logger is properly called. This can only happen
    23  	// when the instrumented driver is working as expected and returns the expected query.
    24  	var (
    25  		queryMySQL = "SELECT 1 FROM DUAL WHERE 1=?"
    26  		queryOther = "SELECT 1 WHERE 1=?"
    27  		mc         = make(chan string)
    28  		wg         sync.WaitGroup
    29  		expected   = []string{
    30  			"SELECT 1 FROM DUAL WHERE 1=?",
    31  			"SELECT 1 FROM DUAL WHERE 1=$1",
    32  			"SELECT 1 WHERE 1=?",
    33  			"SELECT 1 WHERE 1=$1",
    34  		}
    35  	)
    36  
    37  	query := queryOther
    38  	if os.Getenv("SODA_DIALECT") == "mysql" {
    39  		query = queryMySQL
    40  	}
    41  
    42  	wg.Add(1)
    43  	go func() {
    44  		defer wg.Done()
    45  		var messages []string
    46  		var found bool
    47  		for {
    48  			select {
    49  			case m := <-mc:
    50  				p.T().Logf("Received message: %s", m)
    51  				messages = append(messages, m)
    52  				for _, e := range expected {
    53  					if strings.Contains(m, e) {
    54  						p.T().Logf("Found part %s in %s", e, m)
    55  						found = true
    56  						break
    57  					}
    58  				}
    59  			case <-ctx.Done():
    60  				if !found {
    61  					r.FailNow(fmt.Sprintf("Expected tracer to return the \"%s\" query but only the following messages have been received:\n\n\t%s", query, strings.Join(messages, "\n\t")))
    62  					return
    63  				}
    64  				return
    65  			}
    66  		}
    67  	}()
    68  
    69  	var checker = instrumentedsql.LoggerFunc(func(ctx context.Context, msg string, keyvals ...interface{}) {
    70  		p.T().Logf("Instrumentation received message: %s - %+v", msg, keyvals)
    71  		mc <- fmt.Sprintf("%s - %+v", msg, keyvals)
    72  	})
    73  
    74  	deets.UseInstrumentedDriver = true
    75  	deets.InstrumentedDriverOptions = []instrumentedsql.Opt{instrumentedsql.WithLogger(checker)}
    76  
    77  	c, err := NewConnection(&deets)
    78  	r.NoError(err)
    79  	r.NoError(c.Open())
    80  
    81  	err = c.WithContext(context.TODO()).RawQuery(query, 1).Exec()
    82  	r.NoError(err)
    83  
    84  	wg.Wait()
    85  }
    86  
    87  func (s *PostgreSQLSuite) Test_Instrumentation() {
    88  	testInstrumentedDriver(&s.Suite)
    89  }
    90  
    91  func (s *MySQLSuite) Test_Instrumentation() {
    92  	testInstrumentedDriver(&s.Suite)
    93  }
    94  
    95  func (s *SQLiteSuite) Test_Instrumentation() {
    96  	testInstrumentedDriver(&s.Suite)
    97  }
    98  
    99  func (s *CockroachSuite) Test_Instrumentation() {
   100  	testInstrumentedDriver(&s.Suite)
   101  }