diff --git a/cmd/auth.go b/cmd/auth.go index f7876fc..2b618bf 100644 --- a/cmd/auth.go +++ b/cmd/auth.go @@ -25,6 +25,7 @@ import ( "github.com/spf13/viper" "go.uploadedlobster.com/scotty/internal/auth" "go.uploadedlobster.com/scotty/internal/backends" + "go.uploadedlobster.com/scotty/internal/cli" "go.uploadedlobster.com/scotty/internal/models" "go.uploadedlobster.com/scotty/internal/storage" "golang.org/x/oauth2" @@ -36,7 +37,7 @@ var authCmd = &cobra.Command{ Short: "Authenticate with a backend", Long: `For backends requiring authentication this command can be used to authenticate.`, Run: func(cmd *cobra.Command, args []string) { - serviceName, serviceConfig := getConfigFromFlag(cmd, "service") + serviceName, serviceConfig := cli.GetConfigFromFlag(cmd, "service") backend, err := backends.ResolveBackend[models.OAuth2Authenticator](serviceConfig) cobra.CheckErr(err) diff --git a/cmd/common.go b/cmd/common.go deleted file mode 100644 index cee8c16..0000000 --- a/cmd/common.go +++ /dev/null @@ -1,161 +0,0 @@ -/* -Copyright © 2023 Philipp Wolfer - -This file is part of Scotty. - -Scotty is free software: you can redistribute it and/or modify it under the -terms of the GNU General Public License as published by the Free Software -Foundation, either version 3 of the License, or (at your option) any later version. - -Scotty is distributed in the hope that it will be useful, but WITHOUT ANY -WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR -A PARTICULAR PURPOSE. See the GNU General Public License for more details. - -You should have received a copy of the GNU General Public License along with -Scotty. If not, see . -*/ -package cmd - -import ( - "fmt" - "sync" - "time" - - "github.com/spf13/cobra" - "github.com/spf13/viper" - "go.uploadedlobster.com/scotty/internal/backends" - "go.uploadedlobster.com/scotty/internal/models" - "go.uploadedlobster.com/scotty/internal/storage" -) - -func getConfigFromFlag(cmd *cobra.Command, flagName string) (string, *viper.Viper) { - configName := cmd.Flag(flagName).Value.String() - var config *viper.Viper - servicesConfig := viper.Sub("service") - if servicesConfig != nil { - config = servicesConfig.Sub(configName) - } - if config == nil { - cobra.CheckErr(fmt.Sprintf("Invalid source configuration \"%s\"", configName)) - } - return configName, config -} - -func getInt64FromFlag(cmd *cobra.Command, flagName string) (result int64) { - result, err := cmd.Flags().GetInt64(flagName) - if err != nil { - result = 0 - } - return -} - -type backendInfo[T models.Backend, R models.ListensResult | models.LovesResult] struct { - configName string - backend T -} - -type exportBackendInfo[T models.Backend, R models.ListensResult | models.LovesResult] struct { - backendInfo[T, R] - processor backends.ExportProcessor[R] -} - -type importBackendInfo[T models.Backend, R models.ListensResult | models.LovesResult] struct { - backendInfo[T, R] - processor backends.ImportProcessor[R] -} - -func resolveBackends[E models.Backend, I models.ImportBackend, R models.ListensResult | models.LovesResult](cmd *cobra.Command) (*exportBackendInfo[E, R], *importBackendInfo[I, R], error) { - sourceName, sourceConfig := getConfigFromFlag(cmd, "from") - targetName, targetConfig := getConfigFromFlag(cmd, "to") - // Initialize backends - exportBackend, err := backends.ResolveBackend[E](sourceConfig) - if err != nil { - return nil, nil, err - } - importBackend, err := backends.ResolveBackend[I](targetConfig) - if err != nil { - return nil, nil, err - } - - exportInfo := exportBackendInfo[E, R]{ - backendInfo: backendInfo[E, R]{ - configName: sourceName, - backend: exportBackend, - }, - } - - importInfo := importBackendInfo[I, R]{ - backendInfo: backendInfo[I, R]{ - configName: targetName, - backend: importBackend, - }, - } - - return &exportInfo, &importInfo, nil -} - -func cmdExportImport[E models.Backend, I models.ImportBackend, R models.ListensResult | models.LovesResult](cmd *cobra.Command, entity string, exp *exportBackendInfo[E, R], imp *importBackendInfo[I, R]) { - sourceName := exp.configName - targetName := imp.configName - fmt.Printf("Transferring %s from %s to %s...\n", entity, sourceName, targetName) - - // Setup database - db, err := storage.New(viper.GetString("database")) - cobra.CheckErr(err) - - // Authenticate backends, if needed - config := viper.GetViper() - _, err = backends.Authenticate(sourceName, exp.backend, db, config) - cobra.CheckErr(err) - - _, err = backends.Authenticate(targetName, imp.backend, db, config) - cobra.CheckErr(err) - - // Read timestamp - timestamp := time.Unix(getInt64FromFlag(cmd, "timestamp"), 0) - if timestamp == time.Unix(0, 0) { - timestamp, err = db.GetImportTimestamp(sourceName, targetName, entity) - cobra.CheckErr(err) - } - fmt.Printf("From timestamp: %v (%v)\n", timestamp, timestamp.Unix()) - - // Prepare progress bars - exportProgress := make(chan models.Progress) - importProgress := make(chan models.Progress) - var wg sync.WaitGroup - progress := progressBar(&wg, exportProgress, importProgress) - - // Export from source - exportChan := make(chan R, 1000) - go exp.processor.Process(timestamp, exportChan, exportProgress) - - // Import into target - resultChan := make(chan models.ImportResult) - go imp.processor.Process(exportChan, resultChan, importProgress) - result := <-resultChan - close(exportProgress) - wg.Wait() - progress.Wait() - if result.Error != nil { - fmt.Printf("Import failed, last reported timestamp was %v (%v)\n", result.LastTimestamp, result.LastTimestamp.Unix()) - cobra.CheckErr(result.Error) - } - fmt.Printf("Imported %v of %v %s into %v.\n", - result.ImportCount, result.TotalCount, entity, targetName) - - // Update timestamp - if result.LastTimestamp.Unix() < timestamp.Unix() { - result.LastTimestamp = timestamp - } - fmt.Printf("Latest timestamp: %v (%v)\n", result.LastTimestamp, result.LastTimestamp.Unix()) - err = db.SetImportTimestamp(sourceName, targetName, entity, result.LastTimestamp) - cobra.CheckErr(err) - - // Print errors - if len(result.ImportErrors) > 0 { - fmt.Printf("\nDuring the import the following errors occurred:\n") - for _, err := range result.ImportErrors { - fmt.Printf("Error: %v\n", err) - } - } -} diff --git a/cmd/listens.go b/cmd/listens.go index 1d52dc5..8bf3c3b 100644 --- a/cmd/listens.go +++ b/cmd/listens.go @@ -18,8 +18,11 @@ package cmd import ( "github.com/spf13/cobra" + "github.com/spf13/viper" "go.uploadedlobster.com/scotty/internal/backends" + "go.uploadedlobster.com/scotty/internal/cli" "go.uploadedlobster.com/scotty/internal/models" + "go.uploadedlobster.com/scotty/internal/storage" ) // listensCmd represents the listens command @@ -28,15 +31,18 @@ var listensCmd = &cobra.Command{ Short: "Transfer listens between two services", Long: `Transfers listens between two configured services.`, Run: func(cmd *cobra.Command, args []string) { - exp, imp, err := resolveBackends[models.ListensExport, models.ListensImport, models.ListensResult](cmd) + db, err := storage.New(viper.GetString("database")) + cobra.CheckErr(err) + c, err := cli.NewTransferCmd[ + models.ListensExport, + models.ListensImport, + models.ListensResult, + ](cmd, &db, "listens") + cobra.CheckErr(err) + exp := backends.ListensExportProcessor{Backend: c.ExpBackend} + imp := backends.ListensImportProcessor{Backend: c.ImpBackend} + err = c.Transfer(exp, imp) cobra.CheckErr(err) - exp.processor = backends.ListensExportProcessor{ - Backend: exp.backend, - } - imp.processor = backends.ListensImportProcessor{ - Backend: imp.backend, - } - cmdExportImport(cmd, "listens", exp, imp) }, } diff --git a/cmd/loves.go b/cmd/loves.go index ea257a0..f44cb5b 100644 --- a/cmd/loves.go +++ b/cmd/loves.go @@ -18,8 +18,11 @@ package cmd import ( "github.com/spf13/cobra" + "github.com/spf13/viper" "go.uploadedlobster.com/scotty/internal/backends" + "go.uploadedlobster.com/scotty/internal/cli" "go.uploadedlobster.com/scotty/internal/models" + "go.uploadedlobster.com/scotty/internal/storage" ) // lovesCmd represents the loves command @@ -28,15 +31,18 @@ var lovesCmd = &cobra.Command{ Short: "Transfer loves between two services", Long: `Transfers loves between two configured services.`, Run: func(cmd *cobra.Command, args []string) { - exp, imp, err := resolveBackends[models.LovesExport, models.LovesImport, models.LovesResult](cmd) + db, err := storage.New(viper.GetString("database")) + cobra.CheckErr(err) + c, err := cli.NewTransferCmd[ + models.LovesExport, + models.LovesImport, + models.LovesResult, + ](cmd, &db, "loves") + cobra.CheckErr(err) + exp := backends.LovesExportProcessor{Backend: c.ExpBackend} + imp := backends.LovesImportProcessor{Backend: c.ImpBackend} + err = c.Transfer(exp, imp) cobra.CheckErr(err) - exp.processor = backends.LovesExportProcessor{ - Backend: exp.backend, - } - imp.processor = backends.LovesImportProcessor{ - Backend: imp.backend, - } - cmdExportImport(cmd, "loves", exp, imp) }, } diff --git a/internal/cli/common.go b/internal/cli/common.go new file mode 100644 index 0000000..a26253d --- /dev/null +++ b/internal/cli/common.go @@ -0,0 +1,44 @@ +/* +Copyright © 2023 Philipp Wolfer + +Scotty is free software: you can redistribute it and/or modify it under the +terms of the GNU General Public License as published by the Free Software +Foundation, either version 3 of the License, or (at your option) any later version. + +Scotty is distributed in the hope that it will be useful, but WITHOUT ANY +WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. See the GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along with +Scotty. If not, see . +*/ + +package cli + +import ( + "fmt" + + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +func GetConfigFromFlag(cmd *cobra.Command, flagName string) (string, *viper.Viper) { + configName := cmd.Flag(flagName).Value.String() + var config *viper.Viper + servicesConfig := viper.Sub("service") + if servicesConfig != nil { + config = servicesConfig.Sub(configName) + } + if config == nil { + cobra.CheckErr(fmt.Sprintf("Invalid source configuration \"%s\"", configName)) + } + return configName, config +} + +func getInt64FromFlag(cmd *cobra.Command, flagName string) (result int64) { + result, err := cmd.Flags().GetInt64(flagName) + if err != nil { + result = 0 + } + return +} diff --git a/cmd/progress.go b/internal/cli/progress.go similarity index 99% rename from cmd/progress.go rename to internal/cli/progress.go index b328f29..457383c 100644 --- a/cmd/progress.go +++ b/internal/cli/progress.go @@ -15,7 +15,7 @@ You should have received a copy of the GNU General Public License along with Scotty. If not, see . */ -package cmd +package cli import ( "sync" diff --git a/internal/cli/transfer.go b/internal/cli/transfer.go new file mode 100644 index 0000000..35add01 --- /dev/null +++ b/internal/cli/transfer.go @@ -0,0 +1,161 @@ +/* +Copyright © 2023 Philipp Wolfer + +Scotty is free software: you can redistribute it and/or modify it under the +terms of the GNU General Public License as published by the Free Software +Foundation, either version 3 of the License, or (at your option) any later version. + +Scotty is distributed in the hope that it will be useful, but WITHOUT ANY +WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. See the GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along with +Scotty. If not, see . +*/ + +package cli + +import ( + "fmt" + "sync" + "time" + + "github.com/spf13/cobra" + "github.com/spf13/viper" + "go.uploadedlobster.com/scotty/internal/backends" + "go.uploadedlobster.com/scotty/internal/models" + "go.uploadedlobster.com/scotty/internal/storage" +) + +func NewTransferCmd[ + E models.Backend, + I models.ImportBackend, + R models.ListensResult | models.LovesResult, +]( + cmd *cobra.Command, + db *storage.Database, + entity string, +) (TransferCmd[E, I, R], error) { + c := TransferCmd[E, I, R]{ + cmd: cmd, + db: db, + entity: entity, + } + err := c.resolveBackends() + if err != nil { + return c, err + } + return c, nil +} + +type TransferCmd[E models.Backend, I models.ImportBackend, R models.ListensResult | models.LovesResult] struct { + cmd *cobra.Command + db *storage.Database + entity string + sourceName string + targetName string + ExpBackend E + ImpBackend I +} + +func (c *TransferCmd[E, I, R]) resolveBackends() error { + sourceName, sourceConfig := GetConfigFromFlag(c.cmd, "from") + targetName, targetConfig := GetConfigFromFlag(c.cmd, "to") + + // Initialize backends + expBackend, err := backends.ResolveBackend[E](sourceConfig) + if err != nil { + return err + } + impBackend, err := backends.ResolveBackend[I](targetConfig) + if err != nil { + return err + } + + c.sourceName = sourceName + c.targetName = targetName + c.ExpBackend = expBackend + c.ImpBackend = impBackend + return nil +} + +func (c *TransferCmd[E, I, R]) Transfer(exp backends.ExportProcessor[R], imp backends.ImportProcessor[R]) error { + fmt.Printf("Transferring %s from %s to %s...\n", c.entity, c.sourceName, c.targetName) + + // Authenticate backends, if needed + config := viper.GetViper() + _, err := backends.Authenticate(c.sourceName, c.ExpBackend, *c.db, config) + if err != nil { + return err + } + + _, err = backends.Authenticate(c.targetName, c.ImpBackend, *c.db, config) + if err != nil { + return err + } + + // Read timestamp + timestamp, err := c.timestamp() + if err != nil { + return err + } + fmt.Printf("From timestamp: %v (%v)\n", timestamp, timestamp.Unix()) + + // Prepare progress bars + exportProgress := make(chan models.Progress) + importProgress := make(chan models.Progress) + var wg sync.WaitGroup + progress := progressBar(&wg, exportProgress, importProgress) + + // Export from source + exportChan := make(chan R, 1000) + go exp.Process(timestamp, exportChan, exportProgress) + + // Import into target + resultChan := make(chan models.ImportResult) + go imp.Process(exportChan, resultChan, importProgress) + result := <-resultChan + close(exportProgress) + wg.Wait() + progress.Wait() + if result.Error != nil { + fmt.Printf("Import failed, last reported timestamp was %v (%v)\n", result.LastTimestamp, result.LastTimestamp.Unix()) + return result.Error + } + fmt.Printf("Imported %v of %v %s into %v.\n", + result.ImportCount, result.TotalCount, c.entity, c.targetName) + + // Update timestamp + err = c.updateTimestamp(result, timestamp) + if err != nil { + return err + } + + // Print errors + if len(result.ImportErrors) > 0 { + fmt.Printf("\nDuring the import the following errors occurred:\n") + for _, err := range result.ImportErrors { + fmt.Printf("Error: %v\n", err) + } + } + + return nil +} + +func (c *TransferCmd[E, I, R]) timestamp() (time.Time, error) { + timestamp := time.Unix(getInt64FromFlag(c.cmd, "timestamp"), 0) + if timestamp == time.Unix(0, 0) { + timestamp, err := c.db.GetImportTimestamp(c.sourceName, c.targetName, c.entity) + return timestamp, err + } + return timestamp, nil +} + +func (c *TransferCmd[E, I, R]) updateTimestamp(result models.ImportResult, oldTimestamp time.Time) error { + if result.LastTimestamp.Unix() < oldTimestamp.Unix() { + result.LastTimestamp = oldTimestamp + } + fmt.Printf("Latest timestamp: %v (%v)\n", result.LastTimestamp, result.LastTimestamp.Unix()) + err := c.db.SetImportTimestamp(c.sourceName, c.targetName, c.entity, result.LastTimestamp) + return err +}