github.com/dolthub/dolt/go@v0.40.5-0.20240520175717-68db7794bea6/libraries/doltcore/sqle/dprocedures/dolt_add.go (about)

     1  // Copyright 2022 Dolthub, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package dprocedures
    16  
    17  import (
    18  	"fmt"
    19  	"strings"
    20  
    21  	"github.com/dolthub/dolt/go/cmd/dolt/cli"
    22  	"github.com/dolthub/dolt/go/libraries/doltcore/branch_control"
    23  	"github.com/dolthub/dolt/go/libraries/doltcore/diff"
    24  	"github.com/dolthub/dolt/go/libraries/doltcore/env/actions"
    25  	"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
    26  
    27  	"github.com/dolthub/go-mysql-server/sql"
    28  )
    29  
    30  // doltAdd is the stored procedure version for the CLI command `dolt add`.
    31  func doltAdd(ctx *sql.Context, args ...string) (sql.RowIter, error) {
    32  	res, err := doDoltAdd(ctx, args)
    33  	if err != nil {
    34  		return nil, err
    35  	}
    36  	return rowToIter(int64(res)), nil
    37  }
    38  
    39  func doDoltAdd(ctx *sql.Context, args []string) (int, error) {
    40  	dbName := ctx.GetCurrentDatabase()
    41  
    42  	if len(dbName) == 0 {
    43  		return 1, fmt.Errorf("Empty database name.")
    44  	}
    45  	if err := branch_control.CheckAccess(ctx, branch_control.Permissions_Write); err != nil {
    46  		return 1, err
    47  	}
    48  
    49  	apr, err := cli.CreateAddArgParser().Parse(args)
    50  	if err != nil {
    51  		return 1, err
    52  	}
    53  
    54  	allFlag := apr.Contains(cli.AllFlag)
    55  
    56  	dSess := dsess.DSessFromSess(ctx.Session)
    57  	roots, ok := dSess.GetRoots(ctx, dbName)
    58  	if apr.NArg() == 0 && !allFlag {
    59  		return 1, fmt.Errorf("Nothing specified, nothing added. Maybe you wanted to say 'dolt add .'?")
    60  	} else if allFlag || apr.NArg() == 1 && apr.Arg(0) == "." {
    61  		if !ok {
    62  			return 1, fmt.Errorf("db session not found")
    63  		}
    64  
    65  		roots, err = actions.StageAllTables(ctx, roots, !apr.Contains(cli.ForceFlag))
    66  		if err != nil {
    67  			return 1, err
    68  		}
    69  
    70  		roots, err = actions.StageDatabase(ctx, roots, !apr.Contains(cli.ForceFlag))
    71  		if err != nil {
    72  			return 1, err
    73  		}
    74  
    75  		err = dSess.SetRoots(ctx, dbName, roots)
    76  		if err != nil {
    77  			return 1, err
    78  		}
    79  	} else {
    80  		// special case to handle __DATABASE__<db>
    81  		for i, arg := range apr.Args {
    82  			if !strings.HasPrefix(arg, diff.DBPrefix) {
    83  				continue
    84  			}
    85  			// remove from slice
    86  			apr.Args = append(apr.Args[:i], apr.Args[i+1:]...)
    87  			roots, err = actions.StageDatabase(ctx, roots, !apr.Contains(cli.ForceFlag))
    88  			if err != nil {
    89  				return 1, err
    90  			}
    91  		}
    92  
    93  		roots, err = actions.StageTables(ctx, roots, apr.Args, !apr.Contains(cli.ForceFlag))
    94  		if err != nil {
    95  			return 1, err
    96  		}
    97  
    98  		err = dSess.SetRoots(ctx, dbName, roots)
    99  		if err != nil {
   100  			return 1, err
   101  		}
   102  	}
   103  
   104  	return 0, nil
   105  }