decred.org/dcrdex@v1.0.5/client/rpcserver/rpcserver_test.go (about)

     1  // This code is available on the terms of the project LICENSE.md file,
     2  // also available online at https://blueoakcouncil.org/license/1.0.0.
     3  
     4  //go:build !live
     5  
     6  package rpcserver
     7  
     8  import (
     9  	"bytes"
    10  	"context"
    11  	"crypto/sha256"
    12  	"encoding/base64"
    13  	"encoding/json"
    14  	"fmt"
    15  	"net/http"
    16  	"os"
    17  	"testing"
    18  	"time"
    19  
    20  	"decred.org/dcrdex/client/asset"
    21  	"decred.org/dcrdex/client/core"
    22  	"decred.org/dcrdex/client/db"
    23  	"decred.org/dcrdex/client/mnemonic"
    24  	"decred.org/dcrdex/client/orderbook"
    25  	"decred.org/dcrdex/dex"
    26  	"decred.org/dcrdex/dex/msgjson"
    27  )
    28  
    29  func init() {
    30  	log = dex.StdOutLogger("TEST", dex.LevelTrace)
    31  }
    32  
    33  var (
    34  	tCtx context.Context
    35  )
    36  
    37  type TCore struct {
    38  	dexExchange              *core.Exchange
    39  	getDEXConfigErr          error
    40  	balanceErr               error
    41  	syncErr                  error
    42  	createWalletErr          error
    43  	newWalletForm            *core.WalletForm
    44  	openWalletErr            error
    45  	rescanWalletErr          error
    46  	walletState              *core.WalletState
    47  	closeWalletErr           error
    48  	walletStatusErr          error
    49  	wallets                  []*core.WalletState
    50  	initializeClientErr      error
    51  	postBondResult           *core.PostBondResult
    52  	postBondErr              error
    53  	bondOptsErr              error
    54  	exchanges                map[string]*core.Exchange
    55  	loginErr                 error
    56  	order                    *core.Order
    57  	tradeErr                 error
    58  	cancelErr                error
    59  	coin                     asset.Coin
    60  	sendErr                  error
    61  	logoutErr                error
    62  	book                     *core.OrderBook
    63  	bookErr                  error
    64  	exportSeed               string
    65  	exportSeedErr            error
    66  	discoverAcctErr          error
    67  	archivedRecords          int
    68  	deleteArchivedRecordsErr error
    69  	setVSPErr                error
    70  	purchaseTicketsErr       error
    71  	stakeStatus              *asset.TicketStakingStatus
    72  	stakeStatusErr           error
    73  	setVotingPrefErr         error
    74  }
    75  
    76  func (c *TCore) Balance(uint32) (uint64, error) {
    77  	return 0, c.balanceErr
    78  }
    79  func (c *TCore) Book(dex string, base, quote uint32) (*core.OrderBook, error) {
    80  	return c.book, c.bookErr
    81  }
    82  func (c *TCore) AckNotes(ids []dex.Bytes) {}
    83  func (c *TCore) AssetBalance(uint32) (*core.WalletBalance, error) {
    84  	return nil, c.balanceErr
    85  }
    86  func (c *TCore) Cancel(oid dex.Bytes) error {
    87  	return c.cancelErr
    88  }
    89  func (c *TCore) CreateWallet(appPW, walletPW []byte, form *core.WalletForm) error {
    90  	c.newWalletForm = form
    91  	return c.createWalletErr
    92  }
    93  func (c *TCore) CloseWallet(assetID uint32) error {
    94  	return c.closeWalletErr
    95  }
    96  func (c *TCore) Exchanges() (exchanges map[string]*core.Exchange) { return c.exchanges }
    97  func (c *TCore) Exchange(host string) (*core.Exchange, error) {
    98  	exchange, ok := c.exchanges[host]
    99  	if !ok {
   100  		return nil, fmt.Errorf("no exchange at %v", host)
   101  	}
   102  	return exchange, nil
   103  }
   104  func (c *TCore) InitializeClient(pw []byte, seed *string) (string, error) {
   105  	var mnemonicSeed string
   106  	if seed == nil {
   107  		_, mnemonicSeed = mnemonic.New()
   108  	}
   109  	return mnemonicSeed, c.initializeClientErr
   110  }
   111  func (c *TCore) Login(appPass []byte) error {
   112  	return c.loginErr
   113  }
   114  func (c *TCore) Logout() error {
   115  	return c.logoutErr
   116  }
   117  func (c *TCore) OpenWallet(assetID uint32, pw []byte) error {
   118  	return c.openWalletErr
   119  }
   120  func (c *TCore) ToggleWalletStatus(assetID uint32, disable bool) error {
   121  	if c.walletStatusErr != nil {
   122  		return c.walletStatusErr
   123  	}
   124  	if c.walletState != nil {
   125  		c.walletState.Disabled = disable
   126  	}
   127  	return c.walletStatusErr
   128  }
   129  func (c *TCore) RescanWallet(assetID uint32, force bool) error {
   130  	return c.rescanWalletErr
   131  }
   132  func (c *TCore) GetDEXConfig(dexAddr string, certI any) (*core.Exchange, error) {
   133  	return c.dexExchange, c.getDEXConfigErr
   134  }
   135  func (c *TCore) PostBond(*core.PostBondForm) (*core.PostBondResult, error) {
   136  	return c.postBondResult, c.postBondErr
   137  }
   138  func (c *TCore) UpdateBondOptions(form *core.BondOptionsForm) error {
   139  	return c.bondOptsErr
   140  }
   141  func (c *TCore) SyncBook(dex string, base, quote uint32) (*orderbook.OrderBook, core.BookFeed, error) {
   142  	return nil, &tBookFeed{}, c.syncErr
   143  }
   144  func (c *TCore) Trade(appPass []byte, form *core.TradeForm) (order *core.Order, err error) {
   145  	return c.order, c.tradeErr
   146  }
   147  func (c *TCore) Wallets() []*core.WalletState {
   148  	return c.wallets
   149  }
   150  func (c *TCore) WalletState(assetID uint32) *core.WalletState {
   151  	return c.walletState
   152  }
   153  func (c *TCore) Send(pw []byte, assetID uint32, value uint64, addr string, subtract bool) (asset.Coin, error) {
   154  	return c.coin, c.sendErr
   155  }
   156  func (c *TCore) ExportSeed(pw []byte) (string, error) {
   157  	return c.exportSeed, c.exportSeedErr
   158  }
   159  func (c *TCore) DiscoverAccount(dexAddr string, pass []byte, certI any) (*core.Exchange, bool, error) {
   160  	return c.dexExchange, false, c.discoverAcctErr
   161  }
   162  func (c *TCore) DeleteArchivedRecords(olderThan *time.Time, matchesFileStr, ordersFileStr string) (int, error) {
   163  	return c.archivedRecords, c.deleteArchivedRecordsErr
   164  }
   165  func (c *TCore) AssetHasActiveOrders(uint32) bool {
   166  	return false
   167  }
   168  func (c *TCore) WalletPeers(assetID uint32) ([]*asset.WalletPeer, error) {
   169  	return nil, nil
   170  }
   171  func (c *TCore) AddWalletPeer(assetID uint32, address string) error {
   172  	return nil
   173  }
   174  func (c *TCore) RemoveWalletPeer(assetID uint32, address string) error {
   175  	return nil
   176  }
   177  func (c *TCore) Notifications(n int) (notes, pokes []*db.Notification, _ error) {
   178  	return nil, nil, nil
   179  }
   180  func (c *TCore) MultiTrade(appPass []byte, form *core.MultiTradeForm) []*core.MultiTradeResult {
   181  	return nil
   182  }
   183  func (c *TCore) SetVSP(assetID uint32, addr string) error {
   184  	return c.setVSPErr
   185  }
   186  func (c *TCore) PurchaseTickets(assetID uint32, pw []byte, n int) error {
   187  	return c.purchaseTicketsErr
   188  }
   189  func (c *TCore) StakeStatus(assetID uint32) (*asset.TicketStakingStatus, error) {
   190  	return c.stakeStatus, c.stakeStatusErr
   191  }
   192  func (c *TCore) SetVotingPreferences(assetID uint32, choices, tSpendPolicy, treasuryPolicy map[string]string) error {
   193  	return c.setVotingPrefErr
   194  }
   195  func (c *TCore) TxHistory(assetID uint32, n int, refID *string, past bool) ([]*asset.WalletTransaction, error) {
   196  	return nil, nil
   197  }
   198  func (c *TCore) WalletTransaction(assetID uint32, txID string) (*asset.WalletTransaction, error) {
   199  	return nil, nil
   200  }
   201  func (c *TCore) GenerateBCHRecoveryTransaction(appPW []byte, recipient string) ([]byte, error) {
   202  	return nil, nil
   203  }
   204  
   205  type tBookFeed struct{}
   206  
   207  func (*tBookFeed) Next() <-chan *core.BookUpdate {
   208  	return make(<-chan *core.BookUpdate)
   209  }
   210  func (*tBookFeed) Close() {}
   211  func (*tBookFeed) Candles(dur string) error {
   212  	return nil
   213  }
   214  
   215  func newTServer(t *testing.T, start bool, user, pass string) (*RPCServer, func()) {
   216  	tSrv, fn, err := newTServerWErr(t, start, user, pass)
   217  	if err != nil {
   218  		t.Fatal(err)
   219  	}
   220  	return tSrv, fn
   221  }
   222  func newTServerWErr(t *testing.T, start bool, user, pass string) (*RPCServer, func(), error) {
   223  	t.Helper()
   224  
   225  	var shutdown func()
   226  	ctx, killCtx := context.WithCancel(tCtx)
   227  	tempDir := t.TempDir()
   228  
   229  	cert, key := tempDir+"/cert.cert", tempDir+"/key.key"
   230  	cfg := &Config{
   231  		Core: &TCore{},
   232  		Addr: "127.0.0.1:0",
   233  		User: user,
   234  		Pass: pass,
   235  		Cert: cert,
   236  		Key:  key,
   237  	}
   238  	s, err := New(cfg)
   239  	if err != nil {
   240  		killCtx()
   241  		return nil, nil, fmt.Errorf("error creating server: %w", err)
   242  	}
   243  	if start {
   244  		cm := dex.NewConnectionMaster(s)
   245  		err := cm.Connect(ctx)
   246  		if err != nil {
   247  			killCtx()
   248  			return nil, nil, fmt.Errorf("error starting RPCServer: %w", err)
   249  		}
   250  		shutdown = func() {
   251  			killCtx()
   252  			cm.Disconnect()
   253  		}
   254  	} else {
   255  		shutdown = killCtx
   256  	}
   257  	return s, shutdown, nil
   258  }
   259  
   260  func TestMain(m *testing.M) {
   261  	var shutdown func()
   262  	tCtx, shutdown = context.WithCancel(context.Background())
   263  	doIt := func() int {
   264  		defer shutdown()
   265  		return m.Run()
   266  	}
   267  	os.Exit(doIt())
   268  }
   269  
   270  func TestConnectBindError(t *testing.T) {
   271  	s0, shutdown := newTServer(t, true, "", "abc")
   272  	defer shutdown()
   273  
   274  	tempDir := t.TempDir()
   275  
   276  	cert, key := tempDir+"/cert.cert", tempDir+"/key.key"
   277  	cfg := &Config{
   278  		Core: &TCore{},
   279  		Addr: s0.addr,
   280  		User: "",
   281  		Pass: "abc",
   282  		Cert: cert,
   283  		Key:  key,
   284  	}
   285  	s, err := New(cfg)
   286  	if err != nil {
   287  		t.Fatalf("error creating server: %v", err)
   288  	}
   289  
   290  	cm := dex.NewConnectionMaster(s)
   291  	if err = cm.Connect(tCtx); err == nil {
   292  		shutdown() // shutdown both servers with shared context
   293  		cm.Disconnect()
   294  		t.Fatal("should have failed to bind")
   295  	}
   296  }
   297  
   298  type tResponseWriter struct {
   299  	b    []byte
   300  	code int
   301  }
   302  
   303  func (w *tResponseWriter) Header() http.Header {
   304  	return make(http.Header)
   305  }
   306  func (w *tResponseWriter) Write(msg []byte) (int, error) {
   307  	w.b = msg
   308  	return len(msg), nil
   309  }
   310  func (w *tResponseWriter) WriteHeader(statusCode int) {
   311  	w.code = statusCode
   312  }
   313  
   314  func TestParseHTTPRequest(t *testing.T) {
   315  	s, shutdown := newTServer(t, false, "", "abc")
   316  	defer shutdown()
   317  	var r *http.Request
   318  
   319  	ensureHTTPError := func(name string, wantCode int) {
   320  		t.Helper()
   321  		w := &tResponseWriter{}
   322  		s.handleJSON(w, r)
   323  		if w.code != wantCode {
   324  			t.Fatalf("%s: Expected HTTP error %d, got %d",
   325  				name, wantCode, w.code)
   326  		}
   327  	}
   328  
   329  	ensureMsgErr := func(name string, wantCode int) {
   330  		t.Helper()
   331  		w := &tResponseWriter{}
   332  		s.handleJSON(w, r)
   333  		if w.code != 200 {
   334  			t.Fatalf("HTTP error when expecting msgjson.Error")
   335  		}
   336  		resp := new(msgjson.Message)
   337  		if err := json.Unmarshal(w.b, resp); err != nil {
   338  			t.Fatalf("unable to unmarshal response: %v", err)
   339  		}
   340  		payload := new(msgjson.ResponsePayload)
   341  		if err := json.Unmarshal(resp.Payload, payload); err != nil {
   342  			t.Fatalf("unable to unmarshal payload: %v", err)
   343  		}
   344  		if payload.Error == nil {
   345  			t.Fatalf("%s: no error", name)
   346  		}
   347  		if wantCode != payload.Error.Code {
   348  			t.Fatalf("%s, wanted %d, got %d",
   349  				name, wantCode, payload.Error.Code)
   350  		}
   351  	}
   352  	ensureNoErr := func(name string) {
   353  		t.Helper()
   354  		w := &tResponseWriter{}
   355  		s.handleJSON(w, r)
   356  		if w.code != 200 {
   357  			t.Fatalf("HTTP error when expecting no error")
   358  		}
   359  		resp := new(msgjson.Message)
   360  		if err := json.Unmarshal(w.b, resp); err != nil {
   361  			t.Fatalf("unable to unmarshal response: %v", err)
   362  		}
   363  		payload := new(msgjson.ResponsePayload)
   364  		if err := json.Unmarshal(resp.Payload, payload); err != nil {
   365  			t.Fatalf("unable to unmarshal payload: %v", err)
   366  		}
   367  		if payload.Error != nil {
   368  			t.Fatalf("%s: errored", name)
   369  		}
   370  	}
   371  
   372  	// Send a response, which is unsupported on the server.
   373  	msg, _ := msgjson.NewResponse(1, nil, nil)
   374  	b, _ := json.Marshal(msg)
   375  	bbuff := bytes.NewBuffer(b)
   376  	r, _ = http.NewRequest("GET", "", bbuff)
   377  	ensureHTTPError("response", http.StatusMethodNotAllowed)
   378  
   379  	// Unknown route.
   380  	msg, _ = msgjson.NewRequest(1, "123", nil)
   381  	b, _ = json.Marshal(msg)
   382  	bbuff = bytes.NewBuffer(b)
   383  	r, _ = http.NewRequest("GET", "", bbuff)
   384  	ensureMsgErr("bad route", msgjson.RPCUnknownRoute)
   385  
   386  	// Use real route.
   387  	msg, _ = msgjson.NewRequest(1, "version", nil)
   388  	b, _ = json.Marshal(msg)
   389  	bbuff = bytes.NewBuffer(b)
   390  	r, _ = http.NewRequest("GET", "", bbuff)
   391  	ensureNoErr("good request")
   392  
   393  	// Use real route with bad args.
   394  	msg, _ = msgjson.NewRequest(1, "version", "something")
   395  	b, _ = json.Marshal(msg)
   396  	bbuff = bytes.NewBuffer(b)
   397  	r, _ = http.NewRequest("GET", "", bbuff)
   398  	ensureMsgErr("bad params", msgjson.RPCParseError)
   399  }
   400  
   401  func TestNew(t *testing.T) {
   402  	authTests := []struct {
   403  		name, user, pass, wantAuth string
   404  		wantErr                    bool
   405  	}{{
   406  		name:     "ok",
   407  		user:     "user",
   408  		pass:     "pass",
   409  		wantAuth: "AK+rg3mIGeouojwZwNRMjBjZouASr4mu4FWMTXQQcD0=",
   410  	}, {
   411  		name:     "ok various input",
   412  		user:     `&!"#$%&'()~=`,
   413  		pass:     `+<>*?,:.;/][{}`,
   414  		wantAuth: "Te4g4+Ke9Q07MYo3iT1OCqq5qXX2ZcB47FBiVaT41hQ=",
   415  	}, {
   416  		name:    "no password",
   417  		user:    "user",
   418  		wantErr: true,
   419  	}}
   420  	for _, test := range authTests {
   421  		s, shutdown, err := newTServerWErr(t, false, test.user, test.pass)
   422  		if test.wantErr {
   423  			if err == nil {
   424  				t.Fatalf("expected error for test %s", test.name)
   425  			}
   426  			continue
   427  		}
   428  		if err != nil {
   429  			t.Fatalf("unexpected error for test %s: %v", test.name, err)
   430  		}
   431  		auth := base64.StdEncoding.EncodeToString((s.authSHA[:]))
   432  		if auth != test.wantAuth {
   433  			t.Fatalf("expected auth %s but got %s", test.wantAuth, auth)
   434  		}
   435  		shutdown()
   436  	}
   437  }
   438  
   439  func TestAuthMiddleware(t *testing.T) {
   440  	s, shutdown := newTServer(t, false, "", "abc")
   441  	defer shutdown()
   442  	am := s.authMiddleware(http.HandlerFunc(
   443  		func(w http.ResponseWriter, r *http.Request) {
   444  			w.WriteHeader(http.StatusOK)
   445  		}))
   446  	r, _ := http.NewRequest("GET", "", nil)
   447  
   448  	wantAuthError := func(name string, want bool) {
   449  		t.Helper()
   450  		w := &tResponseWriter{}
   451  		am.ServeHTTP(w, r)
   452  		if w.code != http.StatusUnauthorized && w.code != http.StatusOK {
   453  			t.Fatalf("unexpected HTTP error %d for test \"%s\"",
   454  				w.code, name)
   455  		}
   456  		switch want {
   457  		case true:
   458  			if w.code != http.StatusUnauthorized {
   459  				t.Fatalf("Expected unauthorized HTTP error for test \"%s\"",
   460  					name)
   461  			}
   462  		case false:
   463  			if w.code != http.StatusOK {
   464  				t.Fatalf("Expected OK HTTP status for test \"%s\"",
   465  					name)
   466  			}
   467  		}
   468  	}
   469  
   470  	user, pass := "Which one is it?", "It's the one that says bmf on it."
   471  	login := user + ":" + pass
   472  	h := "Basic "
   473  	auth := h + base64.StdEncoding.EncodeToString([]byte(login))
   474  	s.authSHA = sha256.Sum256([]byte(auth))
   475  
   476  	tests := []struct {
   477  		name, user, pass, header string
   478  		hasAuth, wantErr         bool
   479  	}{{
   480  		name:    "auth ok",
   481  		user:    user,
   482  		pass:    pass,
   483  		header:  h,
   484  		hasAuth: true,
   485  		wantErr: false,
   486  	}, {
   487  		name:    "wrong pass",
   488  		user:    user,
   489  		pass:    "password123",
   490  		header:  h,
   491  		hasAuth: true,
   492  		wantErr: true,
   493  	}, {
   494  		name:    "unknown user",
   495  		user:    "Jules",
   496  		pass:    pass,
   497  		header:  h,
   498  		hasAuth: true,
   499  		wantErr: true,
   500  	}, {
   501  		name:    "no header",
   502  		user:    user,
   503  		pass:    pass,
   504  		header:  h,
   505  		hasAuth: false,
   506  		wantErr: true,
   507  	}, {
   508  		name:    "malformed header",
   509  		user:    user,
   510  		pass:    pass,
   511  		header:  "basic ",
   512  		hasAuth: true,
   513  		wantErr: true,
   514  	}}
   515  	for _, test := range tests {
   516  		login = test.user + ":" + test.pass
   517  		auth = test.header + base64.StdEncoding.EncodeToString([]byte(login))
   518  		requestHeader := make(http.Header)
   519  		if test.hasAuth {
   520  			requestHeader.Add("Authorization", auth)
   521  		}
   522  		r.Header = requestHeader
   523  		wantAuthError(test.name, test.wantErr)
   524  	}
   525  }