From b6f7fd7ef899c0fe82b1d2f512b2b0d15c8cfc0f Mon Sep 17 00:00:00 2001 From: Dmitry Ng <19asdek91@gmail.com> Date: Fri, 1 May 2026 16:05:12 +0300 Subject: [PATCH] feat: enhance task and subtask handling with interruption management --- backend/pkg/controller/subtask.go | 31 +++++++++++++++++++++++++++++++ backend/pkg/controller/task.go | 31 +++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/backend/pkg/controller/subtask.go b/backend/pkg/controller/subtask.go index 6adab45..ba00229 100644 --- a/backend/pkg/controller/subtask.go +++ b/backend/pkg/controller/subtask.go @@ -5,10 +5,13 @@ import ( "errors" "fmt" "sync" + "time" "pentagi/pkg/database" obs "pentagi/pkg/observability" "pentagi/pkg/providers" + + "github.com/sirupsen/logrus" ) type TaskUpdater interface { @@ -281,6 +284,7 @@ func (stw *subtaskWorker) Run(ctx context.Context) error { } if err := stw.SetStatus(ctx, database.SubtaskStatusRunning); err != nil { + stw.handleInterrupting(err) return err } @@ -291,6 +295,7 @@ func (stw *subtaskWorker) Run(ctx context.Context) error { ) if err := stw.subtaskCtx.Provider.EnsureChainConsistency(ctx, msgChainID); err != nil { + stw.handleInterrupting(err) return fmt.Errorf("failed to ensure chain consistency for subtask %d: %w", subtaskID, err) } @@ -310,14 +315,17 @@ func (stw *subtaskWorker) Run(ctx context.Context) error { switch performResult { case providers.PerformResultWaiting: if err := stw.SetStatus(ctx, database.SubtaskStatusWaiting); err != nil { + stw.handleInterrupting(err) return err } case providers.PerformResultDone: if err := stw.SetStatus(ctx, database.SubtaskStatusFinished); err != nil { + stw.handleInterrupting(err) return fmt.Errorf("failed to set subtask %d status to finished: %w", subtaskID, err) } case providers.PerformResultError: if err := stw.SetStatus(ctx, database.SubtaskStatusFailed); err != nil { + stw.handleInterrupting(err) return fmt.Errorf("failed to set subtask %d status to failed: %w", subtaskID, err) } default: @@ -327,6 +335,29 @@ func (stw *subtaskWorker) Run(ctx context.Context) error { return nil } +// handleInterrupting sets this subtask (and task/flow via SetStatus back-propagation) +// to Waiting when err is context.Canceled or context.DeadlineExceeded. Use after the subtask +// was advanced past Waiting (e.g. Running) but the run aborts before PerformAgentChain's +// normal error handler, or when a late SetStatus fails with a context interruption. +func (stw *subtaskWorker) handleInterrupting(err error) { + if err == nil { + return + } + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + return + } + if stw.IsCompleted() { + return + } + + resetCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if errSt := stw.SetStatus(resetCtx, database.SubtaskStatusWaiting); errSt != nil { + logrus.WithError(errSt).Warn("failed to set subtask waiting after run interrupt") + } +} + func (stw *subtaskWorker) Finish(ctx context.Context) error { if stw.IsCompleted() { return fmt.Errorf("subtask has already completed") diff --git a/backend/pkg/controller/task.go b/backend/pkg/controller/task.go index 311c630..4de1739 100644 --- a/backend/pkg/controller/task.go +++ b/backend/pkg/controller/task.go @@ -5,11 +5,14 @@ import ( "errors" "fmt" "sync" + "time" "pentagi/pkg/database" obs "pentagi/pkg/observability" "pentagi/pkg/providers" "pentagi/pkg/tools" + + "github.com/sirupsen/logrus" ) type FlowUpdater interface { @@ -284,6 +287,7 @@ func (tw *taskWorker) Run(ctx context.Context) error { for len(tw.stc.ListSubtasks(ctx)) < providers.TasksNumberLimit+3 { st, err := tw.stc.PopSubtask(ctx, tw) if err != nil { + tw.handleInterrupting(err) return err } @@ -293,6 +297,7 @@ func (tw *taskWorker) Run(ctx context.Context) error { } if err := st.Run(ctx); err != nil { + tw.handleInterrupting(err) return err } @@ -312,6 +317,7 @@ func (tw *taskWorker) Run(ctx context.Context) error { jobResult, err := tw.taskCtx.Provider.GetTaskResult(ctx, tw.taskCtx.TaskID) if err != nil { + tw.handleInterrupting(err) return fmt.Errorf("failed to get task %d result: %w", tw.taskCtx.TaskID, err) } @@ -323,10 +329,12 @@ func (tw *taskWorker) Run(ctx context.Context) error { } if err := tw.SetResult(ctx, jobResult.Result); err != nil { + tw.handleInterrupting(err) return err } if err := tw.SetStatus(ctx, taskStatus); err != nil { + tw.handleInterrupting(err) return err } @@ -341,12 +349,35 @@ func (tw *taskWorker) Run(ctx context.Context) error { format, ) if err != nil { + tw.handleInterrupting(err) return fmt.Errorf("failed to put report for task %d: %w", tw.taskCtx.TaskID, err) } return nil } +// handleInterrupting sets this task (and the flow via taskWorker.SetStatus) +// to Waiting when err is context.Canceled or context.DeadlineExceeded. Skips if the task is +// already marked completed in memory (Finished/Failed) so we do not revive a finished task. +func (tw *taskWorker) handleInterrupting(err error) { + if err == nil { + return + } + if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { + return + } + if tw.IsCompleted() { + return + } + + resetCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if errSt := tw.SetStatus(resetCtx, database.TaskStatusWaiting); errSt != nil { + logrus.WithError(errSt).Warn("failed to set task waiting after run interrupt") + } +} + func (tw *taskWorker) Finish(ctx context.Context) error { if tw.IsCompleted() { return fmt.Errorf("task has already completed")