mirror of
https://github.com/vxcontrol/pentagi.git
synced 2026-05-03 21:40:32 +00:00
feat: enhance task and subtask handling with interruption management
This commit is contained in:
@@ -5,10 +5,13 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"pentagi/pkg/database"
|
||||
obs "pentagi/pkg/observability"
|
||||
"pentagi/pkg/providers"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type TaskUpdater interface {
|
||||
@@ -281,6 +284,7 @@ func (stw *subtaskWorker) Run(ctx context.Context) error {
|
||||
}
|
||||
|
||||
if err := stw.SetStatus(ctx, database.SubtaskStatusRunning); err != nil {
|
||||
stw.handleInterrupting(err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -291,6 +295,7 @@ func (stw *subtaskWorker) Run(ctx context.Context) error {
|
||||
)
|
||||
|
||||
if err := stw.subtaskCtx.Provider.EnsureChainConsistency(ctx, msgChainID); err != nil {
|
||||
stw.handleInterrupting(err)
|
||||
return fmt.Errorf("failed to ensure chain consistency for subtask %d: %w", subtaskID, err)
|
||||
}
|
||||
|
||||
@@ -310,14 +315,17 @@ func (stw *subtaskWorker) Run(ctx context.Context) error {
|
||||
switch performResult {
|
||||
case providers.PerformResultWaiting:
|
||||
if err := stw.SetStatus(ctx, database.SubtaskStatusWaiting); err != nil {
|
||||
stw.handleInterrupting(err)
|
||||
return err
|
||||
}
|
||||
case providers.PerformResultDone:
|
||||
if err := stw.SetStatus(ctx, database.SubtaskStatusFinished); err != nil {
|
||||
stw.handleInterrupting(err)
|
||||
return fmt.Errorf("failed to set subtask %d status to finished: %w", subtaskID, err)
|
||||
}
|
||||
case providers.PerformResultError:
|
||||
if err := stw.SetStatus(ctx, database.SubtaskStatusFailed); err != nil {
|
||||
stw.handleInterrupting(err)
|
||||
return fmt.Errorf("failed to set subtask %d status to failed: %w", subtaskID, err)
|
||||
}
|
||||
default:
|
||||
@@ -327,6 +335,29 @@ func (stw *subtaskWorker) Run(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleInterrupting sets this subtask (and task/flow via SetStatus back-propagation)
|
||||
// to Waiting when err is context.Canceled or context.DeadlineExceeded. Use after the subtask
|
||||
// was advanced past Waiting (e.g. Running) but the run aborts before PerformAgentChain's
|
||||
// normal error handler, or when a late SetStatus fails with a context interruption.
|
||||
func (stw *subtaskWorker) handleInterrupting(err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
return
|
||||
}
|
||||
if stw.IsCompleted() {
|
||||
return
|
||||
}
|
||||
|
||||
resetCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if errSt := stw.SetStatus(resetCtx, database.SubtaskStatusWaiting); errSt != nil {
|
||||
logrus.WithError(errSt).Warn("failed to set subtask waiting after run interrupt")
|
||||
}
|
||||
}
|
||||
|
||||
func (stw *subtaskWorker) Finish(ctx context.Context) error {
|
||||
if stw.IsCompleted() {
|
||||
return fmt.Errorf("subtask has already completed")
|
||||
|
||||
@@ -5,11 +5,14 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"pentagi/pkg/database"
|
||||
obs "pentagi/pkg/observability"
|
||||
"pentagi/pkg/providers"
|
||||
"pentagi/pkg/tools"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
type FlowUpdater interface {
|
||||
@@ -284,6 +287,7 @@ func (tw *taskWorker) Run(ctx context.Context) error {
|
||||
for len(tw.stc.ListSubtasks(ctx)) < providers.TasksNumberLimit+3 {
|
||||
st, err := tw.stc.PopSubtask(ctx, tw)
|
||||
if err != nil {
|
||||
tw.handleInterrupting(err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -293,6 +297,7 @@ func (tw *taskWorker) Run(ctx context.Context) error {
|
||||
}
|
||||
|
||||
if err := st.Run(ctx); err != nil {
|
||||
tw.handleInterrupting(err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -312,6 +317,7 @@ func (tw *taskWorker) Run(ctx context.Context) error {
|
||||
|
||||
jobResult, err := tw.taskCtx.Provider.GetTaskResult(ctx, tw.taskCtx.TaskID)
|
||||
if err != nil {
|
||||
tw.handleInterrupting(err)
|
||||
return fmt.Errorf("failed to get task %d result: %w", tw.taskCtx.TaskID, err)
|
||||
}
|
||||
|
||||
@@ -323,10 +329,12 @@ func (tw *taskWorker) Run(ctx context.Context) error {
|
||||
}
|
||||
|
||||
if err := tw.SetResult(ctx, jobResult.Result); err != nil {
|
||||
tw.handleInterrupting(err)
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tw.SetStatus(ctx, taskStatus); err != nil {
|
||||
tw.handleInterrupting(err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -341,12 +349,35 @@ func (tw *taskWorker) Run(ctx context.Context) error {
|
||||
format,
|
||||
)
|
||||
if err != nil {
|
||||
tw.handleInterrupting(err)
|
||||
return fmt.Errorf("failed to put report for task %d: %w", tw.taskCtx.TaskID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleInterrupting sets this task (and the flow via taskWorker.SetStatus)
|
||||
// to Waiting when err is context.Canceled or context.DeadlineExceeded. Skips if the task is
|
||||
// already marked completed in memory (Finished/Failed) so we do not revive a finished task.
|
||||
func (tw *taskWorker) handleInterrupting(err error) {
|
||||
if err == nil {
|
||||
return
|
||||
}
|
||||
if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) {
|
||||
return
|
||||
}
|
||||
if tw.IsCompleted() {
|
||||
return
|
||||
}
|
||||
|
||||
resetCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if errSt := tw.SetStatus(resetCtx, database.TaskStatusWaiting); errSt != nil {
|
||||
logrus.WithError(errSt).Warn("failed to set task waiting after run interrupt")
|
||||
}
|
||||
}
|
||||
|
||||
func (tw *taskWorker) Finish(ctx context.Context) error {
|
||||
if tw.IsCompleted() {
|
||||
return fmt.Errorf("task has already completed")
|
||||
|
||||
Reference in New Issue
Block a user