chore: refactor task scheduler

This commit is contained in:
garethgeorge
2023-11-28 19:57:50 -08:00
parent 4957496787
commit 1b67e2b200
6 changed files with 479 additions and 117 deletions

View File

@@ -61,19 +61,10 @@ func main() {
zap.S().Fatalf("Error finding or installing restic: %v", err) zap.S().Fatalf("Error finding or installing restic: %v", err)
} }
orchestrator, err := orchestrator.NewOrchestrator(resticPath, cfg, oplog) orchestrator := orchestrator.NewOrchestrator(resticPath, cfg, oplog)
if err != nil {
zap.S().Fatalf("Error creating orchestrator: %v", err)
}
// Start orchestration loop. // Start orchestration loop. Only exits when ctx is cancelled.
go func() { go orchestrator.Run(ctx)
err := orchestrator.Run(ctx)
if err != nil && !errors.Is(err, context.Canceled) {
zap.S().Fatal("Orchestrator loop exited with error: ", zap.Error(err))
cancel() // cancel the context when the orchestrator exits (e.g. on fatal error)
}
}()
apiServer := api.NewServer( apiServer := api.NewServer(
configStore, configStore,

View File

@@ -236,7 +236,7 @@ func (s *Server) Backup(ctx context.Context, req *types.StringValue) (*emptypb.E
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get plan %q: %w", req.Value, err) return nil, fmt.Errorf("failed to get plan %q: %w", req.Value, err)
} }
s.orchestrator.EnqueueTask(orchestrator.NewOneofBackupTask(s.orchestrator, plan, time.Now())) s.orchestrator.ScheduleTask(orchestrator.NewOneofBackupTask(s.orchestrator, plan, time.Now()))
return &emptypb.Empty{}, nil return &emptypb.Empty{}, nil
} }

View File

