github.com/prebid/prebid-server/v2@v2.18.0/stored_requests/events/api/api_test.go (about)

     1  package api
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"strings"
    10  	"testing"
    11  
    12  	"github.com/prebid/prebid-server/v2/stored_requests"
    13  	"github.com/prebid/prebid-server/v2/stored_requests/caches/memory"
    14  	"github.com/prebid/prebid-server/v2/stored_requests/events"
    15  )
    16  
    17  func TestGoodRequests(t *testing.T) {
    18  	cache := stored_requests.Cache{
    19  		Requests:  memory.NewCache(256*1024, -1, "Request"),
    20  		Imps:      memory.NewCache(256*1024, -1, "Imp"),
    21  		Responses: memory.NewCache(256*1024, -1, "Responses"),
    22  		Accounts:  memory.NewCache(256*1024, -1, "Account"),
    23  	}
    24  	id := "1"
    25  	config := fmt.Sprintf(`{"id": "%s"}`, id)
    26  	initialValue := map[string]json.RawMessage{id: json.RawMessage(config)}
    27  	cache.Requests.Save(context.Background(), initialValue)
    28  	cache.Imps.Save(context.Background(), initialValue)
    29  	cache.Responses.Save(context.Background(), initialValue)
    30  
    31  	apiEvents, endpoint := NewEventsAPI()
    32  
    33  	// create channels to syncronize
    34  	updateOccurred := make(chan struct{})
    35  	invalidateOccurred := make(chan struct{})
    36  	listener := events.NewEventListener(
    37  		func() { updateOccurred <- struct{}{} },
    38  		func() { invalidateOccurred <- struct{}{} },
    39  	)
    40  
    41  	go listener.Listen(cache, apiEvents)
    42  	defer listener.Stop()
    43  
    44  	config = fmt.Sprintf(`{"id": "%s", "updated": true}`, id)
    45  	update := fmt.Sprintf(`{"requests": {"%s": %s}, "imps": {"%s": %s}, "responses": {"%s": %s}}`, id, config, id, config, id, config)
    46  	request := newRequest("POST", update)
    47  
    48  	recorder := httptest.NewRecorder()
    49  	endpoint(recorder, request, nil)
    50  
    51  	if recorder.Code != http.StatusOK {
    52  		t.Fatalf("Unexpected error from request: %s", recorder.Body.String())
    53  	}
    54  
    55  	<-updateOccurred
    56  	reqData := cache.Requests.Get(context.Background(), []string{id})
    57  	impData := cache.Imps.Get(context.Background(), []string{id})
    58  	respData := cache.Responses.Get(context.Background(), []string{id})
    59  	assertHasValue(t, reqData, id, config)
    60  	assertHasValue(t, impData, id, config)
    61  	assertHasValue(t, respData, id, config)
    62  
    63  	invalidation := fmt.Sprintf(`{"requests": ["%s"], "imps": ["%s"], "responses": ["%s"]}`, id, id, id)
    64  	request = newRequest("DELETE", invalidation)
    65  
    66  	recorder = httptest.NewRecorder()
    67  	endpoint(recorder, request, nil)
    68  
    69  	if recorder.Code != http.StatusOK {
    70  		t.Fatalf("Unexpected error from request: %s", recorder.Body.String())
    71  	}
    72  
    73  	<-invalidateOccurred
    74  	reqData = cache.Requests.Get(context.Background(), []string{id})
    75  	impData = cache.Imps.Get(context.Background(), []string{id})
    76  	respData = cache.Responses.Get(context.Background(), []string{id})
    77  	assertMapLength(t, 0, reqData)
    78  	assertMapLength(t, 0, impData)
    79  	assertMapLength(t, 0, respData)
    80  }
    81  
    82  func TestBadRequests(t *testing.T) {
    83  	cache := stored_requests.Cache{
    84  		Requests:  memory.NewCache(256*1024, -1, "Requests"),
    85  		Imps:      memory.NewCache(256*1024, -1, "Imps"),
    86  		Responses: memory.NewCache(256*1024, -1, "Responses"),
    87  	}
    88  	apiEvents, endpoint := NewEventsAPI()
    89  	listener := events.SimpleEventListener()
    90  	go listener.Listen(cache, apiEvents)
    91  	defer listener.Stop()
    92  
    93  	update := "NOT JSON"
    94  	request := newRequest("POST", update)
    95  
    96  	recorder := httptest.NewRecorder()
    97  	endpoint(recorder, request, nil)
    98  
    99  	if recorder.Code != http.StatusBadRequest {
   100  		t.Errorf("Expected error from request, got OK")
   101  	}
   102  
   103  	invalidation := "NOT JSON"
   104  	request = newRequest("DELETE", invalidation)
   105  
   106  	recorder = httptest.NewRecorder()
   107  	endpoint(recorder, request, nil)
   108  
   109  	if recorder.Code != http.StatusBadRequest {
   110  		t.Errorf("Expected error from request, got OK")
   111  	}
   112  
   113  	request = newRequest("GET", "")
   114  	recorder = httptest.NewRecorder()
   115  	endpoint(recorder, request, nil)
   116  
   117  	if recorder.Code != http.StatusMethodNotAllowed {
   118  		t.Errorf("Expected error from request, got OK")
   119  	}
   120  }
   121  
   122  func newRequest(method string, body string) *http.Request {
   123  	return httptest.NewRequest(method, "/stored_requests", strings.NewReader(body))
   124  }
   125  
   126  func assertMapLength(t *testing.T, expectedLen int, theMap map[string]json.RawMessage) {
   127  	t.Helper()
   128  	if len(theMap) != expectedLen {
   129  		t.Errorf("Wrong map length. Expected %d, Got %d.", expectedLen, len(theMap))
   130  	}
   131  }
   132  
   133  func assertHasValue(t *testing.T, m map[string]json.RawMessage, key string, val string) {
   134  	t.Helper()
   135  	realVal, ok := m[key]
   136  	if !ok {
   137  		t.Errorf("Map missing required key: %s", key)
   138  	}
   139  	if val != string(realVal) {
   140  		t.Errorf("Unexpected value at key %s. Expected %s, Got %s", key, val, string(realVal))
   141  	}
   142  }