github.com/letsencrypt/boulder@v0.20251208.0/test/pardot-test-srv/main.go (about)

     1  package main
     2  
     3  import (
     4  	"crypto/rand"
     5  	"encoding/json"
     6  	"flag"
     7  	"fmt"
     8  	"io"
     9  	"log"
    10  	"net/http"
    11  	"os"
    12  	"slices"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/letsencrypt/boulder/cmd"
    17  )
    18  
    19  var contactsCap = 20
    20  
    21  type config struct {
    22  	// OAuthAddr is the address (e.g. IP:port) on which the Salesforce REST API
    23  	// and OAuth API server will listen.
    24  	//
    25  	// Deprecated: Use SalesforceAddr instead.
    26  	// TODO(#8410): Remove this field.
    27  	OAuthAddr string
    28  
    29  	// SalesforceAddr is the address (e.g. IP:port) on which the Salesforce REST
    30  	// API and OAuth API server will listen.
    31  	SalesforceAddr string
    32  
    33  	// PardotAddr is the address (e.g. IP:port) on which the Pardot server will
    34  	// listen.
    35  	PardotAddr string
    36  
    37  	// ExpectedClientID is the client ID that the server expects to receive in
    38  	// requests to the /services/oauth2/token endpoint.
    39  	ExpectedClientID string `validate:"required"`
    40  
    41  	// ExpectedClientSecret is the client secret that the server expects to
    42  	// receive in requests to the /services/oauth2/token endpoint.
    43  	ExpectedClientSecret string `validate:"required"`
    44  }
    45  
    46  type contacts struct {
    47  	sync.Mutex
    48  	created []string
    49  }
    50  
    51  type cases struct {
    52  	sync.Mutex
    53  	created []map[string]any
    54  }
    55  
    56  type testServer struct {
    57  	expectedClientID     string
    58  	expectedClientSecret string
    59  	token                string
    60  	contacts             contacts
    61  	cases                cases
    62  }
    63  
    64  func (ts *testServer) getTokenHandler(w http.ResponseWriter, r *http.Request) {
    65  	err := r.ParseForm()
    66  	if err != nil {
    67  		http.Error(w, "Invalid request", http.StatusBadRequest)
    68  		return
    69  	}
    70  
    71  	clientID := r.FormValue("client_id")
    72  	clientSecret := r.FormValue("client_secret")
    73  
    74  	if clientID != ts.expectedClientID || clientSecret != ts.expectedClientSecret {
    75  		http.Error(w, "Invalid credentials", http.StatusUnauthorized)
    76  		return
    77  	}
    78  
    79  	response := map[string]any{
    80  		"access_token": ts.token,
    81  		"token_type":   "Bearer",
    82  		"expires_in":   3600,
    83  	}
    84  
    85  	w.Header().Set("Content-Type", "application/json")
    86  	err = json.NewEncoder(w).Encode(response)
    87  	if err != nil {
    88  		log.Printf("Failed to encode token response: %v", err)
    89  		http.Error(w, "Failed to encode token response", http.StatusInternalServerError)
    90  	}
    91  }
    92  
    93  func (ts *testServer) checkToken(w http.ResponseWriter, r *http.Request) {
    94  	token := r.Header.Get("Authorization")
    95  	if token != "Bearer "+ts.token {
    96  		http.Error(w, "Unauthorized", http.StatusUnauthorized)
    97  		return
    98  	}
    99  }
   100  
   101  func (ts *testServer) upsertContactsHandler(w http.ResponseWriter, r *http.Request) {
   102  	ts.checkToken(w, r)
   103  
   104  	businessUnitId := r.Header.Get("Pardot-Business-Unit-Id")
   105  	if businessUnitId == "" {
   106  		http.Error(w, "Missing 'Pardot-Business-Unit-Id' header", http.StatusBadRequest)
   107  		return
   108  	}
   109  
   110  	body, err := io.ReadAll(r.Body)
   111  	if err != nil {
   112  		http.Error(w, "Failed to read request body", http.StatusInternalServerError)
   113  		return
   114  	}
   115  
   116  	type upsertPayload struct {
   117  		MatchEmail string `json:"matchEmail"`
   118  		Prospect   struct {
   119  			Email string `json:"email"`
   120  		} `json:"prospect"`
   121  	}
   122  
   123  	var payload upsertPayload
   124  	err = json.Unmarshal(body, &payload)
   125  	if err != nil {
   126  		http.Error(w, "Failed to parse request body", http.StatusBadRequest)
   127  		return
   128  	}
   129  
   130  	if payload.MatchEmail == "" || payload.Prospect.Email == "" {
   131  		http.Error(w, "Missing 'matchEmail' or 'prospect.email' in request body", http.StatusBadRequest)
   132  		return
   133  	}
   134  
   135  	ts.contacts.Lock()
   136  	if len(ts.contacts.created) >= contactsCap {
   137  		// Copying the slice in memory is inefficient, but this is a test server
   138  		// with a small number of contacts, so it's fine.
   139  		ts.contacts.created = ts.contacts.created[1:]
   140  	}
   141  	ts.contacts.created = append(ts.contacts.created, payload.Prospect.Email)
   142  	ts.contacts.Unlock()
   143  
   144  	w.Header().Set("Content-Type", "application/json")
   145  	w.Write([]byte(`{"status": "success"}`))
   146  }
   147  
   148  func (ts *testServer) queryContactsHandler(w http.ResponseWriter, r *http.Request) {
   149  	ts.checkToken(w, r)
   150  
   151  	ts.contacts.Lock()
   152  	respContacts := slices.Clone(ts.contacts.created)
   153  	ts.contacts.Unlock()
   154  
   155  	w.Header().Set("Content-Type", "application/json")
   156  	err := json.NewEncoder(w).Encode(map[string]any{"contacts": respContacts})
   157  	if err != nil {
   158  		log.Printf("Failed to encode contacts query response: %v", err)
   159  		http.Error(w, "Failed to encode contacts query response", http.StatusInternalServerError)
   160  	}
   161  }
   162  
   163  func (ts *testServer) createCaseHandler(w http.ResponseWriter, r *http.Request) {
   164  	ts.checkToken(w, r)
   165  
   166  	var payload map[string]any
   167  	if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
   168  		http.Error(w, "Invalid JSON", http.StatusBadRequest)
   169  		return
   170  	}
   171  
   172  	_, ok := payload["Origin"]
   173  	if !ok {
   174  		http.Error(w, "Missing required field: Origin", http.StatusBadRequest)
   175  		return
   176  	}
   177  
   178  	ts.cases.Lock()
   179  	ts.cases.created = append(ts.cases.created, payload)
   180  	ts.cases.Unlock()
   181  
   182  	resp := map[string]any{
   183  		"id":      fmt.Sprintf("500xx00000%06dAAA", len(ts.cases.created)+1),
   184  		"success": true,
   185  		"errors":  []string{},
   186  	}
   187  	w.Header().Set("Content-Type", "application/json")
   188  	w.WriteHeader(http.StatusCreated)
   189  	err := json.NewEncoder(w).Encode(resp)
   190  	if err != nil {
   191  		log.Printf("Failed to encode case creation response: %s", err)
   192  		http.Error(w, "Failed to encode case creation response", http.StatusInternalServerError)
   193  	}
   194  }
   195  
   196  func (ts *testServer) queryCasesHandler(w http.ResponseWriter, r *http.Request) {
   197  	ts.checkToken(w, r)
   198  
   199  	ts.cases.Lock()
   200  	respCases := slices.Clone(ts.cases.created)
   201  	ts.cases.Unlock()
   202  
   203  	w.Header().Set("Content-Type", "application/json")
   204  	err := json.NewEncoder(w).Encode(map[string]any{"cases": respCases})
   205  	if err != nil {
   206  		log.Printf("Failed to encode cases query response: %v", err)
   207  		http.Error(w, "Failed to encode cases query response", http.StatusInternalServerError)
   208  	}
   209  }
   210  
   211  func main() {
   212  	// TODO(#8410): Remove the oauthAddr flag.
   213  	oauthAddr := flag.String("oauth-addr", "", "Salesforce REST API and OAuth server listen address override (deprecated: use --salesforce-addr instead)")
   214  	salesforceAddr := flag.String("salesforce-addr", "", "Salesforce REST API and OAuth server listen address override")
   215  	pardotAddr := flag.String("pardot-addr", "", "Pardot server listen address override")
   216  	configFile := flag.String("config", "", "Path to configuration file")
   217  	flag.Parse()
   218  
   219  	if *configFile == "" {
   220  		flag.Usage()
   221  		os.Exit(1)
   222  	}
   223  
   224  	var c config
   225  	err := cmd.ReadConfigFile(*configFile, &c)
   226  	cmd.FailOnError(err, "Reading JSON config file into config structure")
   227  
   228  	// TODO(#8410): Reduce this logic down to just using salesforceAddr once
   229  	// oauthAddr is removed.
   230  	firstNonEmpty := func(vals ...string) string {
   231  		for _, v := range vals {
   232  			if v != "" {
   233  				return v
   234  			}
   235  		}
   236  		return ""
   237  	}
   238  	c.SalesforceAddr = firstNonEmpty(*salesforceAddr, c.SalesforceAddr, *oauthAddr, c.OAuthAddr)
   239  	if c.SalesforceAddr == "" {
   240  		log.Fatal("--salesforce-addr or JSON salesforceAddr must be set (or use deprecated --oauth-addr or JSON oauthAddr until removed)")
   241  	}
   242  
   243  	if *pardotAddr != "" {
   244  		c.PardotAddr = *pardotAddr
   245  	}
   246  
   247  	tokenBytes := make([]byte, 32)
   248  	_, err = rand.Read(tokenBytes)
   249  	if err != nil {
   250  		log.Fatalf("Failed to generate token: %v", err)
   251  	}
   252  
   253  	ts := &testServer{
   254  		expectedClientID:     c.ExpectedClientID,
   255  		expectedClientSecret: c.ExpectedClientSecret,
   256  		token:                fmt.Sprintf("%x", tokenBytes),
   257  		contacts:             contacts{created: make([]string, 0, contactsCap)},
   258  		cases:                cases{created: make([]map[string]any, 0)},
   259  	}
   260  
   261  	// Salesforce REST API and OAuth Server
   262  	oauthMux := http.NewServeMux()
   263  	oauthMux.HandleFunc("/services/oauth2/token", ts.getTokenHandler)
   264  	oauthMux.HandleFunc("/services/data/v64.0/sobjects/Case", ts.createCaseHandler)
   265  	oauthMux.HandleFunc("/cases", ts.queryCasesHandler)
   266  	oauthServer := &http.Server{
   267  		Addr:        c.SalesforceAddr,
   268  		Handler:     oauthMux,
   269  		ReadTimeout: 30 * time.Second,
   270  	}
   271  
   272  	log.Printf("pardot-test-srv Salesforce REST API and OAuth server listening at %s", c.SalesforceAddr)
   273  	go func() {
   274  		err := oauthServer.ListenAndServe()
   275  		if err != nil {
   276  			log.Fatalf("Failed to start Salesforce REST API and OAuth server: %s", err)
   277  		}
   278  	}()
   279  
   280  	// Pardot API Server
   281  	pardotMux := http.NewServeMux()
   282  	pardotMux.HandleFunc("/api/v5/objects/prospects/do/upsertLatestByEmail", ts.upsertContactsHandler)
   283  	pardotMux.HandleFunc("/contacts", ts.queryContactsHandler)
   284  
   285  	pardotServer := &http.Server{
   286  		Addr:        c.PardotAddr,
   287  		Handler:     pardotMux,
   288  		ReadTimeout: 30 * time.Second,
   289  	}
   290  	log.Printf("pardot-test-srv Salesforce Pardot API server listening at %s", c.PardotAddr)
   291  	go func() {
   292  		err := pardotServer.ListenAndServe()
   293  		if err != nil {
   294  			log.Fatalf("Failed to start Salesforce Pardot API server: %s", err)
   295  		}
   296  	}()
   297  
   298  	cmd.WaitForSignal()
   299  }