diff --git a/internal/backends/export.go b/internal/backends/export.go index 54daafb..0deebc6 100644 --- a/internal/backends/export.go +++ b/internal/backends/export.go @@ -16,6 +16,7 @@ Scotty. If not, see . package backends import ( + "context" "sync" "time" @@ -24,7 +25,7 @@ import ( type ExportProcessor[T models.ListensResult | models.LovesResult] interface { ExportBackend() models.Backend - Process(wg *sync.WaitGroup, oldestTimestamp time.Time, results chan T, progress chan models.TransferProgress) + Process(ctx context.Context, wg *sync.WaitGroup, oldestTimestamp time.Time, results chan T, progress chan models.TransferProgress) } type ListensExportProcessor struct { @@ -35,7 +36,7 @@ func (p ListensExportProcessor) ExportBackend() models.Backend { return p.Backend } -func (p ListensExportProcessor) Process(wg *sync.WaitGroup, oldestTimestamp time.Time, results chan models.ListensResult, progress chan models.TransferProgress) { +func (p ListensExportProcessor) Process(ctx context.Context, wg *sync.WaitGroup, oldestTimestamp time.Time, results chan models.ListensResult, progress chan models.TransferProgress) { wg.Add(1) defer wg.Done() defer close(results) @@ -50,7 +51,7 @@ func (p LovesExportProcessor) ExportBackend() models.Backend { return p.Backend } -func (p LovesExportProcessor) Process(wg *sync.WaitGroup, oldestTimestamp time.Time, results chan models.LovesResult, progress chan models.TransferProgress) { +func (p LovesExportProcessor) Process(ctx context.Context, wg *sync.WaitGroup, oldestTimestamp time.Time, results chan models.LovesResult, progress chan models.TransferProgress) { wg.Add(1) defer wg.Done() defer close(results) diff --git a/internal/backends/import.go b/internal/backends/import.go index 0a2e341..e7006bd 100644 --- a/internal/backends/import.go +++ b/internal/backends/import.go @@ -18,6 +18,7 @@ Scotty. If not, see . package backends import ( + "context" "sync" "go.uploadedlobster.com/scotty/internal/models" @@ -25,7 +26,7 @@ import ( type ImportProcessor[T models.ListensResult | models.LovesResult] interface { ImportBackend() models.ImportBackend - Process(wg *sync.WaitGroup, results chan T, out chan models.ImportResult, progress chan models.TransferProgress) + Process(ctx context.Context, wg *sync.WaitGroup, results chan T, out chan models.ImportResult, progress chan models.TransferProgress) Import(export T, result models.ImportResult, out chan models.ImportResult, progress chan models.TransferProgress) (models.ImportResult, error) } @@ -37,8 +38,8 @@ func (p ListensImportProcessor) ImportBackend() models.ImportBackend { return p.Backend } -func (p ListensImportProcessor) Process(wg *sync.WaitGroup, results chan models.ListensResult, out chan models.ImportResult, progress chan models.TransferProgress) { - process(wg, p, results, out, progress) +func (p ListensImportProcessor) Process(ctx context.Context, wg *sync.WaitGroup, results chan models.ListensResult, out chan models.ImportResult, progress chan models.TransferProgress) { + process(ctx, wg, p, results, out, progress) } func (p ListensImportProcessor) Import(export models.ListensResult, result models.ImportResult, out chan models.ImportResult, progress chan models.TransferProgress) (models.ImportResult, error) { @@ -66,8 +67,8 @@ func (p LovesImportProcessor) ImportBackend() models.ImportBackend { return p.Backend } -func (p LovesImportProcessor) Process(wg *sync.WaitGroup, results chan models.LovesResult, out chan models.ImportResult, progress chan models.TransferProgress) { - process(wg, p, results, out, progress) +func (p LovesImportProcessor) Process(ctx context.Context, wg *sync.WaitGroup, results chan models.LovesResult, out chan models.ImportResult, progress chan models.TransferProgress) { + process(ctx, wg, p, results, out, progress) } func (p LovesImportProcessor) Import(export models.LovesResult, result models.ImportResult, out chan models.ImportResult, progress chan models.TransferProgress) (models.ImportResult, error) { @@ -87,7 +88,12 @@ func (p LovesImportProcessor) Import(export models.LovesResult, result models.Im return importResult, nil } -func process[R models.LovesResult | models.ListensResult, P ImportProcessor[R]](wg *sync.WaitGroup, processor P, results chan R, out chan models.ImportResult, progress chan models.TransferProgress) { +func process[R models.LovesResult | models.ListensResult, P ImportProcessor[R]]( + ctx context.Context, wg *sync.WaitGroup, + processor P, results chan R, + out chan models.ImportResult, + progress chan models.TransferProgress, +) { wg.Add(1) defer wg.Done() defer close(out) @@ -100,14 +106,21 @@ func process[R models.LovesResult | models.ListensResult, P ImportProcessor[R]]( } for exportResult := range results { - importResult, err := processor.Import(exportResult, result, out, progress) - result.Update(importResult) - if err != nil { + select { + case <-ctx.Done(): processor.ImportBackend().FinishImport() - out <- handleError(result, err, progress) + out <- handleError(result, ctx.Err(), progress) return + default: + importResult, err := processor.Import(exportResult, result, out, progress) + result.Update(importResult) + if err != nil { + processor.ImportBackend().FinishImport() + out <- handleError(result, err, progress) + return + } + progress <- p.FromImportResult(result, false) } - progress <- p.FromImportResult(result, false) } if err := processor.ImportBackend().FinishImport(); err != nil { diff --git a/internal/cli/transfer.go b/internal/cli/transfer.go index 62dd079..79be3f0 100644 --- a/internal/cli/transfer.go +++ b/internal/cli/transfer.go @@ -16,6 +16,7 @@ Scotty. If not, see . package cli import ( + "context" "errors" "fmt" "strconv" @@ -113,16 +114,22 @@ func (c *TransferCmd[E, I, R]) Transfer(exp backends.ExportProcessor[R], imp bac progressChan := make(chan models.TransferProgress) progress := setupProgressBars(progressChan) + ctx, cancel := context.WithCancel(context.Background()) wg := &sync.WaitGroup{} // Export from source exportChan := make(chan R, 1000) - go exp.Process(wg, timestamp, exportChan, progressChan) + go exp.Process(ctx, wg, timestamp, exportChan, progressChan) // Import into target resultChan := make(chan models.ImportResult) - go imp.Process(wg, exportChan, resultChan, progressChan) + go imp.Process(ctx, wg, exportChan, resultChan, progressChan) result := <-resultChan + + // Once import is done, the context can be cancelled + cancel() + + // Wait for all goroutines to finish wg.Wait() progress.close()