github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/crypto/ssh/agent/server.go (about)

     1  // Copyright 2012 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package agent
     6  
     7  import (
     8  	"crypto/dsa"
     9  	"crypto/ecdsa"
    10  	"crypto/ed25519"
    11  	"crypto/elliptic"
    12  	"crypto/rsa"
    13  	"encoding/binary"
    14  	"errors"
    15  	"fmt"
    16  	"io"
    17  	"log"
    18  	"math/big"
    19  
    20  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/crypto/ssh"
    21  )
    22  
    23  // Server wraps an Agent and uses it to implement the agent side of
    24  // the SSH-agent, wire protocol.
    25  type server struct {
    26  	agent Agent
    27  }
    28  
    29  func (s *server) processRequestBytes(reqData []byte) []byte {
    30  	rep, err := s.processRequest(reqData)
    31  	if err != nil {
    32  		if err != errLocked {
    33  			// TODO(hanwen): provide better logging interface?
    34  			log.Printf("agent %d: %v", reqData[0], err)
    35  		}
    36  		return []byte{agentFailure}
    37  	}
    38  
    39  	if err == nil && rep == nil {
    40  		return []byte{agentSuccess}
    41  	}
    42  
    43  	return ssh.Marshal(rep)
    44  }
    45  
    46  func marshalKey(k *Key) []byte {
    47  	var record struct {
    48  		Blob    []byte
    49  		Comment string
    50  	}
    51  	record.Blob = k.Marshal()
    52  	record.Comment = k.Comment
    53  
    54  	return ssh.Marshal(&record)
    55  }
    56  
    57  // See [PROTOCOL.agent], section 2.5.1.
    58  const agentV1IdentitiesAnswer = 2
    59  
    60  type agentV1IdentityMsg struct {
    61  	Numkeys uint32 `sshtype:"2"`
    62  }
    63  
    64  type agentRemoveIdentityMsg struct {
    65  	KeyBlob []byte `sshtype:"18"`
    66  }
    67  
    68  type agentLockMsg struct {
    69  	Passphrase []byte `sshtype:"22"`
    70  }
    71  
    72  type agentUnlockMsg struct {
    73  	Passphrase []byte `sshtype:"23"`
    74  }
    75  
    76  func (s *server) processRequest(data []byte) (interface{}, error) {
    77  	switch data[0] {
    78  	case agentRequestV1Identities:
    79  		return &agentV1IdentityMsg{0}, nil
    80  
    81  	case agentRemoveAllV1Identities:
    82  		return nil, nil
    83  
    84  	case agentRemoveIdentity:
    85  		var req agentRemoveIdentityMsg
    86  		if err := ssh.Unmarshal(data, &req); err != nil {
    87  			return nil, err
    88  		}
    89  
    90  		var wk wireKey
    91  		if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
    92  			return nil, err
    93  		}
    94  
    95  		return nil, s.agent.Remove(&Key{Format: wk.Format, Blob: req.KeyBlob})
    96  
    97  	case agentRemoveAllIdentities:
    98  		return nil, s.agent.RemoveAll()
    99  
   100  	case agentLock:
   101  		var req agentLockMsg
   102  		if err := ssh.Unmarshal(data, &req); err != nil {
   103  			return nil, err
   104  		}
   105  
   106  		return nil, s.agent.Lock(req.Passphrase)
   107  
   108  	case agentUnlock:
   109  		var req agentUnlockMsg
   110  		if err := ssh.Unmarshal(data, &req); err != nil {
   111  			return nil, err
   112  		}
   113  		return nil, s.agent.Unlock(req.Passphrase)
   114  
   115  	case agentSignRequest:
   116  		var req signRequestAgentMsg
   117  		if err := ssh.Unmarshal(data, &req); err != nil {
   118  			return nil, err
   119  		}
   120  
   121  		var wk wireKey
   122  		if err := ssh.Unmarshal(req.KeyBlob, &wk); err != nil {
   123  			return nil, err
   124  		}
   125  
   126  		k := &Key{
   127  			Format: wk.Format,
   128  			Blob:   req.KeyBlob,
   129  		}
   130  
   131  		var sig *ssh.Signature
   132  		var err error
   133  		if extendedAgent, ok := s.agent.(ExtendedAgent); ok {
   134  			sig, err = extendedAgent.SignWithFlags(k, req.Data, SignatureFlags(req.Flags))
   135  		} else {
   136  			sig, err = s.agent.Sign(k, req.Data)
   137  		}
   138  
   139  		if err != nil {
   140  			return nil, err
   141  		}
   142  		return &signResponseAgentMsg{SigBlob: ssh.Marshal(sig)}, nil
   143  
   144  	case agentRequestIdentities:
   145  		keys, err := s.agent.List()
   146  		if err != nil {
   147  			return nil, err
   148  		}
   149  
   150  		rep := identitiesAnswerAgentMsg{
   151  			NumKeys: uint32(len(keys)),
   152  		}
   153  		for _, k := range keys {
   154  			rep.Keys = append(rep.Keys, marshalKey(k)...)
   155  		}
   156  		return rep, nil
   157  
   158  	case agentAddIDConstrained, agentAddIdentity:
   159  		return nil, s.insertIdentity(data)
   160  
   161  	case agentExtension:
   162  		// Return a stub object where the whole contents of the response gets marshaled.
   163  		var responseStub struct {
   164  			Rest []byte `ssh:"rest"`
   165  		}
   166  
   167  		if extendedAgent, ok := s.agent.(ExtendedAgent); !ok {
   168  			// If this agent doesn't implement extensions, [PROTOCOL.agent] section 4.7
   169  			// requires that we return a standard SSH_AGENT_FAILURE message.
   170  			responseStub.Rest = []byte{agentFailure}
   171  		} else {
   172  			var req extensionAgentMsg
   173  			if err := ssh.Unmarshal(data, &req); err != nil {
   174  				return nil, err
   175  			}
   176  			res, err := extendedAgent.Extension(req.ExtensionType, req.Contents)
   177  			if err != nil {
   178  				// If agent extensions are unsupported, return a standard SSH_AGENT_FAILURE
   179  				// message as required by [PROTOCOL.agent] section 4.7.
   180  				if err == ErrExtensionUnsupported {
   181  					responseStub.Rest = []byte{agentFailure}
   182  				} else {
   183  					// As the result of any other error processing an extension request,
   184  					// [PROTOCOL.agent] section 4.7 requires that we return a
   185  					// SSH_AGENT_EXTENSION_FAILURE code.
   186  					responseStub.Rest = []byte{agentExtensionFailure}
   187  				}
   188  			} else {
   189  				if len(res) == 0 {
   190  					return nil, nil
   191  				}
   192  				responseStub.Rest = res
   193  			}
   194  		}
   195  
   196  		return responseStub, nil
   197  	}
   198  
   199  	return nil, fmt.Errorf("unknown opcode %d", data[0])
   200  }
   201  
   202  func parseConstraints(constraints []byte) (lifetimeSecs uint32, confirmBeforeUse bool, extensions []ConstraintExtension, err error) {
   203  	for len(constraints) != 0 {
   204  		switch constraints[0] {
   205  		case agentConstrainLifetime:
   206  			lifetimeSecs = binary.BigEndian.Uint32(constraints[1:5])
   207  			constraints = constraints[5:]
   208  		case agentConstrainConfirm:
   209  			confirmBeforeUse = true
   210  			constraints = constraints[1:]
   211  		case agentConstrainExtension:
   212  			var msg constrainExtensionAgentMsg
   213  			if err = ssh.Unmarshal(constraints, &msg); err != nil {
   214  				return 0, false, nil, err
   215  			}
   216  			extensions = append(extensions, ConstraintExtension{
   217  				ExtensionName:    msg.ExtensionName,
   218  				ExtensionDetails: msg.ExtensionDetails,
   219  			})
   220  			constraints = msg.Rest
   221  		default:
   222  			return 0, false, nil, fmt.Errorf("unknown constraint type: %d", constraints[0])
   223  		}
   224  	}
   225  	return
   226  }
   227  
   228  func setConstraints(key *AddedKey, constraintBytes []byte) error {
   229  	lifetimeSecs, confirmBeforeUse, constraintExtensions, err := parseConstraints(constraintBytes)
   230  	if err != nil {
   231  		return err
   232  	}
   233  
   234  	key.LifetimeSecs = lifetimeSecs
   235  	key.ConfirmBeforeUse = confirmBeforeUse
   236  	key.ConstraintExtensions = constraintExtensions
   237  	return nil
   238  }
   239  
   240  func parseRSAKey(req []byte) (*AddedKey, error) {
   241  	var k rsaKeyMsg
   242  	if err := ssh.Unmarshal(req, &k); err != nil {
   243  		return nil, err
   244  	}
   245  	if k.E.BitLen() > 30 {
   246  		return nil, errors.New("agent: RSA public exponent too large")
   247  	}
   248  	priv := &rsa.PrivateKey{
   249  		PublicKey: rsa.PublicKey{
   250  			E: int(k.E.Int64()),
   251  			N: k.N,
   252  		},
   253  		D:      k.D,
   254  		Primes: []*big.Int{k.P, k.Q},
   255  	}
   256  	priv.Precompute()
   257  
   258  	addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
   259  	if err := setConstraints(addedKey, k.Constraints); err != nil {
   260  		return nil, err
   261  	}
   262  	return addedKey, nil
   263  }
   264  
   265  func parseEd25519Key(req []byte) (*AddedKey, error) {
   266  	var k ed25519KeyMsg
   267  	if err := ssh.Unmarshal(req, &k); err != nil {
   268  		return nil, err
   269  	}
   270  	priv := ed25519.PrivateKey(k.Priv)
   271  
   272  	addedKey := &AddedKey{PrivateKey: &priv, Comment: k.Comments}
   273  	if err := setConstraints(addedKey, k.Constraints); err != nil {
   274  		return nil, err
   275  	}
   276  	return addedKey, nil
   277  }
   278  
   279  func parseDSAKey(req []byte) (*AddedKey, error) {
   280  	var k dsaKeyMsg
   281  	if err := ssh.Unmarshal(req, &k); err != nil {
   282  		return nil, err
   283  	}
   284  	priv := &dsa.PrivateKey{
   285  		PublicKey: dsa.PublicKey{
   286  			Parameters: dsa.Parameters{
   287  				P: k.P,
   288  				Q: k.Q,
   289  				G: k.G,
   290  			},
   291  			Y: k.Y,
   292  		},
   293  		X: k.X,
   294  	}
   295  
   296  	addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
   297  	if err := setConstraints(addedKey, k.Constraints); err != nil {
   298  		return nil, err
   299  	}
   300  	return addedKey, nil
   301  }
   302  
   303  func unmarshalECDSA(curveName string, keyBytes []byte, privScalar *big.Int) (priv *ecdsa.PrivateKey, err error) {
   304  	priv = &ecdsa.PrivateKey{
   305  		D: privScalar,
   306  	}
   307  
   308  	switch curveName {
   309  	case "nistp256":
   310  		priv.Curve = elliptic.P256()
   311  	case "nistp384":
   312  		priv.Curve = elliptic.P384()
   313  	case "nistp521":
   314  		priv.Curve = elliptic.P521()
   315  	default:
   316  		return nil, fmt.Errorf("agent: unknown curve %q", curveName)
   317  	}
   318  
   319  	priv.X, priv.Y = elliptic.Unmarshal(priv.Curve, keyBytes)
   320  	if priv.X == nil || priv.Y == nil {
   321  		return nil, errors.New("agent: point not on curve")
   322  	}
   323  
   324  	return priv, nil
   325  }
   326  
   327  func parseEd25519Cert(req []byte) (*AddedKey, error) {
   328  	var k ed25519CertMsg
   329  	if err := ssh.Unmarshal(req, &k); err != nil {
   330  		return nil, err
   331  	}
   332  	pubKey, err := ssh.ParsePublicKey(k.CertBytes)
   333  	if err != nil {
   334  		return nil, err
   335  	}
   336  	priv := ed25519.PrivateKey(k.Priv)
   337  	cert, ok := pubKey.(*ssh.Certificate)
   338  	if !ok {
   339  		return nil, errors.New("agent: bad ED25519 certificate")
   340  	}
   341  
   342  	addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
   343  	if err := setConstraints(addedKey, k.Constraints); err != nil {
   344  		return nil, err
   345  	}
   346  	return addedKey, nil
   347  }
   348  
   349  func parseECDSAKey(req []byte) (*AddedKey, error) {
   350  	var k ecdsaKeyMsg
   351  	if err := ssh.Unmarshal(req, &k); err != nil {
   352  		return nil, err
   353  	}
   354  
   355  	priv, err := unmarshalECDSA(k.Curve, k.KeyBytes, k.D)
   356  	if err != nil {
   357  		return nil, err
   358  	}
   359  
   360  	addedKey := &AddedKey{PrivateKey: priv, Comment: k.Comments}
   361  	if err := setConstraints(addedKey, k.Constraints); err != nil {
   362  		return nil, err
   363  	}
   364  	return addedKey, nil
   365  }
   366  
   367  func parseRSACert(req []byte) (*AddedKey, error) {
   368  	var k rsaCertMsg
   369  	if err := ssh.Unmarshal(req, &k); err != nil {
   370  		return nil, err
   371  	}
   372  
   373  	pubKey, err := ssh.ParsePublicKey(k.CertBytes)
   374  	if err != nil {
   375  		return nil, err
   376  	}
   377  
   378  	cert, ok := pubKey.(*ssh.Certificate)
   379  	if !ok {
   380  		return nil, errors.New("agent: bad RSA certificate")
   381  	}
   382  
   383  	// An RSA publickey as marshaled by rsaPublicKey.Marshal() in keys.go
   384  	var rsaPub struct {
   385  		Name string
   386  		E    *big.Int
   387  		N    *big.Int
   388  	}
   389  	if err := ssh.Unmarshal(cert.Key.Marshal(), &rsaPub); err != nil {
   390  		return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
   391  	}
   392  
   393  	if rsaPub.E.BitLen() > 30 {
   394  		return nil, errors.New("agent: RSA public exponent too large")
   395  	}
   396  
   397  	priv := rsa.PrivateKey{
   398  		PublicKey: rsa.PublicKey{
   399  			E: int(rsaPub.E.Int64()),
   400  			N: rsaPub.N,
   401  		},
   402  		D:      k.D,
   403  		Primes: []*big.Int{k.Q, k.P},
   404  	}
   405  	priv.Precompute()
   406  
   407  	addedKey := &AddedKey{PrivateKey: &priv, Certificate: cert, Comment: k.Comments}
   408  	if err := setConstraints(addedKey, k.Constraints); err != nil {
   409  		return nil, err
   410  	}
   411  	return addedKey, nil
   412  }
   413  
   414  func parseDSACert(req []byte) (*AddedKey, error) {
   415  	var k dsaCertMsg
   416  	if err := ssh.Unmarshal(req, &k); err != nil {
   417  		return nil, err
   418  	}
   419  	pubKey, err := ssh.ParsePublicKey(k.CertBytes)
   420  	if err != nil {
   421  		return nil, err
   422  	}
   423  	cert, ok := pubKey.(*ssh.Certificate)
   424  	if !ok {
   425  		return nil, errors.New("agent: bad DSA certificate")
   426  	}
   427  
   428  	// A DSA publickey as marshaled by dsaPublicKey.Marshal() in keys.go
   429  	var w struct {
   430  		Name       string
   431  		P, Q, G, Y *big.Int
   432  	}
   433  	if err := ssh.Unmarshal(cert.Key.Marshal(), &w); err != nil {
   434  		return nil, fmt.Errorf("agent: Unmarshal failed to parse public key: %v", err)
   435  	}
   436  
   437  	priv := &dsa.PrivateKey{
   438  		PublicKey: dsa.PublicKey{
   439  			Parameters: dsa.Parameters{
   440  				P: w.P,
   441  				Q: w.Q,
   442  				G: w.G,
   443  			},
   444  			Y: w.Y,
   445  		},
   446  		X: k.X,
   447  	}
   448  
   449  	addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
   450  	if err := setConstraints(addedKey, k.Constraints); err != nil {
   451  		return nil, err
   452  	}
   453  	return addedKey, nil
   454  }
   455  
   456  func parseECDSACert(req []byte) (*AddedKey, error) {
   457  	var k ecdsaCertMsg
   458  	if err := ssh.Unmarshal(req, &k); err != nil {
   459  		return nil, err
   460  	}
   461  
   462  	pubKey, err := ssh.ParsePublicKey(k.CertBytes)
   463  	if err != nil {
   464  		return nil, err
   465  	}
   466  	cert, ok := pubKey.(*ssh.Certificate)
   467  	if !ok {
   468  		return nil, errors.New("agent: bad ECDSA certificate")
   469  	}
   470  
   471  	// An ECDSA publickey as marshaled by ecdsaPublicKey.Marshal() in keys.go
   472  	var ecdsaPub struct {
   473  		Name string
   474  		ID   string
   475  		Key  []byte
   476  	}
   477  	if err := ssh.Unmarshal(cert.Key.Marshal(), &ecdsaPub); err != nil {
   478  		return nil, err
   479  	}
   480  
   481  	priv, err := unmarshalECDSA(ecdsaPub.ID, ecdsaPub.Key, k.D)
   482  	if err != nil {
   483  		return nil, err
   484  	}
   485  
   486  	addedKey := &AddedKey{PrivateKey: priv, Certificate: cert, Comment: k.Comments}
   487  	if err := setConstraints(addedKey, k.Constraints); err != nil {
   488  		return nil, err
   489  	}
   490  	return addedKey, nil
   491  }
   492  
   493  func (s *server) insertIdentity(req []byte) error {
   494  	var record struct {
   495  		Type string `sshtype:"17|25"`
   496  		Rest []byte `ssh:"rest"`
   497  	}
   498  
   499  	if err := ssh.Unmarshal(req, &record); err != nil {
   500  		return err
   501  	}
   502  
   503  	var addedKey *AddedKey
   504  	var err error
   505  
   506  	switch record.Type {
   507  	case ssh.KeyAlgoRSA:
   508  		addedKey, err = parseRSAKey(req)
   509  	case ssh.KeyAlgoDSA:
   510  		addedKey, err = parseDSAKey(req)
   511  	case ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521:
   512  		addedKey, err = parseECDSAKey(req)
   513  	case ssh.KeyAlgoED25519:
   514  		addedKey, err = parseEd25519Key(req)
   515  	case ssh.CertAlgoRSAv01:
   516  		addedKey, err = parseRSACert(req)
   517  	case ssh.CertAlgoDSAv01:
   518  		addedKey, err = parseDSACert(req)
   519  	case ssh.CertAlgoECDSA256v01, ssh.CertAlgoECDSA384v01, ssh.CertAlgoECDSA521v01:
   520  		addedKey, err = parseECDSACert(req)
   521  	case ssh.CertAlgoED25519v01:
   522  		addedKey, err = parseEd25519Cert(req)
   523  	default:
   524  		return fmt.Errorf("agent: not implemented: %q", record.Type)
   525  	}
   526  
   527  	if err != nil {
   528  		return err
   529  	}
   530  	return s.agent.Add(*addedKey)
   531  }
   532  
   533  // ServeAgent serves the agent protocol on the given connection. It
   534  // returns when an I/O error occurs.
   535  func ServeAgent(agent Agent, c io.ReadWriter) error {
   536  	s := &server{agent}
   537  
   538  	var length [4]byte
   539  	for {
   540  		if _, err := io.ReadFull(c, length[:]); err != nil {
   541  			return err
   542  		}
   543  		l := binary.BigEndian.Uint32(length[:])
   544  		if l == 0 {
   545  			return fmt.Errorf("agent: request size is 0")
   546  		}
   547  		if l > maxAgentResponseBytes {
   548  			// We also cap requests.
   549  			return fmt.Errorf("agent: request too large: %d", l)
   550  		}
   551  
   552  		req := make([]byte, l)
   553  		if _, err := io.ReadFull(c, req); err != nil {
   554  			return err
   555  		}
   556  
   557  		repData := s.processRequestBytes(req)
   558  		if len(repData) > maxAgentResponseBytes {
   559  			return fmt.Errorf("agent: reply too large: %d bytes", len(repData))
   560  		}
   561  
   562  		binary.BigEndian.PutUint32(length[:], uint32(len(repData)))
   563  		if _, err := c.Write(length[:]); err != nil {
   564  			return err
   565  		}
   566  		if _, err := c.Write(repData); err != nil {
   567  			return err
   568  		}
   569  	}
   570  }