github.com/ydb-platform/ydb-go-sdk/v3@v3.89.2/dsn.go (about)

     1  package ydb
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"regexp"
     7  	"strings"
     8  
     9  	"github.com/ydb-platform/ydb-go-sdk/v3/balancers"
    10  	"github.com/ydb-platform/ydb-go-sdk/v3/credentials"
    11  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/bind"
    12  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/dsn"
    13  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xerrors"
    14  	"github.com/ydb-platform/ydb-go-sdk/v3/internal/xsql"
    15  )
    16  
    17  const tablePathPrefixTransformer = "table_path_prefix"
    18  
    19  var dsnParsers = []func(dsn string) (opts []Option, _ error){
    20  	func(dsn string) ([]Option, error) {
    21  		opts, err := parseConnectionString(dsn)
    22  		if err != nil {
    23  			return nil, xerrors.WithStackTrace(err)
    24  		}
    25  
    26  		return opts, nil
    27  	},
    28  }
    29  
    30  // RegisterDsnParser registers DSN parser for ydb.Open and sql.Open driver constructors
    31  //
    32  // Experimental: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#experimental
    33  func RegisterDsnParser(parser func(dsn string) (opts []Option, _ error)) (registrationID int) {
    34  	dsnParsers = append(dsnParsers, parser)
    35  
    36  	return len(dsnParsers) - 1
    37  }
    38  
    39  // UnregisterDsnParser unregisters DSN parser by key
    40  //
    41  // Experimental: https://github.com/ydb-platform/ydb-go-sdk/blob/master/VERSIONING.md#experimental
    42  func UnregisterDsnParser(registrationID int) {
    43  	dsnParsers[registrationID] = nil
    44  }
    45  
    46  //nolint:funlen
    47  func parseConnectionString(dataSourceName string) (opts []Option, _ error) {
    48  	info, err := dsn.Parse(dataSourceName)
    49  	if err != nil {
    50  		return nil, xerrors.WithStackTrace(err)
    51  	}
    52  	opts = append(opts, With(info.Options...))
    53  	if token := info.Params.Get("token"); token != "" {
    54  		opts = append(opts, WithCredentials(credentials.NewAccessTokenCredentials(token)))
    55  	}
    56  	if balancer := info.Params.Get("go_balancer"); balancer != "" {
    57  		opts = append(opts, WithBalancer(balancers.FromConfig(balancer)))
    58  	} else if balancer := info.Params.Get("balancer"); balancer != "" {
    59  		opts = append(opts, WithBalancer(balancers.FromConfig(balancer)))
    60  	}
    61  	if queryMode := info.Params.Get("go_query_mode"); queryMode != "" {
    62  		mode := xsql.QueryModeFromString(queryMode)
    63  		if mode == xsql.UnknownQueryMode {
    64  			return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode))
    65  		}
    66  		opts = append(opts, withConnectorOptions(xsql.WithDefaultQueryMode(mode)))
    67  	} else if queryMode := info.Params.Get("query_mode"); queryMode != "" {
    68  		mode := xsql.QueryModeFromString(queryMode)
    69  		if mode == xsql.UnknownQueryMode {
    70  			return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode))
    71  		}
    72  		opts = append(opts, withConnectorOptions(xsql.WithDefaultQueryMode(mode)))
    73  	}
    74  	if fakeTx := info.Params.Get("go_fake_tx"); fakeTx != "" {
    75  		for _, queryMode := range strings.Split(fakeTx, ",") {
    76  			mode := xsql.QueryModeFromString(queryMode)
    77  			if mode == xsql.UnknownQueryMode {
    78  				return nil, xerrors.WithStackTrace(fmt.Errorf("unknown query mode: %s", queryMode))
    79  			}
    80  			opts = append(opts, withConnectorOptions(xsql.WithFakeTx(mode)))
    81  		}
    82  	}
    83  	if info.Params.Has("go_query_bind") {
    84  		var binders []xsql.ConnectorOption
    85  		queryTransformers := strings.Split(info.Params.Get("go_query_bind"), ",")
    86  		for _, transformer := range queryTransformers {
    87  			switch transformer {
    88  			case "declare":
    89  				binders = append(binders, xsql.WithQueryBind(bind.AutoDeclare{}))
    90  			case "positional":
    91  				binders = append(binders, xsql.WithQueryBind(bind.PositionalArgs{}))
    92  			case "numeric":
    93  				binders = append(binders, xsql.WithQueryBind(bind.NumericArgs{}))
    94  			default:
    95  				if strings.HasPrefix(transformer, tablePathPrefixTransformer) {
    96  					prefix, err := extractTablePathPrefixFromBinderName(transformer)
    97  					if err != nil {
    98  						return nil, xerrors.WithStackTrace(err)
    99  					}
   100  					binders = append(binders, xsql.WithTablePathPrefix(prefix))
   101  				} else {
   102  					return nil, xerrors.WithStackTrace(
   103  						fmt.Errorf("unknown query rewriter: %s", transformer),
   104  					)
   105  				}
   106  			}
   107  		}
   108  		opts = append(opts, withConnectorOptions(binders...))
   109  	}
   110  
   111  	return opts, nil
   112  }
   113  
   114  var (
   115  	tablePathPrefixRe       = regexp.MustCompile(tablePathPrefixTransformer + "\\((.*)\\)")
   116  	errWrongTablePathPrefix = errors.New("wrong '" + tablePathPrefixTransformer + "' query transformer")
   117  )
   118  
   119  func extractTablePathPrefixFromBinderName(binderName string) (string, error) {
   120  	ss := tablePathPrefixRe.FindAllStringSubmatch(binderName, -1)
   121  	if len(ss) != 1 || len(ss[0]) != 2 || ss[0][1] == "" {
   122  		return "", xerrors.WithStackTrace(fmt.Errorf("%w: %s", errWrongTablePathPrefix, binderName))
   123  	}
   124  
   125  	return ss[0][1], nil
   126  }