github.com/dannyzhou2015/migrate/v4@v4.15.2/source/github/github.go (about)

     1  package github
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"golang.org/x/oauth2"
     7  	"io"
     8  	"io/ioutil"
     9  	"net/http"
    10  	nurl "net/url"
    11  	"os"
    12  	"path"
    13  	"strings"
    14  
    15  	"github.com/dannyzhou2015/migrate/v4/source"
    16  	"github.com/google/go-github/v39/github"
    17  )
    18  
    19  func init() {
    20  	source.Register("github", &Github{})
    21  }
    22  
    23  var (
    24  	ErrNoUserInfo          = fmt.Errorf("no username:token provided")
    25  	ErrNoAccessToken       = fmt.Errorf("no access token")
    26  	ErrInvalidRepo         = fmt.Errorf("invalid repo")
    27  	ErrInvalidGithubClient = fmt.Errorf("expected *github.Client")
    28  	ErrNoDir               = fmt.Errorf("no directory")
    29  )
    30  
    31  type Github struct {
    32  	config     *Config
    33  	client     *github.Client
    34  	options    *github.RepositoryContentGetOptions
    35  	migrations *source.Migrations
    36  }
    37  
    38  type Config struct {
    39  	Owner string
    40  	Repo  string
    41  	Path  string
    42  	Ref   string
    43  }
    44  
    45  func (g *Github) Open(url string) (source.Driver, error) {
    46  	u, err := nurl.Parse(url)
    47  	if err != nil {
    48  		return nil, err
    49  	}
    50  
    51  	// client defaults to http.DefaultClient
    52  	var client *http.Client
    53  	if u.User != nil {
    54  		password, ok := u.User.Password()
    55  		if !ok {
    56  			return nil, ErrNoUserInfo
    57  		}
    58  		ts := oauth2.StaticTokenSource(
    59  			&oauth2.Token{AccessToken: password},
    60  		)
    61  		client = oauth2.NewClient(context.Background(), ts)
    62  
    63  	}
    64  
    65  	gn := &Github{
    66  		client:     github.NewClient(client),
    67  		migrations: source.NewMigrations(),
    68  		options:    &github.RepositoryContentGetOptions{Ref: u.Fragment},
    69  	}
    70  
    71  	gn.ensureFields()
    72  
    73  	// set owner, repo and path in repo
    74  	gn.config.Owner = u.Host
    75  	pe := strings.Split(strings.Trim(u.Path, "/"), "/")
    76  	if len(pe) < 1 {
    77  		return nil, ErrInvalidRepo
    78  	}
    79  	gn.config.Repo = pe[0]
    80  	if len(pe) > 1 {
    81  		gn.config.Path = strings.Join(pe[1:], "/")
    82  	}
    83  
    84  	if err := gn.readDirectory(); err != nil {
    85  		return nil, err
    86  	}
    87  
    88  	return gn, nil
    89  }
    90  
    91  func WithInstance(client *github.Client, config *Config) (source.Driver, error) {
    92  	gn := &Github{
    93  		client:     client,
    94  		config:     config,
    95  		migrations: source.NewMigrations(),
    96  		options:    &github.RepositoryContentGetOptions{Ref: config.Ref},
    97  	}
    98  
    99  	if err := gn.readDirectory(); err != nil {
   100  		return nil, err
   101  	}
   102  
   103  	return gn, nil
   104  }
   105  
   106  func (g *Github) readDirectory() error {
   107  	g.ensureFields()
   108  
   109  	fileContent, dirContents, _, err := g.client.Repositories.GetContents(
   110  		context.Background(),
   111  		g.config.Owner,
   112  		g.config.Repo,
   113  		g.config.Path,
   114  		g.options,
   115  	)
   116  
   117  	if err != nil {
   118  		return err
   119  	}
   120  	if fileContent != nil {
   121  		return ErrNoDir
   122  	}
   123  
   124  	for _, fi := range dirContents {
   125  		m, err := source.DefaultParse(*fi.Name)
   126  		if err != nil {
   127  			continue // ignore files that we can't parse
   128  		}
   129  		if !g.migrations.Append(m) {
   130  			return fmt.Errorf("unable to parse file %v", *fi.Name)
   131  		}
   132  	}
   133  
   134  	return nil
   135  }
   136  
   137  func (g *Github) ensureFields() {
   138  	if g.config == nil {
   139  		g.config = &Config{}
   140  	}
   141  }
   142  
   143  func (g *Github) Close() error {
   144  	return nil
   145  }
   146  
   147  func (g *Github) First() (version uint, err error) {
   148  	g.ensureFields()
   149  
   150  	if v, ok := g.migrations.First(); !ok {
   151  		return 0, &os.PathError{Op: "first", Path: g.config.Path, Err: os.ErrNotExist}
   152  	} else {
   153  		return v, nil
   154  	}
   155  }
   156  
   157  func (g *Github) Prev(version uint) (prevVersion uint, err error) {
   158  	g.ensureFields()
   159  
   160  	if v, ok := g.migrations.Prev(version); !ok {
   161  		return 0, &os.PathError{Op: fmt.Sprintf("prev for version %v", version), Path: g.config.Path, Err: os.ErrNotExist}
   162  	} else {
   163  		return v, nil
   164  	}
   165  }
   166  
   167  func (g *Github) Next(version uint) (nextVersion uint, err error) {
   168  	g.ensureFields()
   169  
   170  	if v, ok := g.migrations.Next(version); !ok {
   171  		return 0, &os.PathError{Op: fmt.Sprintf("next for version %v", version), Path: g.config.Path, Err: os.ErrNotExist}
   172  	} else {
   173  		return v, nil
   174  	}
   175  }
   176  
   177  func (g *Github) ReadUp(version uint) (r io.ReadCloser, identifier string, err error) {
   178  	g.ensureFields()
   179  
   180  	if m, ok := g.migrations.Up(version); ok {
   181  		file, _, _, err := g.client.Repositories.GetContents(
   182  			context.Background(),
   183  			g.config.Owner,
   184  			g.config.Repo,
   185  			path.Join(g.config.Path, m.Raw),
   186  			g.options,
   187  		)
   188  
   189  		if err != nil {
   190  			return nil, "", err
   191  		}
   192  		if file != nil {
   193  			r, err := file.GetContent()
   194  			if err != nil {
   195  				return nil, "", err
   196  			}
   197  			return ioutil.NopCloser(strings.NewReader(r)), m.Identifier, nil
   198  		}
   199  	}
   200  	return nil, "", &os.PathError{Op: fmt.Sprintf("read version %v", version), Path: g.config.Path, Err: os.ErrNotExist}
   201  }
   202  
   203  func (g *Github) ReadDown(version uint) (r io.ReadCloser, identifier string, err error) {
   204  	g.ensureFields()
   205  
   206  	if m, ok := g.migrations.Down(version); ok {
   207  		file, _, _, err := g.client.Repositories.GetContents(
   208  			context.Background(),
   209  			g.config.Owner,
   210  			g.config.Repo,
   211  			path.Join(g.config.Path, m.Raw),
   212  			g.options,
   213  		)
   214  
   215  		if err != nil {
   216  			return nil, "", err
   217  		}
   218  		if file != nil {
   219  			r, err := file.GetContent()
   220  			if err != nil {
   221  				return nil, "", err
   222  			}
   223  			return ioutil.NopCloser(strings.NewReader(r)), m.Identifier, nil
   224  		}
   225  	}
   226  	return nil, "", &os.PathError{Op: fmt.Sprintf("read version %v", version), Path: g.config.Path, Err: os.ErrNotExist}
   227  }