@@ -21,23 +21,33 @@ var ErrPlanNotFound = errors.New("plan not found")
// Orchestrator is responsible for managing repos and backups. // Orchestrator is responsible for managing repos and backups.
type Orchestrator struct { type Orchestrator struct {
mu sync.Mutex mu sync.Mutex
config *v1.Config config *v1.Config
OpLog *oplog.OpLog OpLog *oplog.OpLog
repoPool *resticRepoPool repoPool *resticRepoPool
taskQueue taskQueue
configUpdates chan *v1.Config // configUpdates chan makes config changes available to Run() // now for the purpose of testing; used by Run() to get the current time.
externTasks chan Task // externTasks is a channel that externally added tasks can be added to, they will be consumed by Run() now func() time.Time
} }
func NewOrchestrator(resticBin string, cfg *v1.Config, oplog *oplog.OpLog) (*Orchestrator, error) { func NewOrchestrator(resticBin string, cfg *v1.Config, oplog *oplog.OpLog) *Orchestrator {
return &Orchestrator{ var o *Orchestrator
o = &Orchestrator{
config: cfg, config: cfg,
OpLog: oplog, OpLog: oplog,
// repoPool created with a memory store to ensure the config is updated in an atomic operation with the repo pool's config value. // repoPool created with a memory store to ensure the config is updated in an atomic operation with the repo pool's config value.
repoPool: newResticRepoPool(resticBin, &config.MemoryStore{Config: cfg}), repoPool: newResticRepoPool(resticBin, &config.MemoryStore{Config: cfg}),
externTasks: make(chan Task, 2), taskQueue: taskQueue{
}, nil Now: func() time.Time {
if o.now != nil {
return o.now()
}
return time.Now()
},
},
}
return o
} }
func (o *Orchestrator) ApplyConfig(cfg *v1.Config) error { func (o *Orchestrator) ApplyConfig(cfg *v1.Config) error {
@@ -52,9 +62,15 @@ func (o *Orchestrator) ApplyConfig(cfg *v1.Config) error {
return fmt.Errorf("failed to update repo pool config: %w", err) return fmt.Errorf("failed to update repo pool config: %w", err)
} }
if o.configUpdates != nil { o.taskQueue.Reset() // reset queued tasks, this may loose any ephemeral operations scheduled by RPC. Tasks in progress are not cancelled.
// orchestrator loop is running, notify it of the config change.
o.configUpdates <- cfg // Requeue tasks that are affected by the config change.
for _, plan := range cfg.Plans {
t, err := NewScheduledBackupTask(o, plan)
if err != nil {
return fmt.Errorf("schedule backup task for plan %q: %w", plan.Id, err)
}
o.ScheduleTask(t)
} }
return nil return nil
@@ -66,7 +82,7 @@ func (o *Orchestrator) GetRepo(repoId string) (repo *RepoOrchestrator, err error
r, err := o.repoPool.GetRepo(repoId) r, err := o.repoPool.GetRepo(repoId)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get repo %q: %w", repoId, err) return nil, fmt.Errorf("get repo %q: %w", repoId, err)
} }
return r, nil return r, nil
} }
@@ -89,110 +105,55 @@ func (o *Orchestrator) GetPlan(planId string) (*v1.Plan, error) {
} }
// Run is the main orchestration loop. Cancel the context to stop the loop. // Run is the main orchestration loop. Cancel the context to stop the loop.
func (o *Orchestrator) Run(mainCtx context.Context) error { func (o *Orchestrator) Run(mainCtx context.Context) {
zap.L().Info("starting orchestrator loop") zap.L().Info("starting orchestrator loop")
o.mu.Lock()
o.configUpdates = make(chan *v1.Config)
o.mu.Unlock()
for { for {
o.mu.Lock() if mainCtx.Err() != nil {
config := o.config zap.L().Info("shutting down orchestrator loop, context cancelled.")
o.mu.Unlock()
if o.runVersion(mainCtx, config) {
zap.L().Info("restarting orchestrator loop")
} else {
zap.L().Info("exiting orchestrator loop, context cancelled.")
break break
} }
}
return nil
}
// runImmutable is a helper function for Run() that runs the orchestration loop with a single version of the config. t := o.taskQueue.Dequeue(mainCtx)
func (o *Orchestrator) runVersion(mainCtx context.Context, config *v1.Config) bool { if t == nil {
var lock sync.Mutex continue
ctx, cancel := context.WithCancel(mainCtx) }
var wg sync.WaitGroup zap.L().Info("running task", zap.String("task", t.task.Name()))
if err := t.task.Run(mainCtx); err != nil {
zap.L().Error("task failed", zap.String("task", t.task.Name()), zap.Error(err))
} else {
zap.L().Debug("task finished", zap.String("task", t.task.Name()))
}
var execTask func(t Task)
execTask = func(t Task) {
curTime := time.Now() curTime := time.Now()
if o.now != nil {
runAt := t.Next(curTime) curTime = o.now()
if runAt == nil {
zap.L().Debug("task has no next run, not scheduling.", zap.String("task", t.Name()))
return
} }
timer := time.NewTimer(runAt.Sub(curTime)) if nextTime := t.task.Next(curTime); nextTime != nil {
zap.L().Info("scheduling task", zap.String("task", t.Name()), zap.String("runAt", runAt.Format(time.RFC3339))) o.taskQueue.Push(scheduledTask{
task: t.task,
wg.Add(1) runAt: *nextTime,
go func() { })
defer wg.Done()
select {
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
zap.L().Debug("cancelled scheduled (but not running) task, orchestrator context is cancelled.", zap.String("task", t.Name()))
return
case <-timer.C:
lock.Lock()
defer lock.Unlock()
zap.L().Info("running task", zap.String("task", t.Name()))
// Task execution runs with mainCtx meaning config changes do not interrupt it, but cancelling the orchestration loop will.
if err := t.Run(mainCtx); err != nil {
zap.L().Error("task failed", zap.String("task", t.Name()), zap.Error(err))
} else {
zap.L().Debug("task finished", zap.String("task", t.Name()))
}
if ctx.Err() != nil {
zap.L().Debug("not attempting to reschedule task, orchestrator context is cancelled.", zap.String("task", t.Name()))
return
}
execTask(t)
}
}()
}
// Schedule all backup tasks.
for _, plan := range config.Plans {
t, err := NewScheduledBackupTask(o, plan)
if err != nil {
zap.L().Error("failed to create backup task for plan", zap.String("plan", plan.Id), zap.Error(err))
}
execTask(t)
}
// wait for either an error or the context to be cancelled, then wait for all tasks.
for {
select {
case t := <-o.externTasks:
execTask(t)
case <-mainCtx.Done():
zap.L().Info("orchestrator context cancelled, shutting down orchestrator")
cancel()
wg.Wait()
return false
case <-o.configUpdates:
zap.L().Info("orchestrator received config change, waiting for in-progress operations then restarting")
cancel()
wg.Wait()
return true
} }
} }
} }
func (o *Orchestrator) EnqueueTask(t Task) { func (o *Orchestrator) ScheduleTask(t Task) {
o.externTasks <- t curTime := time.Now()
if o.now != nil {
curTime = o.now()
}
nextRun := t.Next(curTime)
if nextRun == nil {
return
}
zap.L().Info("scheduling task", zap.String("task", t.Name()), zap.String("runAt", nextRun.Format(time.RFC3339)))
o.taskQueue.Push(scheduledTask{
task: t,
runAt: *nextRun,
})
} }
// resticRepoPool caches restic repos. // resticRepoPool caches restic repos.

