github.com/ava-labs/avalanchego@v1.11.11/wallet/chain/p/signer/visitor.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package signer
     5  
     6  import (
     7  	"context"
     8  	"errors"
     9  	"fmt"
    10  
    11  	"github.com/ava-labs/avalanchego/database"
    12  	"github.com/ava-labs/avalanchego/ids"
    13  	"github.com/ava-labs/avalanchego/utils/constants"
    14  	"github.com/ava-labs/avalanchego/utils/crypto/keychain"
    15  	"github.com/ava-labs/avalanchego/utils/crypto/secp256k1"
    16  	"github.com/ava-labs/avalanchego/utils/hashing"
    17  	"github.com/ava-labs/avalanchego/vms/components/avax"
    18  	"github.com/ava-labs/avalanchego/vms/components/verify"
    19  	"github.com/ava-labs/avalanchego/vms/platformvm/stakeable"
    20  	"github.com/ava-labs/avalanchego/vms/platformvm/txs"
    21  	"github.com/ava-labs/avalanchego/vms/secp256k1fx"
    22  )
    23  
    24  var (
    25  	_ txs.Visitor = (*visitor)(nil)
    26  
    27  	ErrUnsupportedTxType     = errors.New("unsupported tx type")
    28  	ErrUnknownInputType      = errors.New("unknown input type")
    29  	ErrUnknownOutputType     = errors.New("unknown output type")
    30  	ErrInvalidUTXOSigIndex   = errors.New("invalid UTXO signature index")
    31  	ErrUnknownSubnetAuthType = errors.New("unknown subnet auth type")
    32  	ErrUnknownOwnerType      = errors.New("unknown owner type")
    33  	ErrUnknownCredentialType = errors.New("unknown credential type")
    34  
    35  	emptySig [secp256k1.SignatureLen]byte
    36  )
    37  
    38  // visitor handles signing transactions for the signer
    39  type visitor struct {
    40  	kc      keychain.Keychain
    41  	backend Backend
    42  	ctx     context.Context
    43  	tx      *txs.Tx
    44  }
    45  
    46  func (*visitor) AdvanceTimeTx(*txs.AdvanceTimeTx) error {
    47  	return ErrUnsupportedTxType
    48  }
    49  
    50  func (*visitor) RewardValidatorTx(*txs.RewardValidatorTx) error {
    51  	return ErrUnsupportedTxType
    52  }
    53  
    54  func (s *visitor) BaseTx(tx *txs.BaseTx) error {
    55  	txSigners, err := s.getSigners(constants.PlatformChainID, tx.Ins)
    56  	if err != nil {
    57  		return err
    58  	}
    59  	return sign(s.tx, false, txSigners)
    60  }
    61  
    62  func (s *visitor) AddValidatorTx(tx *txs.AddValidatorTx) error {
    63  	txSigners, err := s.getSigners(constants.PlatformChainID, tx.Ins)
    64  	if err != nil {
    65  		return err
    66  	}
    67  	return sign(s.tx, false, txSigners)
    68  }
    69  
    70  func (s *visitor) AddSubnetValidatorTx(tx *txs.AddSubnetValidatorTx) error {
    71  	txSigners, err := s.getSigners(constants.PlatformChainID, tx.Ins)
    72  	if err != nil {
    73  		return err
    74  	}
    75  	subnetAuthSigners, err := s.getSubnetSigners(tx.SubnetValidator.Subnet, tx.SubnetAuth)
    76  	if err != nil {
    77  		return err
    78  	}
    79  	txSigners = append(txSigners, subnetAuthSigners)
    80  	return sign(s.tx, false, txSigners)
    81  }
    82  
    83  func (s *visitor) AddDelegatorTx(tx *txs.AddDelegatorTx) error {
    84  	txSigners, err := s.getSigners(constants.PlatformChainID, tx.Ins)
    85  	if err != nil {
    86  		return err
    87  	}
    88  	return sign(s.tx, false, txSigners)
    89  }
    90  
    91  func (s *visitor) CreateChainTx(tx *txs.CreateChainTx) error {
    92  	txSigners, err := s.getSigners(constants.PlatformChainID, tx.Ins)
    93  	if err != nil {
    94  		return err
    95  	}
    96  	subnetAuthSigners, err := s.getSubnetSigners(tx.SubnetID, tx.SubnetAuth)
    97  	if err != nil {
    98  		return err
    99  	}
   100  	txSigners = append(txSigners, subnetAuthSigners)
   101  	return sign(s.tx, false, txSigners)
   102  }
   103  
   104  func (s *visitor) CreateSubnetTx(tx *txs.CreateSubnetTx) error {
   105  	txSigners, err := s.getSigners(constants.PlatformChainID, tx.Ins)
   106  	if err != nil {
   107  		return err
   108  	}
   109  	return sign(s.tx, false, txSigners)
   110  }
   111  
   112  func (s *visitor) ImportTx(tx *txs.ImportTx) error {
   113  	txSigners, err := s.getSigners(constants.PlatformChainID, tx.Ins)
   114  	if err != nil {
   115  		return err
   116  	}
   117  	txImportSigners, err := s.getSigners(tx.SourceChain, tx.ImportedInputs)
   118  	if err != nil {
   119  		return err
   120  	}
   121  	txSigners = append(txSigners, txImportSigners...)
   122  	return sign(s.tx, false, txSigners)
   123  }
   124  
   125  func (s *visitor) ExportTx(tx *txs.ExportTx) error {
   126  	txSigners, err := s.getSigners(constants.PlatformChainID, tx.Ins)
   127  	if err != nil {
   128  		return err
   129  	}
   130  	return sign(s.tx, false, txSigners)
   131  }
   132  
   133  func (s *visitor) RemoveSubnetValidatorTx(tx *txs.RemoveSubnetValidatorTx) error {
   134  	txSigners, err := s.getSigners(constants.PlatformChainID, tx.Ins)
   135  	if err != nil {
   136  		return err
   137  	}
   138  	subnetAuthSigners, err := s.getSubnetSigners(tx.Subnet, tx.SubnetAuth)
   139  	if err != nil {
   140  		return err
   141  	}
   142  	txSigners = append(txSigners, subnetAuthSigners)
   143  	return sign(s.tx, true, txSigners)
   144  }
   145  
   146  func (s *visitor) TransferSubnetOwnershipTx(tx *txs.TransferSubnetOwnershipTx) error {
   147  	txSigners, err := s.getSigners(constants.PlatformChainID, tx.Ins)
   148  	if err != nil {
   149  		return err
   150  	}
   151  	subnetAuthSigners, err := s.getSubnetSigners(tx.Subnet, tx.SubnetAuth)
   152  	if err != nil {
   153  		return err
   154  	}
   155  	txSigners = append(txSigners, subnetAuthSigners)
   156  	return sign(s.tx, true, txSigners)
   157  }
   158  
   159  func (s *visitor) TransformSubnetTx(tx *txs.TransformSubnetTx) error {
   160  	txSigners, err := s.getSigners(constants.PlatformChainID, tx.Ins)
   161  	if err != nil {
   162  		return err
   163  	}
   164  	subnetAuthSigners, err := s.getSubnetSigners(tx.Subnet, tx.SubnetAuth)
   165  	if err != nil {
   166  		return err
   167  	}
   168  	txSigners = append(txSigners, subnetAuthSigners)
   169  	return sign(s.tx, true, txSigners)
   170  }
   171  
   172  func (s *visitor) AddPermissionlessValidatorTx(tx *txs.AddPermissionlessValidatorTx) error {
   173  	txSigners, err := s.getSigners(constants.PlatformChainID, tx.Ins)
   174  	if err != nil {
   175  		return err
   176  	}
   177  	return sign(s.tx, true, txSigners)
   178  }
   179  
   180  func (s *visitor) AddPermissionlessDelegatorTx(tx *txs.AddPermissionlessDelegatorTx) error {
   181  	txSigners, err := s.getSigners(constants.PlatformChainID, tx.Ins)
   182  	if err != nil {
   183  		return err
   184  	}
   185  	return sign(s.tx, true, txSigners)
   186  }
   187  
   188  func (s *visitor) getSigners(sourceChainID ids.ID, ins []*avax.TransferableInput) ([][]keychain.Signer, error) {
   189  	txSigners := make([][]keychain.Signer, len(ins))
   190  	for credIndex, transferInput := range ins {
   191  		inIntf := transferInput.In
   192  		if stakeableIn, ok := inIntf.(*stakeable.LockIn); ok {
   193  			inIntf = stakeableIn.TransferableIn
   194  		}
   195  
   196  		input, ok := inIntf.(*secp256k1fx.TransferInput)
   197  		if !ok {
   198  			return nil, ErrUnknownInputType
   199  		}
   200  
   201  		inputSigners := make([]keychain.Signer, len(input.SigIndices))
   202  		txSigners[credIndex] = inputSigners
   203  
   204  		utxoID := transferInput.InputID()
   205  		utxo, err := s.backend.GetUTXO(s.ctx, sourceChainID, utxoID)
   206  		if err == database.ErrNotFound {
   207  			// If we don't have access to the UTXO, then we can't sign this
   208  			// transaction. However, we can attempt to partially sign it.
   209  			continue
   210  		}
   211  		if err != nil {
   212  			return nil, err
   213  		}
   214  
   215  		outIntf := utxo.Out
   216  		if stakeableOut, ok := outIntf.(*stakeable.LockOut); ok {
   217  			outIntf = stakeableOut.TransferableOut
   218  		}
   219  
   220  		out, ok := outIntf.(*secp256k1fx.TransferOutput)
   221  		if !ok {
   222  			return nil, ErrUnknownOutputType
   223  		}
   224  
   225  		for sigIndex, addrIndex := range input.SigIndices {
   226  			if addrIndex >= uint32(len(out.Addrs)) {
   227  				return nil, ErrInvalidUTXOSigIndex
   228  			}
   229  
   230  			addr := out.Addrs[addrIndex]
   231  			key, ok := s.kc.Get(addr)
   232  			if !ok {
   233  				// If we don't have access to the key, then we can't sign this
   234  				// transaction. However, we can attempt to partially sign it.
   235  				continue
   236  			}
   237  			inputSigners[sigIndex] = key
   238  		}
   239  	}
   240  	return txSigners, nil
   241  }
   242  
   243  func (s *visitor) getSubnetSigners(subnetID ids.ID, subnetAuth verify.Verifiable) ([]keychain.Signer, error) {
   244  	subnetInput, ok := subnetAuth.(*secp256k1fx.Input)
   245  	if !ok {
   246  		return nil, ErrUnknownSubnetAuthType
   247  	}
   248  
   249  	ownerIntf, err := s.backend.GetSubnetOwner(s.ctx, subnetID)
   250  	if err != nil {
   251  		return nil, fmt.Errorf(
   252  			"failed to fetch subnet owner for %q: %w",
   253  			subnetID,
   254  			err,
   255  		)
   256  	}
   257  	owner, ok := ownerIntf.(*secp256k1fx.OutputOwners)
   258  	if !ok {
   259  		return nil, ErrUnknownOwnerType
   260  	}
   261  
   262  	authSigners := make([]keychain.Signer, len(subnetInput.SigIndices))
   263  	for sigIndex, addrIndex := range subnetInput.SigIndices {
   264  		if addrIndex >= uint32(len(owner.Addrs)) {
   265  			return nil, ErrInvalidUTXOSigIndex
   266  		}
   267  
   268  		addr := owner.Addrs[addrIndex]
   269  		key, ok := s.kc.Get(addr)
   270  		if !ok {
   271  			// If we don't have access to the key, then we can't sign this
   272  			// transaction. However, we can attempt to partially sign it.
   273  			continue
   274  		}
   275  		authSigners[sigIndex] = key
   276  	}
   277  	return authSigners, nil
   278  }
   279  
   280  // TODO: remove [signHash] after the ledger supports signing all transactions.
   281  func sign(tx *txs.Tx, signHash bool, txSigners [][]keychain.Signer) error {
   282  	unsignedBytes, err := txs.Codec.Marshal(txs.CodecVersion, &tx.Unsigned)
   283  	if err != nil {
   284  		return fmt.Errorf("couldn't marshal unsigned tx: %w", err)
   285  	}
   286  	unsignedHash := hashing.ComputeHash256(unsignedBytes)
   287  
   288  	if expectedLen := len(txSigners); expectedLen != len(tx.Creds) {
   289  		tx.Creds = make([]verify.Verifiable, expectedLen)
   290  	}
   291  
   292  	sigCache := make(map[ids.ShortID][secp256k1.SignatureLen]byte)
   293  	for credIndex, inputSigners := range txSigners {
   294  		credIntf := tx.Creds[credIndex]
   295  		if credIntf == nil {
   296  			credIntf = &secp256k1fx.Credential{}
   297  			tx.Creds[credIndex] = credIntf
   298  		}
   299  
   300  		cred, ok := credIntf.(*secp256k1fx.Credential)
   301  		if !ok {
   302  			return ErrUnknownCredentialType
   303  		}
   304  		if expectedLen := len(inputSigners); expectedLen != len(cred.Sigs) {
   305  			cred.Sigs = make([][secp256k1.SignatureLen]byte, expectedLen)
   306  		}
   307  
   308  		for sigIndex, signer := range inputSigners {
   309  			if signer == nil {
   310  				// If we don't have access to the key, then we can't sign this
   311  				// transaction. However, we can attempt to partially sign it.
   312  				continue
   313  			}
   314  			addr := signer.Address()
   315  			if sig := cred.Sigs[sigIndex]; sig != emptySig {
   316  				// If this signature has already been populated, we can just
   317  				// copy the needed signature for the future.
   318  				sigCache[addr] = sig
   319  				continue
   320  			}
   321  
   322  			if sig, exists := sigCache[addr]; exists {
   323  				// If this key has already produced a signature, we can just
   324  				// copy the previous signature.
   325  				cred.Sigs[sigIndex] = sig
   326  				continue
   327  			}
   328  
   329  			var sig []byte
   330  			if signHash {
   331  				sig, err = signer.SignHash(unsignedHash)
   332  			} else {
   333  				sig, err = signer.Sign(unsignedBytes)
   334  			}
   335  			if err != nil {
   336  				return fmt.Errorf("problem signing tx: %w", err)
   337  			}
   338  			copy(cred.Sigs[sigIndex][:], sig)
   339  			sigCache[addr] = cred.Sigs[sigIndex]
   340  		}
   341  	}
   342  
   343  	signedBytes, err := txs.Codec.Marshal(txs.CodecVersion, tx)
   344  	if err != nil {
   345  		return fmt.Errorf("couldn't marshal tx: %w", err)
   346  	}
   347  	tx.SetBytes(unsignedBytes, signedBytes)
   348  	return nil
   349  }