feat: enhance task and subtask handling with interruption management

This commit is contained in:
Dmitry Ng
2026-05-01 16:05:12 +03:00
parent 3da56244dd
commit b6f7fd7ef8
2 changed files with 62 additions and 0 deletions
+31
View File
@@ -5,10 +5,13 @@ import (
"errors"
"fmt"
"sync"
"time"
"pentagi/pkg/database"
obs "pentagi/pkg/observability"
"pentagi/pkg/providers"
"github.com/sirupsen/logrus"
)
type TaskUpdater interface {
@@ -281,6 +284,7 @@ func (stw *subtaskWorker) Run(ctx context.Context) error {
}
if err := stw.SetStatus(ctx, database.SubtaskStatusRunning); err != nil {
stw.handleInterrupting(err)
return err
}
@@ -291,6 +295,7 @@ func (stw *subtaskWorker) Run(ctx context.Context) error {
)
if err := stw.subtaskCtx.Provider.EnsureChainConsistency(ctx, msgChainID); err != nil {
stw.handleInterrupting(err)
return fmt.Errorf("failed to ensure chain consistency for subtask %d: %w", subtaskID, err)
}
@@ -310,14 +315,17 @@ func (stw *subtaskWorker) Run(ctx context.Context) error {
switch performResult {
case providers.PerformResultWaiting:
if err := stw.SetStatus(ctx, database.SubtaskStatusWaiting); err != nil {
stw.handleInterrupting(err)
return err
}
case providers.PerformResultDone:
if err := stw.SetStatus(ctx, database.SubtaskStatusFinished); err != nil {
stw.handleInterrupting(err)
return fmt.Errorf("failed to set subtask %d status to finished: %w", subtaskID, err)
}
case providers.PerformResultError:
if err := stw.SetStatus(ctx, database.SubtaskStatusFailed); err != nil {
stw.handleInterrupting(err)
return fmt.Errorf("failed to set subtask %d status to failed: %w", subtaskID, err)
}
default:
@@ -327,6 +335,29 @@ func (stw *subtaskWorker) Run(ctx context.Context) error {
return nil
}
// handleInterrupting sets this subtask (and task/flow via SetStatus back-propagation)
// to Waiting when err is context.Canceled or context.DeadlineExceeded. Use after the subtask
// was advanced past Waiting (e.g. Running) but the run aborts before PerformAgentChain's
// normal error handler, or when a late SetStatus fails with a context interruption.
func (stw *subtaskWorker) handleInterrupting(err error) {
if err == nil {
return
}
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
return
}
if stw.IsCompleted() {
return
}
resetCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if errSt := stw.SetStatus(resetCtx, database.SubtaskStatusWaiting); errSt != nil {
logrus.WithError(errSt).Warn("failed to set subtask waiting after run interrupt")
}
}
func (stw *subtaskWorker) Finish(ctx context.Context) error {
if stw.IsCompleted() {
return fmt.Errorf("subtask has already completed")
+31
View File
@@ -5,11 +5,14 @@ import (
"errors"
"fmt"
"sync"
"time"
"pentagi/pkg/database"
obs "pentagi/pkg/observability"
"pentagi/pkg/providers"
"pentagi/pkg/tools"
"github.com/sirupsen/logrus"
)
type FlowUpdater interface {
@@ -284,6 +287,7 @@ func (tw *taskWorker) Run(ctx context.Context) error {
for len(tw.stc.ListSubtasks(ctx)) < providers.TasksNumberLimit+3 {
st, err := tw.stc.PopSubtask(ctx, tw)
if err != nil {
tw.handleInterrupting(err)
return err
}
@@ -293,6 +297,7 @@ func (tw *taskWorker) Run(ctx context.Context) error {
}
if err := st.Run(ctx); err != nil {
tw.handleInterrupting(err)
return err
}
@@ -312,6 +317,7 @@ func (tw *taskWorker) Run(ctx context.Context) error {
jobResult, err := tw.taskCtx.Provider.GetTaskResult(ctx, tw.taskCtx.TaskID)
if err != nil {
tw.handleInterrupting(err)
return fmt.Errorf("failed to get task %d result: %w", tw.taskCtx.TaskID, err)
}
@@ -323,10 +329,12 @@ func (tw *taskWorker) Run(ctx context.Context) error {
}
if err := tw.SetResult(ctx, jobResult.Result); err != nil {
tw.handleInterrupting(err)
return err
}
if err := tw.SetStatus(ctx, taskStatus); err != nil {
tw.handleInterrupting(err)
return err
}
@@ -341,12 +349,35 @@ func (tw *taskWorker) Run(ctx context.Context) error {
format,
)
if err != nil {
tw.handleInterrupting(err)
return fmt.Errorf("failed to put report for task %d: %w", tw.taskCtx.TaskID, err)
}
return nil
}
// handleInterrupting sets this task (and the flow via taskWorker.SetStatus)
// to Waiting when err is context.Canceled or context.DeadlineExceeded. Skips if the task is
// already marked completed in memory (Finished/Failed) so we do not revive a finished task.
func (tw *taskWorker) handleInterrupting(err error) {
if err == nil {
return
}
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
return
}
if tw.IsCompleted() {
return
}
resetCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if errSt := tw.SetStatus(resetCtx, database.TaskStatusWaiting); errSt != nil {
logrus.WithError(errSt).Warn("failed to set task waiting after run interrupt")
}
}
func (tw *taskWorker) Finish(ctx context.Context) error {
if tw.IsCompleted() {
return fmt.Errorf("task has already completed")