github.com/imgk/caddy-trojan@v0.0.0-20221206043256-2631719e16c8/trojan/trojan.go (about)

     1  package trojan
     2  
     3  import (
     4  	"crypto/sha256"
     5  	"encoding/hex"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"time"
    11  
    12  	"github.com/imgk/caddy-trojan/socks"
    13  	"github.com/imgk/caddy-trojan/utils"
    14  )
    15  
    16  // HeaderLen is ...
    17  const HeaderLen = 56
    18  
    19  const (
    20  	// CmdConnect is ...
    21  	CmdConnect = 1
    22  	// CmdAssociate is ...
    23  	CmdAssociate = 3
    24  )
    25  
    26  // GenKey is ...
    27  func GenKey(s string, key []byte) {
    28  	hash := sha256.Sum224(utils.StringToByteSlice(s))
    29  	hex.Encode(key, hash[:])
    30  }
    31  
    32  // Handle is ...
    33  func Handle(r io.Reader, w io.Writer) (int64, int64, error) {
    34  	return HandleWithDialer(r, w, (*netDialer)(nil))
    35  }
    36  
    37  // Dialer is ...
    38  type Dialer interface {
    39  	// Dial is ...
    40  	Dial(string, string) (net.Conn, error)
    41  	// ListenPacket is ...
    42  	ListenPacket(string, string) (net.PacketConn, error)
    43  }
    44  
    45  type netDialer struct{}
    46  
    47  func (*netDialer) Dial(network, addr string) (net.Conn, error) {
    48  	return net.Dial(network, addr)
    49  }
    50  
    51  func (*netDialer) ListenPacket(network, addr string) (net.PacketConn, error) {
    52  	return net.ListenPacket(network, addr)
    53  }
    54  
    55  // HandleWithDialer is ...
    56  func HandleWithDialer(r io.Reader, w io.Writer, d Dialer) (int64, int64, error) {
    57  	b := [1 + socks.MaxAddrLen + 2]byte{}
    58  
    59  	// read command
    60  	if _, err := io.ReadFull(r, b[:1]); err != nil {
    61  		return 0, 0, fmt.Errorf("read command error: %w", err)
    62  	}
    63  	if b[0] != CmdConnect && b[0] != CmdAssociate {
    64  		return 0, 0, errors.New("command error")
    65  	}
    66  
    67  	// read address
    68  	addr, err := socks.ReadAddrBuffer(r, b[3:])
    69  	if err != nil {
    70  		return 0, 0, fmt.Errorf("read addr error: %w", err)
    71  	}
    72  
    73  	// read 0x0d, 0x0a
    74  	if _, err := io.ReadFull(r, b[1:3]); err != nil {
    75  		return 0, 0, fmt.Errorf("read 0x0d 0x0a error: %w", err)
    76  	}
    77  
    78  	switch b[0] {
    79  	case CmdConnect:
    80  		nr, nw, err := HandleTCP(r, w, addr, d)
    81  		if err != nil {
    82  			return nr, nw, fmt.Errorf("handle tcp error: %w", err)
    83  		}
    84  		return nr, nw, nil
    85  	case CmdAssociate:
    86  		nr, nw, err := HandleUDP(r, w, time.Minute*10, d)
    87  		if err != nil {
    88  			return nr, nw, fmt.Errorf("handle udp error: %w", err)
    89  		}
    90  		return nr, nw, nil
    91  	default:
    92  	}
    93  	return 0, 0, errors.New("command error")
    94  }