View File

@@ -0,0 +1,181 @@
package orchestrator
import (
"context"
"sync"
"testing"
"time"
"github.com/garethgeorge/resticui/internal/config"
)
type testTask struct {
onRun func() error
onNext func(curTime time.Time) *time.Time
}
func (t *testTask) Name() string {
return "test"
}
func (t *testTask) Next(now time.Time) *time.Time {
return t.onNext(now)
}
func (t *testTask) Run(ctx context.Context) error {
return t.onRun()
}
func TestTaskScheduling(t *testing.T) {
t.Parallel()
// Arrange
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
orch := NewOrchestrator("", config.NewDefaultConfig(), nil)
var wg sync.WaitGroup
wg.Add(1)
task := &testTask{
onRun: func() error {
wg.Done()
cancel()
return nil
},
onNext: func(t time.Time) *time.Time {
t = t.Add(10 * time.Millisecond)
return &t
},
}
wg.Add(1)
go func() {
defer wg.Done()
orch.Run(ctx)
}()
// Act
orch.ScheduleTask(task)
// Assert passes if all tasks run and the orchestrator exists when cancelled.
wg.Wait()
}
func TestTaskRescheduling(t *testing.T) {
t.Parallel()
// Arrange
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
orch := NewOrchestrator("", config.NewDefaultConfig(), nil)
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
orch.Run(ctx)
}()
// Act
count := 0
ranTimes := 0
orch.ScheduleTask(&testTask{
onNext: func(t time.Time) *time.Time {
if count < 10 {
count += 1
return &t
}
return nil
},
onRun: func() error {
ranTimes += 1
if ranTimes == 10 {
cancel()
}
return nil
},
})
wg.Wait()
if count != 10 {
t.Errorf("expected 10 Next calls, got %d", count)
}
if ranTimes != 10 {
t.Errorf("expected 10 Run calls, got %d", ranTimes)
}
}
func TestGracefulShutdown(t *testing.T) {
t.Parallel()
// Arrange
orch := NewOrchestrator("", config.NewDefaultConfig(), nil)
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(10 * time.Millisecond)
cancel()
}()
// Act
orch.Run(ctx)
}
func TestSchedulerWait(t *testing.T) {
t.Parallel()
// Arrange
curTime := time.Now()
orch := NewOrchestrator("", config.NewDefaultConfig(), nil)
orch.now = func() time.Time {
return curTime
}
ran := make(chan struct{})
orch.ScheduleTask(&testTask{
onNext: func(t time.Time) *time.Time {
t = t.Add(5 * time.Millisecond)
return &t
},
onRun: func() error {
close(ran)
return nil
},
})
// Act
go orch.Run(context.Background())
// Assert
select {
case <-time.NewTimer(20 * time.Millisecond).C:
case <-ran:
t.Errorf("expected task to not run yet")
}
curTime = time.Now()
// Schedule another task just to trigger a queue refresh
orch.ScheduleTask(&testTask{
onNext: func(t time.Time) *time.Time {
t = t.Add(5 * time.Millisecond)
return &t
},
onRun: func() error {
t.Fatalf("should never run")
return nil
},
})
select {
case <-time.NewTimer(1000 * time.Millisecond).C:
t.Errorf("expected task to run")
case <-ran:
}
}

View File

