diff --git a/cmd/resticui/resticui.go b/cmd/resticui/resticui.go index 53228ce..3d90d97 100644 --- a/cmd/resticui/resticui.go +++ b/cmd/resticui/resticui.go @@ -61,19 +61,10 @@ func main() { zap.S().Fatalf("Error finding or installing restic: %v", err) } - orchestrator, err := orchestrator.NewOrchestrator(resticPath, cfg, oplog) - if err != nil { - zap.S().Fatalf("Error creating orchestrator: %v", err) - } + orchestrator := orchestrator.NewOrchestrator(resticPath, cfg, oplog) - // Start orchestration loop. - go func() { - err := orchestrator.Run(ctx) - if err != nil && !errors.Is(err, context.Canceled) { - zap.S().Fatal("Orchestrator loop exited with error: ", zap.Error(err)) - cancel() // cancel the context when the orchestrator exits (e.g. on fatal error) - } - }() + // Start orchestration loop. Only exits when ctx is cancelled. + go orchestrator.Run(ctx) apiServer := api.NewServer( configStore, diff --git a/internal/api/server.go b/internal/api/server.go index 9ba44e4..984e4c4 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -236,7 +236,7 @@ func (s *Server) Backup(ctx context.Context, req *types.StringValue) (*emptypb.E if err != nil { return nil, fmt.Errorf("failed to get plan %q: %w", req.Value, err) } - s.orchestrator.EnqueueTask(orchestrator.NewOneofBackupTask(s.orchestrator, plan, time.Now())) + s.orchestrator.ScheduleTask(orchestrator.NewOneofBackupTask(s.orchestrator, plan, time.Now())) return &emptypb.Empty{}, nil } diff --git a/internal/orchestrator/orchestrator.go b/internal/orchestrator/orchestrator.go index e39aebf..c3998cc 100644 --- a/internal/orchestrator/orchestrator.go +++ b/internal/orchestrator/orchestrator.go @@ -21,23 +21,33 @@ var ErrPlanNotFound = errors.New("plan not found") // Orchestrator is responsible for managing repos and backups. type Orchestrator struct { - mu sync.Mutex - config *v1.Config - OpLog *oplog.OpLog - repoPool *resticRepoPool + mu sync.Mutex + config *v1.Config + OpLog *oplog.OpLog + repoPool *resticRepoPool + taskQueue taskQueue - configUpdates chan *v1.Config // configUpdates chan makes config changes available to Run() - externTasks chan Task // externTasks is a channel that externally added tasks can be added to, they will be consumed by Run() + // now for the purpose of testing; used by Run() to get the current time. + now func() time.Time } -func NewOrchestrator(resticBin string, cfg *v1.Config, oplog *oplog.OpLog) (*Orchestrator, error) { - return &Orchestrator{ +func NewOrchestrator(resticBin string, cfg *v1.Config, oplog *oplog.OpLog) *Orchestrator { + var o *Orchestrator + o = &Orchestrator{ config: cfg, OpLog: oplog, // repoPool created with a memory store to ensure the config is updated in an atomic operation with the repo pool's config value. - repoPool: newResticRepoPool(resticBin, &config.MemoryStore{Config: cfg}), - externTasks: make(chan Task, 2), - }, nil + repoPool: newResticRepoPool(resticBin, &config.MemoryStore{Config: cfg}), + taskQueue: taskQueue{ + Now: func() time.Time { + if o.now != nil { + return o.now() + } + return time.Now() + }, + }, + } + return o } func (o *Orchestrator) ApplyConfig(cfg *v1.Config) error { @@ -52,9 +62,15 @@ func (o *Orchestrator) ApplyConfig(cfg *v1.Config) error { return fmt.Errorf("failed to update repo pool config: %w", err) } - if o.configUpdates != nil { - // orchestrator loop is running, notify it of the config change. - o.configUpdates <- cfg + o.taskQueue.Reset() // reset queued tasks, this may loose any ephemeral operations scheduled by RPC. Tasks in progress are not cancelled. + + // Requeue tasks that are affected by the config change. + for _, plan := range cfg.Plans { + t, err := NewScheduledBackupTask(o, plan) + if err != nil { + return fmt.Errorf("schedule backup task for plan %q: %w", plan.Id, err) + } + o.ScheduleTask(t) } return nil @@ -66,7 +82,7 @@ func (o *Orchestrator) GetRepo(repoId string) (repo *RepoOrchestrator, err error r, err := o.repoPool.GetRepo(repoId) if err != nil { - return nil, fmt.Errorf("failed to get repo %q: %w", repoId, err) + return nil, fmt.Errorf("get repo %q: %w", repoId, err) } return r, nil } @@ -89,110 +105,55 @@ func (o *Orchestrator) GetPlan(planId string) (*v1.Plan, error) { } // Run is the main orchestration loop. Cancel the context to stop the loop. -func (o *Orchestrator) Run(mainCtx context.Context) error { +func (o *Orchestrator) Run(mainCtx context.Context) { zap.L().Info("starting orchestrator loop") - o.mu.Lock() - o.configUpdates = make(chan *v1.Config) - o.mu.Unlock() - for { - o.mu.Lock() - config := o.config - o.mu.Unlock() - if o.runVersion(mainCtx, config) { - zap.L().Info("restarting orchestrator loop") - } else { - zap.L().Info("exiting orchestrator loop, context cancelled.") + if mainCtx.Err() != nil { + zap.L().Info("shutting down orchestrator loop, context cancelled.") break } - } - return nil -} -// runImmutable is a helper function for Run() that runs the orchestration loop with a single version of the config. -func (o *Orchestrator) runVersion(mainCtx context.Context, config *v1.Config) bool { - var lock sync.Mutex - ctx, cancel := context.WithCancel(mainCtx) + t := o.taskQueue.Dequeue(mainCtx) + if t == nil { + continue + } - var wg sync.WaitGroup + zap.L().Info("running task", zap.String("task", t.task.Name())) + if err := t.task.Run(mainCtx); err != nil { + zap.L().Error("task failed", zap.String("task", t.task.Name()), zap.Error(err)) + } else { + zap.L().Debug("task finished", zap.String("task", t.task.Name())) + } - var execTask func(t Task) - execTask = func(t Task) { curTime := time.Now() - - runAt := t.Next(curTime) - if runAt == nil { - zap.L().Debug("task has no next run, not scheduling.", zap.String("task", t.Name())) - return + if o.now != nil { + curTime = o.now() } - timer := time.NewTimer(runAt.Sub(curTime)) - zap.L().Info("scheduling task", zap.String("task", t.Name()), zap.String("runAt", runAt.Format(time.RFC3339))) - - wg.Add(1) - go func() { - defer wg.Done() - select { - case <-ctx.Done(): - if !timer.Stop() { - <-timer.C - } - zap.L().Debug("cancelled scheduled (but not running) task, orchestrator context is cancelled.", zap.String("task", t.Name())) - return - case <-timer.C: - lock.Lock() - defer lock.Unlock() - zap.L().Info("running task", zap.String("task", t.Name())) - - // Task execution runs with mainCtx meaning config changes do not interrupt it, but cancelling the orchestration loop will. - if err := t.Run(mainCtx); err != nil { - zap.L().Error("task failed", zap.String("task", t.Name()), zap.Error(err)) - } else { - zap.L().Debug("task finished", zap.String("task", t.Name())) - } - - if ctx.Err() != nil { - zap.L().Debug("not attempting to reschedule task, orchestrator context is cancelled.", zap.String("task", t.Name())) - return - } - - execTask(t) - } - }() - } - - // Schedule all backup tasks. - for _, plan := range config.Plans { - t, err := NewScheduledBackupTask(o, plan) - if err != nil { - zap.L().Error("failed to create backup task for plan", zap.String("plan", plan.Id), zap.Error(err)) - } - - execTask(t) - } - - // wait for either an error or the context to be cancelled, then wait for all tasks. - for { - select { - case t := <-o.externTasks: - execTask(t) - case <-mainCtx.Done(): - zap.L().Info("orchestrator context cancelled, shutting down orchestrator") - cancel() - wg.Wait() - return false - case <-o.configUpdates: - zap.L().Info("orchestrator received config change, waiting for in-progress operations then restarting") - cancel() - wg.Wait() - return true + if nextTime := t.task.Next(curTime); nextTime != nil { + o.taskQueue.Push(scheduledTask{ + task: t.task, + runAt: *nextTime, + }) } } } -func (o *Orchestrator) EnqueueTask(t Task) { - o.externTasks <- t +func (o *Orchestrator) ScheduleTask(t Task) { + curTime := time.Now() + if o.now != nil { + curTime = o.now() + } + nextRun := t.Next(curTime) + if nextRun == nil { + return + } + zap.L().Info("scheduling task", zap.String("task", t.Name()), zap.String("runAt", nextRun.Format(time.RFC3339))) + o.taskQueue.Push(scheduledTask{ + task: t, + runAt: *nextRun, + }) } // resticRepoPool caches restic repos. diff --git a/internal/orchestrator/orchestrator_test.go b/internal/orchestrator/orchestrator_test.go new file mode 100644 index 0000000..06213a1 --- /dev/null +++ b/internal/orchestrator/orchestrator_test.go @@ -0,0 +1,181 @@ +package orchestrator + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/garethgeorge/resticui/internal/config" +) + +type testTask struct { + onRun func() error + onNext func(curTime time.Time) *time.Time +} + +func (t *testTask) Name() string { + return "test" +} + +func (t *testTask) Next(now time.Time) *time.Time { + return t.onNext(now) +} + +func (t *testTask) Run(ctx context.Context) error { + return t.onRun() +} + +func TestTaskScheduling(t *testing.T) { + t.Parallel() + + // Arrange + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + orch := NewOrchestrator("", config.NewDefaultConfig(), nil) + + var wg sync.WaitGroup + wg.Add(1) + task := &testTask{ + onRun: func() error { + wg.Done() + cancel() + return nil + }, + onNext: func(t time.Time) *time.Time { + t = t.Add(10 * time.Millisecond) + return &t + }, + } + + wg.Add(1) + go func() { + defer wg.Done() + orch.Run(ctx) + }() + + // Act + orch.ScheduleTask(task) + + // Assert passes if all tasks run and the orchestrator exists when cancelled. + wg.Wait() +} + +func TestTaskRescheduling(t *testing.T) { + t.Parallel() + + // Arrange + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + orch := NewOrchestrator("", config.NewDefaultConfig(), nil) + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + orch.Run(ctx) + }() + + // Act + count := 0 + ranTimes := 0 + + orch.ScheduleTask(&testTask{ + onNext: func(t time.Time) *time.Time { + if count < 10 { + count += 1 + return &t + } + return nil + }, + onRun: func() error { + ranTimes += 1 + if ranTimes == 10 { + cancel() + } + return nil + }, + }) + + wg.Wait() + + if count != 10 { + t.Errorf("expected 10 Next calls, got %d", count) + } + + if ranTimes != 10 { + t.Errorf("expected 10 Run calls, got %d", ranTimes) + } +} + +func TestGracefulShutdown(t *testing.T) { + t.Parallel() + + // Arrange + orch := NewOrchestrator("", config.NewDefaultConfig(), nil) + ctx, cancel := context.WithCancel(context.Background()) + + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() + + // Act + orch.Run(ctx) +} + +func TestSchedulerWait(t *testing.T) { + t.Parallel() + + // Arrange + curTime := time.Now() + orch := NewOrchestrator("", config.NewDefaultConfig(), nil) + orch.now = func() time.Time { + return curTime + } + + ran := make(chan struct{}) + orch.ScheduleTask(&testTask{ + onNext: func(t time.Time) *time.Time { + t = t.Add(5 * time.Millisecond) + return &t + }, + onRun: func() error { + close(ran) + return nil + }, + }) + + // Act + go orch.Run(context.Background()) + + // Assert + select { + case <-time.NewTimer(20 * time.Millisecond).C: + case <-ran: + t.Errorf("expected task to not run yet") + } + + curTime = time.Now() + + // Schedule another task just to trigger a queue refresh + orch.ScheduleTask(&testTask{ + onNext: func(t time.Time) *time.Time { + t = t.Add(5 * time.Millisecond) + return &t + }, + onRun: func() error { + t.Fatalf("should never run") + return nil + }, + }) + + select { + case <-time.NewTimer(1000 * time.Millisecond).C: + t.Errorf("expected task to run") + case <-ran: + } +} diff --git a/internal/orchestrator/scheduledtaskheap.go b/internal/orchestrator/scheduledtaskheap.go new file mode 100644 index 0000000..d389a87 --- /dev/null +++ b/internal/orchestrator/scheduledtaskheap.go @@ -0,0 +1,144 @@ +package orchestrator + +import ( + "container/heap" + "context" + "sync" + "time" +) + +type taskQueue struct { + dequeueMu sync.Mutex + mu sync.Mutex + heap scheduledTaskHeap + notify chan struct{} + + Now func() time.Time +} + +func (t *taskQueue) curTime() time.Time { + if t.Now == nil { + return time.Now() + } + return t.Now() +} + +func (t *taskQueue) Push(task scheduledTask) { + t.mu.Lock() + defer t.mu.Unlock() + + if task.task == nil { + panic("task cannot be nil") + } + + heap.Push(&t.heap, &task) + if t.notify != nil { + t.notify <- struct{}{} + } +} + +func (t *taskQueue) Reset() { + t.mu.Lock() + defer t.mu.Unlock() + + t.heap.tasks = nil + if t.notify != nil { + t.notify <- struct{}{} + } +} + +func (t *taskQueue) Dequeue(ctx context.Context) *scheduledTask { + t.dequeueMu.Lock() + defer t.dequeueMu.Unlock() + + t.notify = make(chan struct{}, 1) + defer func() { + t.notify = nil + }() + + t.mu.Lock() + for { + first, ok := t.heap.Peek().(*scheduledTask) + if !ok { // no tasks in heap. + t.mu.Unlock() + select { + case <-ctx.Done(): + return nil + case <-t.notify: + } + t.mu.Lock() + continue + } + t.mu.Unlock() + timer := time.NewTimer(first.runAt.Sub(t.curTime())) + + t.mu.Lock() + select { + case <-timer.C: + if t.heap.Len() == 0 { + break + } + first = t.heap.Peek().(*scheduledTask) + if first.runAt.After(t.curTime()) { + // task is not yet ready to run + break + } + + heap.Pop(&t.heap) // remove the task from the heap + t.mu.Unlock() + return first + case <-t.notify: // new task was added, loop again to ensure we have the earliest task. + if !timer.Stop() { + <-timer.C + } + case <-ctx.Done(): + if !timer.Stop() { + <-timer.C + } + t.mu.Unlock() + return nil + } + } +} + +type scheduledTask struct { + task Task + runAt time.Time +} + +type scheduledTaskHeap struct { + tasks []*scheduledTask +} + +var _ heap.Interface = &scheduledTaskHeap{} + +func (h *scheduledTaskHeap) Len() int { + return len(h.tasks) +} + +func (h *scheduledTaskHeap) Less(i, j int) bool { + return h.tasks[i].runAt.Before(h.tasks[j].runAt) +} + +func (h *scheduledTaskHeap) Swap(i, j int) { + h.tasks[i], h.tasks[j] = h.tasks[j], h.tasks[i] +} + +func (h *scheduledTaskHeap) Push(x interface{}) { + h.tasks = append(h.tasks, x.(*scheduledTask)) +} + +func (h *scheduledTaskHeap) Pop() interface{} { + old := h.tasks + n := len(old) + x := old[n-1] + h.tasks = old[0 : n-1] + return x +} + +func (h *scheduledTaskHeap) Peek() interface{} { + if len(h.tasks) == 0 { + return nil + } + return h.tasks[0] +} diff --git a/internal/orchestrator/scheduledtaskheap_test.go b/internal/orchestrator/scheduledtaskheap_test.go new file mode 100644 index 0000000..4f000ea --- /dev/null +++ b/internal/orchestrator/scheduledtaskheap_test.go @@ -0,0 +1,85 @@ +package orchestrator + +import ( + "context" + "reflect" + "testing" + "time" +) + +type heapTestTask struct { + name string +} + +var _ Task = &heapTestTask{} + +func (t *heapTestTask) Name() string { + return t.name +} + +func (t *heapTestTask) Next(now time.Time) *time.Time { + return nil +} + +func (t *heapTestTask) Run(ctx context.Context) error { + return nil +} + +func TestTaskQueueOrdering(t *testing.T) { + h := taskQueue{} + + h.Push(scheduledTask{runAt: time.Now().Add(1 * time.Millisecond), task: &heapTestTask{name: "1"}}) + h.Push(scheduledTask{runAt: time.Now().Add(2 * time.Millisecond), task: &heapTestTask{name: "2"}}) + h.Push(scheduledTask{runAt: time.Now().Add(2 * time.Millisecond), task: &heapTestTask{name: "3"}}) + + wantSeq := []string{"1", "2", "3"} + seq := []string{} + for i := 0; i < 3; i++ { + task := h.Dequeue(context.Background()) + if task == nil || task.task == nil { + t.Fatal("expected task") + } + seq = append(seq, task.task.Name()) + } + + if !reflect.DeepEqual(seq, wantSeq) { + t.Errorf("got %v, want %v", seq, wantSeq) + } +} + +func TestLiveTaskEnqueue(t *testing.T) { + h := taskQueue{} + + go func() { + time.Sleep(1 * time.Millisecond) + h.Push(scheduledTask{runAt: time.Now().Add(1 * time.Millisecond), task: &heapTestTask{name: "1"}}) + }() + + t1 := h.Dequeue(context.Background()) + if t1.task.Name() != "1" { + t.Errorf("got %s, want 1", t1.task.Name()) + } +} + +func TestTaskQueueReset(t *testing.T) { + h := taskQueue{} + + h.Push(scheduledTask{runAt: time.Now().Add(1 * time.Millisecond), task: &heapTestTask{name: "1"}}) + h.Push(scheduledTask{runAt: time.Now().Add(2 * time.Millisecond), task: &heapTestTask{name: "2"}}) + h.Push(scheduledTask{runAt: time.Now().Add(2 * time.Millisecond), task: &heapTestTask{name: "3"}}) + + if h.Dequeue(context.Background()).task.Name() != "1" { + t.Fatal("expected 1") + } + h.Reset() + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(1 * time.Millisecond) + cancel() + }() + + if h.Dequeue(ctx) != nil { + t.Fatal("expected nil task") + } +}