github.com/prebid/prebid-server@v0.275.0/stored_requests/backends/db_fetcher/fetcher_test.go (about)

     1  package db_fetcher
     2  
     3  import (
     4  	"context"
     5  	"database/sql/driver"
     6  	"encoding/json"
     7  	"errors"
     8  	"fmt"
     9  	"regexp"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/DATA-DOG/go-sqlmock"
    14  	"github.com/prebid/prebid-server/stored_requests/backends/db_provider"
    15  	"github.com/stretchr/testify/assert"
    16  )
    17  
    18  func TestEmptyQuery(t *testing.T) {
    19  	provider, _, err := db_provider.NewDbProviderMock()
    20  	if err != nil {
    21  		t.Fatalf("Unexpected error stubbing DB: %v", err)
    22  	}
    23  	defer provider.Close()
    24  
    25  	fetcher := dbFetcher{
    26  		provider:              provider,
    27  		queryTemplate:         "",
    28  		responseQueryTemplate: "",
    29  	}
    30  	storedReqs, storedImps, errs := fetcher.FetchRequests(context.Background(), nil, nil)
    31  	assertErrorCount(t, 0, errs)
    32  	assertMapLength(t, 0, storedReqs)
    33  	assertMapLength(t, 0, storedImps)
    34  
    35  	storedResponses, errs := fetcher.FetchResponses(context.Background(), nil)
    36  	assertErrorCount(t, 0, errs)
    37  	assertMapLength(t, 0, storedResponses)
    38  }
    39  
    40  // TestGoodResponse makes sure we interpret DB responses properly when all the stored requests are there.
    41  func TestGoodResponse(t *testing.T) {
    42  	mockQuery := "SELECT id, data, 'request' AS dataType FROM req_table WHERE id IN (?) UNION ALL SELECT id, data, 'imp' as dataType FROM imp_table WHERE id IN (?, ?)"
    43  	mockReturn := sqlmock.NewRows([]string{"id", "data", "dataType"}).
    44  		AddRow("request-id", `{"req":true}`, "request").
    45  		AddRow("imp-id", `{"imp":true,"value":1}`, "imp").
    46  		AddRow("imp-id-2", `{"imp":true,"value":2}`, "imp")
    47  
    48  	mock, fetcher := newFetcher(t, mockReturn, mockQuery, "request-id")
    49  	defer fetcher.provider.Close()
    50  
    51  	storedReqs, storedImps, errs := fetcher.FetchRequests(context.Background(), []string{"request-id"}, nil)
    52  
    53  	assertMockExpectations(t, mock)
    54  	assertErrorCount(t, 0, errs)
    55  	assertMapLength(t, 1, storedReqs)
    56  	assertMapLength(t, 2, storedImps)
    57  	assertHasData(t, storedReqs, "request-id", `{"req":true}`)
    58  	assertHasData(t, storedImps, "imp-id", `{"imp":true,"value":1}`)
    59  	assertHasData(t, storedImps, "imp-id-2", `{"imp":true,"value":2}`)
    60  }
    61  
    62  func TestFetchResponses(t *testing.T) {
    63  	testCases := []struct {
    64  		description  string
    65  		mockQuery    string
    66  		mockReturn   *sqlmock.Rows
    67  		arguments    []driver.Value
    68  		respIds      []string
    69  		expectedResp map[string]string
    70  	}{
    71  		{
    72  			description: "fetch good response",
    73  			mockQuery:   "SELECT id, data, 'response' AS dataType FROM responses_table WHERE id IN (?, ?)",
    74  			mockReturn: sqlmock.NewRows([]string{"id", "data", "dataType"}).
    75  				AddRow("resp-id-1", `{"resp":false,"value":1}`, "response").
    76  				AddRow("resp-id-2", `{"resp":true,"value":2}`, "response"),
    77  			arguments:    []driver.Value{"resp-id-1", "resp-id-2"},
    78  			respIds:      []string{"resp-id-1", "resp-id-2"},
    79  			expectedResp: map[string]string{"resp-id-1": `{"resp":false,"value":1}`, "resp-id-2": `{"resp":true,"value":2}`},
    80  		},
    81  		{
    82  			description: "fetch partial response",
    83  			mockQuery:   "SELECT id, data, 'response' AS dataType FROM responses_table WHERE id IN (?, ?)",
    84  			mockReturn: sqlmock.NewRows([]string{"id", "data", "dataType"}).
    85  				AddRow("stored-resp-id", "{}", "response"),
    86  			arguments:    []driver.Value{"stored-resp-id", "stored-resp-id-2"},
    87  			respIds:      []string{"stored-resp-id", "stored-resp-id-2"},
    88  			expectedResp: map[string]string{"stored-resp-id": `{}`},
    89  		},
    90  		{
    91  			description:  "fetch empty response",
    92  			mockQuery:    "SELECT id, data, dataType FROM my_table WHERE id IN (?, ?)",
    93  			mockReturn:   sqlmock.NewRows([]string{"id", "data", "dataType"}),
    94  			arguments:    []driver.Value{"stored-resp-id", "stored-resp-id-2"},
    95  			respIds:      []string{"stored-resp-id", "stored-resp-id-2"},
    96  			expectedResp: map[string]string{},
    97  		},
    98  	}
    99  
   100  	for _, test := range testCases {
   101  		mock, fetcher := newFetcher(t, test.mockReturn, test.mockQuery, test.arguments...)
   102  		defer fetcher.provider.Close()
   103  
   104  		storedResponses, errs := fetcher.FetchResponses(context.Background(), test.respIds)
   105  
   106  		assertMockExpectations(t, mock)
   107  		assertErrorCount(t, 0, errs)
   108  		assert.Len(t, storedResponses, len(test.expectedResp))
   109  
   110  		for k, v := range test.expectedResp {
   111  			assertHasData(t, storedResponses, k, v)
   112  		}
   113  
   114  	}
   115  }
   116  
   117  // TestPartialResponse makes sure we unpack things properly when the DB finds some of the stored requests.
   118  func TestPartialResponse(t *testing.T) {
   119  	mockQuery := "SELECT id, data, 'request' AS dataType FROM req_table WHERE id IN (?, ?) UNION ALL SELECT id, data, 'imp' as dataType FROM imp_table WHERE id IN (NULL)"
   120  	mockReturn := sqlmock.NewRows([]string{"id", "data", "dataType"}).
   121  		AddRow("stored-req-id", "{}", "request")
   122  
   123  	mock, fetcher := newFetcher(t, mockReturn, mockQuery, "stored-req-id", "stored-req-id-2")
   124  	defer fetcher.provider.Close()
   125  
   126  	storedReqs, storedImps, errs := fetcher.FetchRequests(context.Background(), []string{"stored-req-id", "stored-req-id-2"}, nil)
   127  
   128  	assertMockExpectations(t, mock)
   129  	assertErrorCount(t, 1, errs)
   130  	assertMapLength(t, 0, storedImps)
   131  	assertMapLength(t, 1, storedReqs)
   132  	assertHasData(t, storedReqs, "stored-req-id", "{}")
   133  }
   134  
   135  // TestEmptyResponse makes sure we handle empty DB responses properly.
   136  func TestEmptyResponse(t *testing.T) {
   137  	mockQuery := "SELECT id, data, dataType FROM my_table WHERE id IN (?, ?)"
   138  	mockReturn := sqlmock.NewRows([]string{"id", "data", "dataType"})
   139  
   140  	mock, fetcher := newFetcher(t, mockReturn, mockQuery, "stored-req-id", "stored-req-id-2", "stored-imp-id")
   141  	defer fetcher.provider.Close()
   142  
   143  	storedReqs, storedImps, errs := fetcher.FetchRequests(context.Background(), []string{"stored-req-id", "stored-req-id-2"}, []string{"stored-imp-id"})
   144  
   145  	assertMockExpectations(t, mock)
   146  	assertErrorCount(t, 3, errs)
   147  	assertMapLength(t, 0, storedReqs)
   148  	assertMapLength(t, 0, storedImps)
   149  }
   150  
   151  // TestDatabaseError makes sure we exit with an error if the DB query fails.
   152  func TestDatabaseError(t *testing.T) {
   153  	provider, mock, err := db_provider.NewDbProviderMock()
   154  	if err != nil {
   155  		t.Fatalf("Failed to create mock: %v", err)
   156  	}
   157  
   158  	mock.ExpectQuery(".*").WillReturnError(errors.New("Invalid query."))
   159  
   160  	fetcher := &dbFetcher{
   161  		provider:      provider,
   162  		queryTemplate: "SELECT id, data, dataType FROM my_table WHERE id IN (?, ?)",
   163  	}
   164  
   165  	storedReqs, storedImps, errs := fetcher.FetchRequests(context.Background(), []string{"stored-req-id"}, nil)
   166  	assertErrorCount(t, 1, errs)
   167  	assertMapLength(t, 0, storedReqs)
   168  	assertMapLength(t, 0, storedImps)
   169  }
   170  
   171  // TestContextDeadlines makes sure a hung query returns when the timeout expires.
   172  func TestContextDeadlines(t *testing.T) {
   173  	provider, mock, err := db_provider.NewDbProviderMock()
   174  	if err != nil {
   175  		t.Fatalf("Failed to create mock: %v", err)
   176  	}
   177  
   178  	mock.ExpectQuery(".*").WillDelayFor(2 * time.Minute)
   179  
   180  	fetcher := &dbFetcher{
   181  		provider:              provider,
   182  		queryTemplate:         "SELECT id, requestData FROM my_table WHERE id IN (?, ?)",
   183  		responseQueryTemplate: "SELECT id, responseData FROM my_table WHERE id IN (?, ?)",
   184  	}
   185  
   186  	ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
   187  	defer cancel()
   188  
   189  	_, _, errs := fetcher.FetchRequests(ctx, []string{"id"}, nil)
   190  	if len(errs) < 1 {
   191  		t.Errorf("dbFetcher should return an error when the context times out.")
   192  	}
   193  	_, errs = fetcher.FetchResponses(ctx, []string{"id"})
   194  	if len(errs) < 1 {
   195  		t.Errorf("dbFetcher should return an error when the context times out.")
   196  	}
   197  }
   198  
   199  // TestContextCancelled makes sure a hung query returns when the context is cancelled.
   200  func TestContextCancelled(t *testing.T) {
   201  	provider, mock, err := db_provider.NewDbProviderMock()
   202  	if err != nil {
   203  		t.Fatalf("Failed to create mock: %v", err)
   204  	}
   205  
   206  	mock.ExpectQuery(".*").WillDelayFor(2 * time.Minute)
   207  
   208  	fetcher := &dbFetcher{
   209  		provider:              provider,
   210  		queryTemplate:         "SELECT id, requestData FROM my_table WHERE id IN (?, ?)",
   211  		responseQueryTemplate: "SELECT id, responseData FROM my_table WHERE id IN (?, ?)",
   212  	}
   213  
   214  	ctx, cancel := context.WithCancel(context.Background())
   215  	cancel()
   216  	_, _, errs := fetcher.FetchRequests(ctx, []string{"id"}, nil)
   217  	if len(errs) < 1 {
   218  		t.Errorf("dbFetcher should return an error when the context is cancelled.")
   219  	}
   220  	_, errs = fetcher.FetchResponses(ctx, []string{"id"})
   221  	if len(errs) < 1 {
   222  		t.Errorf("dbFetcher should return an error when the context is cancelled.")
   223  	}
   224  }
   225  
   226  // Prevents #338
   227  func TestRowErrors(t *testing.T) {
   228  	provider, mock, err := db_provider.NewDbProviderMock()
   229  	if err != nil {
   230  		t.Fatalf("Failed to create mock: %v", err)
   231  	}
   232  	rows := sqlmock.NewRows([]string{"id", "data", "dataType"})
   233  	rows.AddRow("foo", []byte(`{"data":1}`), "request")
   234  	rows.AddRow("bar", []byte(`{"data":2}`), "request")
   235  	rows.RowError(1, errors.New("Error reading from row 1"))
   236  	mock.ExpectQuery(".*").WillReturnRows(rows)
   237  	fetcher := &dbFetcher{
   238  		provider:      provider,
   239  		queryTemplate: "SELECT id, data, dataType FROM my_table WHERE id IN (?)",
   240  	}
   241  	data, _, errs := fetcher.FetchRequests(context.Background(), []string{"foo", "bar"}, nil)
   242  	assertErrorCount(t, 1, errs)
   243  	if errs[0].Error() != "Error reading from row 1" {
   244  		t.Errorf("Unexpected error message: %v", errs[0].Error())
   245  	}
   246  	assertMapLength(t, 0, data)
   247  }
   248  
   249  func TestRowErrorsFetchResponses(t *testing.T) {
   250  	provider, mock, err := db_provider.NewDbProviderMock()
   251  	if err != nil {
   252  		t.Fatalf("Failed to create mock: %v", err)
   253  	}
   254  	rows := sqlmock.NewRows([]string{"id", "data", "dataType"})
   255  	rows.AddRow("foo", []byte(`{"data":1}`), "response")
   256  	rows.AddRow("bar", []byte(`{"data":2}`), "response")
   257  	rows.RowError(1, errors.New("Error reading from row 1"))
   258  	mock.ExpectQuery(".*").WillReturnRows(rows)
   259  	fetcher := &dbFetcher{
   260  		provider:              provider,
   261  		queryTemplate:         "SELECT id, data, dataType FROM my_table WHERE id IN (?)",
   262  		responseQueryTemplate: "SELECT id, data, dataType FROM my_table WHERE id IN (?)",
   263  	}
   264  	data, errs := fetcher.FetchResponses(context.Background(), []string{"foo", "bar"})
   265  	assertErrorCount(t, 1, errs)
   266  	if errs[0].Error() != "Error reading from row 1" {
   267  		t.Errorf("Unexpected error message: %v", errs[0].Error())
   268  	}
   269  	assertMapLength(t, 0, data)
   270  }
   271  
   272  func newFetcher(t *testing.T, rows *sqlmock.Rows, query string, args ...driver.Value) (sqlmock.Sqlmock, *dbFetcher) {
   273  	provider, mock, err := db_provider.NewDbProviderMock()
   274  	if err != nil {
   275  		t.Fatalf("Failed to create mock: %v", err)
   276  		return nil, nil
   277  	}
   278  
   279  	queryRegex := fmt.Sprintf("^%s$", regexp.QuoteMeta(query))
   280  	mock.ExpectQuery(queryRegex).WithArgs(args...).WillReturnRows(rows)
   281  	fetcher := &dbFetcher{
   282  		provider:              provider,
   283  		queryTemplate:         query,
   284  		responseQueryTemplate: query,
   285  	}
   286  
   287  	return mock, fetcher
   288  }
   289  
   290  func assertMapLength(t *testing.T, numExpected int, configs map[string]json.RawMessage) {
   291  	t.Helper()
   292  	if len(configs) != numExpected {
   293  		t.Errorf("Wrong num configs. Expected %d, Got %d.", numExpected, len(configs))
   294  	}
   295  }
   296  
   297  func assertMockExpectations(t *testing.T, mock sqlmock.Sqlmock) {
   298  	t.Helper()
   299  	if err := mock.ExpectationsWereMet(); err != nil {
   300  		t.Errorf("Mock expectations not met: %v", err)
   301  	}
   302  }
   303  
   304  func assertHasData(t *testing.T, data map[string]json.RawMessage, key string, value string) {
   305  	t.Helper()
   306  	cfg, ok := data[key]
   307  	if !ok {
   308  		t.Errorf("Missing expected stored request data: %s", key)
   309  	}
   310  	if string(cfg) != value {
   311  		t.Errorf("Bad data[%s] value. Expected %s, Got %s", key, value, cfg)
   312  	}
   313  }
   314  
   315  func assertErrorCount(t *testing.T, num int, errs []error) {
   316  	t.Helper()
   317  	if len(errs) != num {
   318  		t.Errorf("Wrong number of errors. Expected %d. Got %d. Errors are %v", num, len(errs), errs)
   319  	}
   320  }