From 08ebdfb58cd85a5ec07885e15f7ca254507f9916 Mon Sep 17 00:00:00 2001 From: Gareth George Date: Sun, 3 May 2026 17:46:07 -0700 Subject: [PATCH] cancellation behavior fixes --- internal/api/syncapi/cmdstreamutil.go | 78 +++++++++++++++++++++------ internal/api/syncapi/synccommon.go | 7 ++- 2 files changed, 67 insertions(+), 18 deletions(-) diff --git a/internal/api/syncapi/cmdstreamutil.go b/internal/api/syncapi/cmdstreamutil.go index feb4b5e0..4c893dc5 100644 --- a/internal/api/syncapi/cmdstreamutil.go +++ b/internal/api/syncapi/cmdstreamutil.go @@ -6,6 +6,8 @@ import ( "errors" "fmt" "io" + "sync" + "sync/atomic" "time" "connectrpc.com/connect" @@ -23,16 +25,22 @@ var _ syncCommandStreamTrait = (*connect.BidiStream[v1sync.SyncStreamItem, v1syn var _ syncCommandStreamTrait = (*connect.BidiStreamForClient[v1sync.SyncStreamItem, v1sync.SyncStreamItem])(nil) type bidiSyncCommandStream struct { - sendChan chan *v1sync.SyncStreamItem - recvChan chan *v1sync.SyncStreamItem - terminateWithErrChan chan error + sendChan chan *v1sync.SyncStreamItem + recvChan chan *v1sync.SyncStreamItem + + // done is closed exactly once when the stream is terminated. Readers can + // observe termination by selecting on it; the cause (if any) is stored in + // terminateErr. + done chan struct{} + doneOnce sync.Once + terminateErr atomic.Pointer[error] } func newBidiSyncCommandStream() *bidiSyncCommandStream { return &bidiSyncCommandStream{ - sendChan: make(chan *v1sync.SyncStreamItem, 256), - recvChan: make(chan *v1sync.SyncStreamItem, 1), - terminateWithErrChan: make(chan error, 1), + sendChan: make(chan *v1sync.SyncStreamItem, 256), + recvChan: make(chan *v1sync.SyncStreamItem, 1), + done: make(chan struct{}), } } @@ -48,23 +56,57 @@ func (s *bidiSyncCommandStream) Send(item *v1sync.SyncStreamItem) { } } +// SendErrorAndTerminate marks the stream as terminated. The first call wins: +// its err (if non-nil) is the one returned by Err. Subsequent calls are no-ops. +// Safe to call from any goroutine; non-blocking. func (s *bidiSyncCommandStream) SendErrorAndTerminate(err error) { - select { - case s.terminateWithErrChan <- err: - default: + s.doneOnce.Do(func() { + if err != nil { + errCopy := err + s.terminateErr.Store(&errCopy) + } + close(s.done) + }) +} + +// Err returns the termination error, or nil if the stream has not been +// terminated or was terminated without an error. +func (s *bidiSyncCommandStream) Err() error { + if errPtr := s.terminateErr.Load(); errPtr != nil { + return *errPtr } + return nil +} + +// Done returns a channel that is closed when the stream is terminated. +func (s *bidiSyncCommandStream) Done() <-chan struct{} { + return s.done } func (s *bidiSyncCommandStream) ReadChannel() chan *v1sync.SyncStreamItem { return s.recvChan } -func (s *bidiSyncCommandStream) ReceiveWithinDuration(d time.Duration) *v1sync.SyncStreamItem { +// ReceiveWithinDuration waits up to d for the next stream item. The returned +// error explains why no item arrived: ctx.Err() if ctx is cancelled, the +// stream's termination error (which may itself be nil) if the stream is +// terminated, or context.DeadlineExceeded if d elapses first. A nil item with +// a nil error means the stream was terminated cleanly with no cause. +func (s *bidiSyncCommandStream) ReceiveWithinDuration(ctx context.Context, d time.Duration) (*v1sync.SyncStreamItem, error) { + timer := time.NewTimer(d) + defer timer.Stop() select { - case item := <-s.recvChan: - return item - case <-time.After(d): - return nil + case item, ok := <-s.recvChan: + if !ok { + return nil, s.Err() + } + return item, nil + case <-s.done: + return nil, s.Err() + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + return nil, context.DeadlineExceeded } } @@ -78,6 +120,10 @@ func (s *bidiSyncCommandStream) ConnectStream(ctx context.Context, stream syncCo // Perform ECDH key exchange on the raw transport before starting the pump. transport, err := establishEncryption(stream) if err != nil { + // Signal termination so any goroutine parked in ReceiveWithinDuration + // (e.g. runSync waiting for the handshake reply) wakes up immediately + // instead of waiting out its full timeout. + s.SendErrorAndTerminate(err) return err } @@ -110,8 +156,8 @@ func (s *bidiSyncCommandStream) ConnectStream(ctx context.Context, stream syncCo s.SendErrorAndTerminate(err) return err } - case err := <-s.terminateWithErrChan: - return err + case <-s.done: + return s.Err() case <-ctx.Done(): return ctx.Err() } diff --git a/internal/api/syncapi/synccommon.go b/internal/api/syncapi/synccommon.go index 30974f83..dac52ea2 100644 --- a/internal/api/syncapi/synccommon.go +++ b/internal/api/syncapi/synccommon.go @@ -46,9 +46,12 @@ func runSync( commandStream.Send(handshakePacket) // Wait for the handshake packet to be acknowledged by the peer. - handshake := commandStream.ReceiveWithinDuration(15 * time.Second) + handshake, err := commandStream.ReceiveWithinDuration(ctx, 15*time.Second) + if err != nil { + return NewSyncErrorAuth(fmt.Errorf("waiting for handshake packet from peer: %w", err)) + } if handshake == nil { - return NewSyncErrorAuth(fmt.Errorf("no handshake packet received from peer within timeout")) + return NewSyncErrorAuth(fmt.Errorf("no handshake packet received from peer")) } if _, err := verifyHandshakePacket(handshake); err != nil { return NewSyncErrorAuth(fmt.Errorf("verifying handshake packet: %w", err))