github.com/Jeffail/benthos/v3@v3.65.0/lib/processor/sql.go (about)

     1  package processor
     2  
     3  import (
     4  	"database/sql"
     5  	"errors"
     6  	"fmt"
     7  	"strings"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/Jeffail/benthos/v3/internal/bloblang/field"
    12  	"github.com/Jeffail/benthos/v3/internal/bloblang/mapping"
    13  	"github.com/Jeffail/benthos/v3/internal/docs"
    14  	"github.com/Jeffail/benthos/v3/internal/interop"
    15  	"github.com/Jeffail/benthos/v3/internal/tracing"
    16  	"github.com/Jeffail/benthos/v3/lib/log"
    17  	"github.com/Jeffail/benthos/v3/lib/metrics"
    18  	"github.com/Jeffail/benthos/v3/lib/types"
    19  
    20  	// SQL Drivers
    21  	_ "github.com/ClickHouse/clickhouse-go"
    22  	_ "github.com/denisenkom/go-mssqldb"
    23  	_ "github.com/go-sql-driver/mysql"
    24  )
    25  
    26  //------------------------------------------------------------------------------
    27  
    28  func init() {
    29  	Constructors[TypeSQL] = TypeSpec{
    30  		constructor: NewSQL,
    31  		Categories: []Category{
    32  			CategoryIntegration,
    33  		},
    34  		Status: docs.StatusStable,
    35  		Summary: `
    36  Runs an SQL prepared query against a target database for each message and, for
    37  queries that return rows, replaces it with the result according to a
    38  [codec](#result-codecs).`,
    39  		Description: `
    40  ## Alternatives
    41  
    42  For basic inserts or select queries use use either the ` + "[`sql_insert`](/docs/components/processors/sql_insert)" + ` or the ` + "[`sql_select`](/docs/components/processors/sql_select)" + ` processor.
    43  
    44  For more complex queries use the ` + "[`sql_raw`](/docs/components/processors/sql_raw)" + ` processor.`,
    45  		Examples: []docs.AnnotatedExample{
    46  			{
    47  				Title: "Table Insert (MySQL)",
    48  				Summary: `
    49  The following example inserts rows into the table footable with the columns foo,
    50  bar and baz populated with values extracted from messages:`,
    51  				Config: `
    52  pipeline:
    53    processors:
    54      - sql:
    55          driver: mysql
    56          data_source_name: foouser:foopassword@tcp(localhost:3306)/foodb
    57          query: "INSERT INTO footable (foo, bar, baz) VALUES (?, ?, ?);"
    58          args_mapping: '[ document.foo, document.bar, meta("kafka_topic") ]'
    59  `,
    60  			},
    61  			{
    62  				Title: "Table Query (PostgreSQL)",
    63  				Summary: `
    64  Here we query a database for columns of footable that share a ` + "`user_id`" + `
    65  with the message ` + "`user.id`" + `. The ` + "`result_codec`" + ` is set to
    66  ` + "`json_array`" + ` and a ` + "[`branch` processor](/docs/components/processors/branch)" + `
    67  is used in order to insert the resulting array into the original message at the
    68  path ` + "`foo_rows`" + `:`,
    69  				Config: `
    70  pipeline:
    71    processors:
    72      - branch:
    73          processors:
    74            - sql:
    75                driver: postgres
    76                result_codec: json_array
    77                data_source_name: postgres://foouser:foopass@localhost:5432/testdb?sslmode=disable
    78                query: "SELECT * FROM footable WHERE user_id = $1;"
    79                args_mapping: '[ this.user.id ]'
    80          result_map: 'root.foo_rows = this'
    81  `,
    82  			},
    83  		},
    84  		FieldSpecs: docs.FieldSpecs{
    85  			docs.FieldCommon(
    86  				"driver",
    87  				"A database [driver](#drivers) to use.",
    88  			).HasOptions("mysql", "postgres", "clickhouse", "mssql"),
    89  			docs.FieldCommon(
    90  				"data_source_name", "A Data Source Name to identify the target database.",
    91  				"tcp://host1:9000?username=user&password=qwerty&database=clicks&read_timeout=10&write_timeout=20&alt_hosts=host2:9000,host3:9000",
    92  				"foouser:foopassword@tcp(localhost:3306)/foodb",
    93  				"postgres://foouser:foopass@localhost:5432/foodb?sslmode=disable",
    94  			),
    95  			docs.FieldDeprecated("dsn", ""),
    96  			docs.FieldCommon(
    97  				"query", "The query to run against the database.",
    98  				"INSERT INTO footable (foo, bar, baz) VALUES (?, ?, ?);",
    99  			),
   100  			docs.FieldBool(
   101  				"unsafe_dynamic_query",
   102  				"Whether to enable dynamic queries that support interpolation functions. WARNING: This feature opens up the possibility of SQL injection attacks and is considered unsafe.",
   103  			).Advanced().HasDefault(false),
   104  			docs.FieldDeprecated(
   105  				"args",
   106  				"A list of arguments for the query to be resolved for each message.",
   107  			).IsInterpolated().Array(),
   108  			docs.FieldBloblang(
   109  				"args_mapping",
   110  				"A [Bloblang mapping](/docs/guides/bloblang/about) that produces the arguments for the query. The mapping must return an array containing the number of arguments in the query.",
   111  				`[ this.foo, this.bar.not_empty().catch(null), meta("baz") ]`,
   112  				`root = [ uuid_v4() ].merge(this.document.args)`,
   113  			).AtVersion("3.47.0"),
   114  			docs.FieldCommon(
   115  				"result_codec",
   116  				"A [codec](#result-codecs) to determine how resulting rows are converted into messages.",
   117  			).HasOptions("none", "json_array"),
   118  		},
   119  		Footnotes: `
   120  ## Result Codecs
   121  
   122  When a query returns rows they are serialised according to a chosen codec, and
   123  the message contents are replaced with the serialised result.
   124  
   125  ### ` + "`none`" + `
   126  
   127  The result of the query is ignored and the message remains unchanged. If your
   128  query does not return rows then this is the appropriate codec.
   129  
   130  ### ` + "`json_array`" + `
   131  
   132  The resulting rows are serialised into an array of JSON objects, where each
   133  object represents a row, where the key is the column name and the value is that
   134  columns value in the row.`,
   135  	}
   136  }
   137  
   138  //------------------------------------------------------------------------------
   139  
   140  // SQLConfig contains configuration fields for the SQL processor.
   141  type SQLConfig struct {
   142  	Driver             string   `json:"driver" yaml:"driver"`
   143  	DataSourceName     string   `json:"data_source_name" yaml:"data_source_name"`
   144  	DSN                string   `json:"dsn" yaml:"dsn"`
   145  	Query              string   `json:"query" yaml:"query"`
   146  	UnsafeDynamicQuery bool     `json:"unsafe_dynamic_query" yaml:"unsafe_dynamic_query"`
   147  	Args               []string `json:"args" yaml:"args"`
   148  	ArgsMapping        string   `json:"args_mapping" yaml:"args_mapping"`
   149  	ResultCodec        string   `json:"result_codec" yaml:"result_codec"`
   150  }
   151  
   152  // NewSQLConfig returns a SQLConfig with default values.
   153  func NewSQLConfig() SQLConfig {
   154  	return SQLConfig{
   155  		Driver:             "mysql",
   156  		DataSourceName:     "",
   157  		DSN:                "",
   158  		Query:              "",
   159  		UnsafeDynamicQuery: false,
   160  		Args:               []string{},
   161  		ArgsMapping:        "",
   162  		ResultCodec:        "none",
   163  	}
   164  }
   165  
   166  //------------------------------------------------------------------------------
   167  
   168  // Some SQL drivers (such as clickhouse) require prepared inserts to be local to
   169  // a transaction, rather than general.
   170  func insertRequiresTransactionPrepare(driver string) bool {
   171  	_, exists := map[string]struct{}{
   172  		"clickhouse": {},
   173  	}[driver]
   174  	return exists
   175  }
   176  
   177  //------------------------------------------------------------------------------
   178  
   179  // SQL is a processor that executes an SQL query for each message.
   180  type SQL struct {
   181  	log   log.Modular
   182  	stats metrics.Type
   183  
   184  	conf        SQLConfig
   185  	db          *sql.DB
   186  	dbMux       sync.RWMutex
   187  	args        []*field.Expression
   188  	argsMapping *mapping.Executor
   189  	resCodec    sqlResultCodec
   190  
   191  	// TODO: V4 Remove this
   192  	deprecated         bool
   193  	resCodecDeprecated sqlResultCodecDeprecated
   194  
   195  	queryStr string
   196  	dynQuery *field.Expression
   197  	query    *sql.Stmt
   198  
   199  	closeChan  chan struct{}
   200  	closedChan chan struct{}
   201  	closeOnce  sync.Once
   202  
   203  	mCount     metrics.StatCounter
   204  	mErr       metrics.StatCounter
   205  	mSent      metrics.StatCounter
   206  	mBatchSent metrics.StatCounter
   207  }
   208  
   209  // NewSQL returns a SQL processor.
   210  func NewSQL(
   211  	conf Config, mgr types.Manager, log log.Modular, stats metrics.Type,
   212  ) (Type, error) {
   213  	deprecated := false
   214  	dsn := conf.SQL.DataSourceName
   215  	if len(conf.SQL.DSN) > 0 {
   216  		if len(dsn) > 0 {
   217  			return nil, errors.New("specified both a deprecated `dsn` as well as a `data_source_name`")
   218  		}
   219  		dsn = conf.SQL.DSN
   220  		deprecated = true
   221  	}
   222  
   223  	if len(conf.SQL.Args) > 0 && conf.SQL.ArgsMapping != "" {
   224  		return nil, errors.New("cannot specify both `args` and an `args_mapping` in the same processor")
   225  	}
   226  
   227  	var argsMapping *mapping.Executor
   228  	if conf.SQL.ArgsMapping != "" {
   229  		if deprecated {
   230  			return nil, errors.New("the field `args_mapping` cannot be used when running the `sql` processor in deprecated mode (using the `dsn` field), use the `data_source_name` field instead")
   231  		}
   232  		log.Warnln("using unsafe_dynamic_query leaves you vulnerable to SQL injection attacks")
   233  		var err error
   234  		if argsMapping, err = interop.NewBloblangMapping(mgr, conf.SQL.ArgsMapping); err != nil {
   235  			return nil, fmt.Errorf("failed to parse `args_mapping`: %w", err)
   236  		}
   237  	}
   238  
   239  	var args []*field.Expression
   240  	for i, v := range conf.SQL.Args {
   241  		expr, err := interop.NewBloblangField(mgr, v)
   242  		if err != nil {
   243  			return nil, fmt.Errorf("failed to parse arg %v expression: %v", i, err)
   244  		}
   245  		args = append(args, expr)
   246  	}
   247  
   248  	if conf.SQL.Driver == "mssql" {
   249  		// For MSSQL, if the user part of the connection string is in the
   250  		// `DOMAIN\username` format, then the backslash character needs to be
   251  		// URL-encoded.
   252  		conf.SQL.DataSourceName = strings.ReplaceAll(conf.SQL.DataSourceName, `\`, "%5C")
   253  	}
   254  
   255  	s := &SQL{
   256  		log:         log,
   257  		stats:       stats,
   258  		conf:        conf.SQL,
   259  		args:        args,
   260  		argsMapping: argsMapping,
   261  
   262  		queryStr: conf.SQL.Query,
   263  
   264  		deprecated: deprecated,
   265  		closeChan:  make(chan struct{}),
   266  		closedChan: make(chan struct{}),
   267  		mCount:     stats.GetCounter("count"),
   268  		mErr:       stats.GetCounter("error"),
   269  		mSent:      stats.GetCounter("sent"),
   270  		mBatchSent: stats.GetCounter("batch.sent"),
   271  	}
   272  
   273  	var err error
   274  	if deprecated {
   275  		s.log.Warnln("Using deprecated SQL functionality due to use of field 'dsn'. To switch to the new processor use the field 'data_source_name' instead. The new processor is not backwards compatible due to differences in how message batches are processed. For more information check out the docs at https://www.benthos.dev/docs/components/processors/sql.")
   276  		if conf.SQL.Driver != "mysql" && conf.SQL.Driver != "postgres" && conf.SQL.Driver != "mssql" {
   277  			return nil, fmt.Errorf("driver '%v' is not supported with deprecated SQL features (using field 'dsn')", conf.SQL.Driver)
   278  		}
   279  		if s.resCodecDeprecated, err = strToSQLResultCodecDeprecated(conf.SQL.ResultCodec); err != nil {
   280  			return nil, err
   281  		}
   282  	} else if s.resCodec, err = strToSQLResultCodec(conf.SQL.ResultCodec); err != nil {
   283  		return nil, err
   284  	}
   285  
   286  	if s.db, err = sql.Open(conf.SQL.Driver, dsn); err != nil {
   287  		return nil, err
   288  	}
   289  
   290  	if conf.SQL.UnsafeDynamicQuery {
   291  		if deprecated {
   292  			return nil, errors.New("cannot use dynamic queries when running in deprecated mode")
   293  		}
   294  		if s.dynQuery, err = interop.NewBloblangField(mgr, s.queryStr); err != nil {
   295  			return nil, fmt.Errorf("failed to parse dynamic query expression: %v", err)
   296  		}
   297  	}
   298  
   299  	isSelectQuery := s.resCodecDeprecated != nil || s.resCodec != nil
   300  
   301  	// Some drivers only support transactional prepared inserts.
   302  	if s.dynQuery == nil && (isSelectQuery || !insertRequiresTransactionPrepare(conf.SQL.Driver)) {
   303  		if s.query, err = s.db.Prepare(s.queryStr); err != nil {
   304  			s.db.Close()
   305  			return nil, fmt.Errorf("failed to prepare query: %v", err)
   306  		}
   307  	}
   308  
   309  	go func() {
   310  		defer func() {
   311  			s.dbMux.Lock()
   312  			s.db.Close()
   313  			if s.query != nil {
   314  				s.query.Close()
   315  			}
   316  			s.dbMux.Unlock()
   317  			close(s.closedChan)
   318  		}()
   319  		<-s.closeChan
   320  	}()
   321  	return s, nil
   322  }
   323  
   324  //------------------------------------------------------------------------------
   325  
   326  type sqlResultCodec func(rows *sql.Rows, part types.Part) error
   327  
   328  func sqlResultJSONArrayCodec(rows *sql.Rows, part types.Part) error {
   329  	columnNames, err := rows.Columns()
   330  	if err != nil {
   331  		return err
   332  	}
   333  	jArray := []interface{}{}
   334  	for rows.Next() {
   335  		values := make([]interface{}, len(columnNames))
   336  		valuesWrapped := make([]interface{}, len(columnNames))
   337  		for i := range values {
   338  			valuesWrapped[i] = &values[i]
   339  		}
   340  		if err := rows.Scan(valuesWrapped...); err != nil {
   341  			return err
   342  		}
   343  		jObj := map[string]interface{}{}
   344  		for i, v := range values {
   345  			switch t := v.(type) {
   346  			case string:
   347  				jObj[columnNames[i]] = t
   348  			case []byte:
   349  				jObj[columnNames[i]] = string(t)
   350  			case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
   351  				jObj[columnNames[i]] = t
   352  			case float32, float64:
   353  				jObj[columnNames[i]] = t
   354  			case bool:
   355  				jObj[columnNames[i]] = t
   356  			default:
   357  				jObj[columnNames[i]] = t
   358  			}
   359  		}
   360  		jArray = append(jArray, jObj)
   361  	}
   362  	if err := rows.Err(); err != nil {
   363  		return err
   364  	}
   365  	return part.SetJSON(jArray)
   366  }
   367  
   368  func strToSQLResultCodec(codec string) (sqlResultCodec, error) {
   369  	switch codec {
   370  	case "json_array":
   371  		return sqlResultJSONArrayCodec, nil
   372  	case "none":
   373  		return nil, nil
   374  	}
   375  	return nil, fmt.Errorf("unrecognised result codec: %v", codec)
   376  }
   377  
   378  //------------------------------------------------------------------------------
   379  
   380  func (s *SQL) doExecute(argSets [][]interface{}) (errs []error) {
   381  	var err error
   382  	defer func() {
   383  		if err != nil {
   384  			if len(errs) == 0 {
   385  				errs = make([]error, len(argSets))
   386  			}
   387  			for i := range errs {
   388  				if errs[i] == nil {
   389  					errs[i] = err
   390  				}
   391  			}
   392  		}
   393  	}()
   394  
   395  	var tx *sql.Tx
   396  	if tx, err = s.db.Begin(); err != nil {
   397  		return
   398  	}
   399  
   400  	stmt := s.query
   401  	if stmt == nil {
   402  		if stmt, err = tx.Prepare(s.queryStr); err != nil {
   403  			return
   404  		}
   405  		defer stmt.Close()
   406  	} else {
   407  		stmt = tx.Stmt(stmt)
   408  	}
   409  
   410  	for i, args := range argSets {
   411  		if len(args) == 0 {
   412  			continue
   413  		}
   414  		if _, serr := stmt.Exec(args...); serr != nil {
   415  			if len(errs) == 0 {
   416  				errs = make([]error, len(argSets))
   417  			}
   418  			errs[i] = serr
   419  		}
   420  	}
   421  
   422  	err = tx.Commit()
   423  	return
   424  }
   425  
   426  func (s *SQL) getArgs(index int, msg types.Message) ([]interface{}, error) {
   427  	if len(s.args) > 0 {
   428  		args := make([]interface{}, len(s.args))
   429  		for i, v := range s.args {
   430  			args[i] = v.String(index, msg)
   431  		}
   432  		return args, nil
   433  	}
   434  
   435  	if s.argsMapping == nil {
   436  		return nil, nil
   437  	}
   438  
   439  	pargs, err := s.argsMapping.MapPart(index, msg)
   440  	if err != nil {
   441  		return nil, err
   442  	}
   443  
   444  	iargs, err := pargs.JSON()
   445  	if err != nil {
   446  		return nil, fmt.Errorf("mapping returned non-structured result: %w", err)
   447  	}
   448  
   449  	args, ok := iargs.([]interface{})
   450  	if !ok {
   451  		return nil, fmt.Errorf("mapping returned non-array result: %T", iargs)
   452  	}
   453  	return args, nil
   454  }
   455  
   456  // ProcessMessage logs an event and returns the message unchanged.
   457  func (s *SQL) ProcessMessage(msg types.Message) ([]types.Message, types.Response) {
   458  	s.dbMux.RLock()
   459  	defer s.dbMux.RUnlock()
   460  
   461  	if s.deprecated {
   462  		return s.processMessageDeprecated(msg)
   463  	}
   464  
   465  	s.mCount.Incr(1)
   466  	newMsg := msg.Copy()
   467  
   468  	if s.resCodec == nil && s.dynQuery == nil {
   469  		argSets := make([][]interface{}, newMsg.Len())
   470  		newMsg.Iter(func(index int, p types.Part) error {
   471  			args, err := s.getArgs(index, msg)
   472  			if err != nil {
   473  				s.mErr.Incr(1)
   474  				s.log.Errorf("Args mapping error: %v\n", err)
   475  				FlagErr(newMsg.Get(index), err)
   476  				return nil
   477  			}
   478  			argSets[index] = args
   479  			return nil
   480  		})
   481  
   482  		for i, err := range s.doExecute(argSets) {
   483  			if err != nil {
   484  				s.mErr.Incr(1)
   485  				s.log.Errorf("SQL error: %v\n", err)
   486  				FlagErr(newMsg.Get(i), err)
   487  			}
   488  		}
   489  	} else {
   490  		IteratePartsWithSpanV2(TypeSQL, nil, newMsg, func(index int, span *tracing.Span, part types.Part) error {
   491  			args, err := s.getArgs(index, msg)
   492  			if err != nil {
   493  				s.mErr.Incr(1)
   494  				s.log.Errorf("Args mapping error: %v\n", err)
   495  				return err
   496  			}
   497  
   498  			if s.resCodec == nil {
   499  				if s.dynQuery != nil {
   500  					queryStr := s.dynQuery.String(index, msg)
   501  					_, err = s.db.Exec(queryStr, args...)
   502  				} else {
   503  					_, err = s.query.Exec(args...)
   504  				}
   505  				if err != nil {
   506  					return fmt.Errorf("failed to execute query: %w", err)
   507  				}
   508  				return nil
   509  			}
   510  
   511  			var rows *sql.Rows
   512  			if s.dynQuery != nil {
   513  				queryStr := s.dynQuery.String(index, msg)
   514  				rows, err = s.db.Query(queryStr, args...)
   515  			} else {
   516  				rows, err = s.query.Query(args...)
   517  			}
   518  			if err == nil {
   519  				defer rows.Close()
   520  				if err = s.resCodec(rows, part); err != nil {
   521  					err = fmt.Errorf("failed to apply result codec: %v", err)
   522  				}
   523  			} else {
   524  				err = fmt.Errorf("failed to execute query: %v", err)
   525  			}
   526  			if err != nil {
   527  				s.mErr.Incr(1)
   528  				s.log.Errorf("SQL error: %v\n", err)
   529  				return err
   530  			}
   531  			return nil
   532  		})
   533  	}
   534  
   535  	s.mBatchSent.Incr(1)
   536  	s.mSent.Incr(int64(newMsg.Len()))
   537  	msgs := [1]types.Message{newMsg}
   538  	return msgs[:], nil
   539  }
   540  
   541  // CloseAsync shuts down the processor and stops processing requests.
   542  func (s *SQL) CloseAsync() {
   543  	s.closeOnce.Do(func() {
   544  		close(s.closeChan)
   545  	})
   546  }
   547  
   548  // WaitForClose blocks until the processor has closed down.
   549  func (s *SQL) WaitForClose(timeout time.Duration) error {
   550  	select {
   551  	case <-time.After(timeout):
   552  		return types.ErrTimeout
   553  	case <-s.closedChan:
   554  	}
   555  	return nil
   556  }
   557  
   558  //------------------------------------------------------------------------------