@@ -0,0 +1,144 @@
package orchestrator
import (
"container/heap"
"context"
"sync"
"time"
)
type taskQueue struct {
dequeueMu sync.Mutex
mu sync.Mutex
heap scheduledTaskHeap
notify chan struct{}
Now func() time.Time
}
func (t *taskQueue) curTime() time.Time {
if t.Now == nil {
return time.Now()
}
return t.Now()
}
func (t *taskQueue) Push(task scheduledTask) {
t.mu.Lock()
defer t.mu.Unlock()
if task.task == nil {
panic("task cannot be nil")
}
heap.Push(&t.heap, &task)
if t.notify != nil {
t.notify <- struct{}{}
}
}
func (t *taskQueue) Reset() {
t.mu.Lock()
defer t.mu.Unlock()
t.heap.tasks = nil
if t.notify != nil {
t.notify <- struct{}{}
}
}
func (t *taskQueue) Dequeue(ctx context.Context) *scheduledTask {
t.dequeueMu.Lock()
defer t.dequeueMu.Unlock()
t.notify = make(chan struct{}, 1)
defer func() {
t.notify = nil
}()
t.mu.Lock()
for {
first, ok := t.heap.Peek().(*scheduledTask)
if !ok { // no tasks in heap.
t.mu.Unlock()
select {
case <-ctx.Done():
return nil
case <-t.notify:
}
t.mu.Lock()
continue
}
t.mu.Unlock()
timer := time.NewTimer(first.runAt.Sub(t.curTime()))
t.mu.Lock()
select {
case <-timer.C:
if t.heap.Len() == 0 {
break
}
first = t.heap.Peek().(*scheduledTask)
if first.runAt.After(t.curTime()) {
// task is not yet ready to run
break
}
heap.Pop(&t.heap) // remove the task from the heap
t.mu.Unlock()
return first
case <-t.notify: // new task was added, loop again to ensure we have the earliest task.
if !timer.Stop() {
<-timer.C
}
case <-ctx.Done():
if !timer.Stop() {
<-timer.C
}
t.mu.Unlock()
return nil
}
}
}
type scheduledTask struct {
task Task
runAt time.Time
}
type scheduledTaskHeap struct {
tasks []*scheduledTask
}
var _ heap.Interface = &scheduledTaskHeap{}
func (h *scheduledTaskHeap) Len() int {
return len(h.tasks)
}
func (h *scheduledTaskHeap) Less(i, j int) bool {
return h.tasks[i].runAt.Before(h.tasks[j].runAt)
}
func (h *scheduledTaskHeap) Swap(i, j int) {
h.tasks[i], h.tasks[j] = h.tasks[j], h.tasks[i]
}
func (h *scheduledTaskHeap) Push(x interface{}) {
h.tasks = append(h.tasks, x.(*scheduledTask))
}
func (h *scheduledTaskHeap) Pop() interface{} {
old := h.tasks
n := len(old)
x := old[n-1]
h.tasks = old[0 : n-1]
return x
}
func (h *scheduledTaskHeap) Peek() interface{} {
if len(h.tasks) == 0 {
return nil
}
return h.tasks[0]
}

View File

@@ -0,0 +1,85 @@
package orchestrator
import (
"context"
"reflect"
"testing"
"time"
)
type heapTestTask struct {
name string
}
var _ Task = &heapTestTask{}
func (t *heapTestTask) Name() string {
return t.name
}
func (t *heapTestTask) Next(now time.Time) *time.Time {
return nil
}
func (t *heapTestTask) Run(ctx context.Context) error {
return nil
}
func TestTaskQueueOrdering(t *testing.T) {
h := taskQueue{}
h.Push(scheduledTask{runAt: time.Now().Add(1 * time.Millisecond), task: &heapTestTask{name: "1"}})
h.Push(scheduledTask{runAt: time.Now().Add(2 * time.Millisecond), task: &heapTestTask{name: "2"}})
h.Push(scheduledTask{runAt: time.Now().Add(2 * time.Millisecond), task: &heapTestTask{name: "3"}})
wantSeq := []string{"1", "2", "3"}
seq := []string{}
for i := 0; i < 3; i++ {
task := h.Dequeue(context.Background())
if task == nil || task.task == nil {
t.Fatal("expected task")
}
seq = append(seq, task.task.Name())
}
if !reflect.DeepEqual(seq, wantSeq) {
t.Errorf("got %v, want %v", seq, wantSeq)
}
}
func TestLiveTaskEnqueue(t *testing.T) {
h := taskQueue{}
go func() {
time.Sleep(1 * time.Millisecond)
h.Push(scheduledTask{runAt: time.Now().Add(1 * time.Millisecond), task: &heapTestTask{name: "1"}})
}()
t1 := h.Dequeue(context.Background())
if t1.task.Name() != "1" {
t.Errorf("got %s, want 1", t1.task.Name())
}
}
func TestTaskQueueReset(t *testing.T) {
h := taskQueue{}
h.Push(scheduledTask{runAt: time.Now().Add(1 * time.Millisecond), task: &heapTestTask{name: "1"}})
h.Push(scheduledTask{runAt: time.Now().Add(2 * time.Millisecond), task: &heapTestTask{name: "2"}})
h.Push(scheduledTask{runAt: time.Now().Add(2 * time.Millisecond), task: &heapTestTask{name: "3"}})
if h.Dequeue(context.Background()).task.Name() != "1" {
t.Fatal("expected 1")
}
h.Reset()
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(1 * time.Millisecond)
cancel()
}()
if h.Dequeue(ctx) != nil {
t.Fatal("expected nil task")
}
}