feat: support task cancellation

This commit is contained in:
Gareth George
2023-12-20 08:54:45 +00:00
parent 95ca96a31f
commit fc9c06df00
15 changed files with 366 additions and 74 deletions

View File

@@ -6,6 +6,7 @@ import (
"fmt"
"slices"
"sync"
"sync/atomic"
"time"
v1 "github.com/garethgeorge/restora/gen/go/v1"
@@ -19,6 +20,7 @@ import (
var ErrRepoNotFound = errors.New("repo not found")
var ErrRepoInitializationFailed = errors.New("repo initialization failed")
var ErrPlanNotFound = errors.New("plan not found")
var ErrTaskNotFound = errors.New("task not found")
const (
TaskPriorityDefault = iota
@@ -38,6 +40,8 @@ type Orchestrator struct {
// now for the purpose of testing; used by Run() to get the current time.
now func() time.Time
runningTask atomic.Pointer[taskExecutionInfo]
}
func NewOrchestrator(resticBin string, cfg *v1.Config, oplog *oplog.OpLog) (*Orchestrator, error) {
@@ -163,6 +167,41 @@ func (o *Orchestrator) GetPlan(planId string) (*v1.Plan, error) {
return nil, ErrPlanNotFound
}
func (o *Orchestrator) CancelOperation(operationId int64, status v1.OperationStatus) error {
o.mu.Lock()
defer o.mu.Unlock()
// note: if the task is running the requested status will not be set.
if running := o.runningTask.Load(); running != nil && running.operationId == operationId {
running.cancel()
}
tasks := o.taskQueue.Reset()
remaining := make([]scheduledTask, 0, len(tasks))
for _, t := range tasks {
if t.task.OperationId() == operationId {
if err := t.task.Cancel(status); err != nil {
return fmt.Errorf("cancel task %q: %w", t.task.Name(), err)
}
// check if the task has a next after it's current 'runAt' time, if it does then we will schedule the next run.
if nextTime := t.task.Next(t.runAt); nextTime != nil {
remaining = append(remaining, scheduledTask{
task: t.task,
runAt: *nextTime,
})
}
} else {
remaining = append(remaining, *t)
}
}
o.taskQueue.Push(remaining...)
return nil
}
// Run is the main orchestration loop. Cancel the context to stop the loop.
func (o *Orchestrator) Run(mainCtx context.Context) {
zap.L().Info("starting orchestrator loop")
@@ -179,12 +218,24 @@ func (o *Orchestrator) Run(mainCtx context.Context) {
}
zap.L().Info("running task", zap.String("task", t.task.Name()))
if err := t.task.Run(mainCtx); err != nil {
taskCtx, cancel := context.WithCancel(mainCtx)
if swapped := o.runningTask.CompareAndSwap(nil, &taskExecutionInfo{
operationId: t.task.OperationId(),
cancel: cancel,
}); !swapped {
zap.L().Fatal("failed to start task, another task is already running. Was Run() called twice?")
}
if err := t.task.Run(taskCtx); err != nil {
zap.L().Error("task failed", zap.String("task", t.task.Name()), zap.Error(err))
} else {
zap.L().Info("task finished", zap.String("task", t.task.Name()))
}
o.runningTask.Store(nil)
if nextTime := t.task.Next(o.curTime()); nextTime != nil {
o.taskQueue.Push(scheduledTask{
task: t.task,
@@ -268,3 +319,8 @@ func (rp *resticRepoPool) GetRepo(repoId string) (repo *RepoOrchestrator, err er
rp.repos[repoId] = repo
return repo, nil
}
type taskExecutionInfo struct {
operationId int64
cancel func()
}