cancellation behavior fixes

This commit is contained in:
Gareth George
2026-05-03 17:46:07 -07:00
parent db53b114cb
commit 08ebdfb58c
2 changed files with 67 additions and 18 deletions
+62 -16
View File
@@ -6,6 +6,8 @@ import (
"errors"
"fmt"
"io"
"sync"
"sync/atomic"
"time"
"connectrpc.com/connect"
@@ -23,16 +25,22 @@ var _ syncCommandStreamTrait = (*connect.BidiStream[v1sync.SyncStreamItem, v1syn
var _ syncCommandStreamTrait = (*connect.BidiStreamForClient[v1sync.SyncStreamItem, v1sync.SyncStreamItem])(nil)
type bidiSyncCommandStream struct {
sendChan chan *v1sync.SyncStreamItem
recvChan chan *v1sync.SyncStreamItem
terminateWithErrChan chan error
sendChan chan *v1sync.SyncStreamItem
recvChan chan *v1sync.SyncStreamItem
// done is closed exactly once when the stream is terminated. Readers can
// observe termination by selecting on it; the cause (if any) is stored in
// terminateErr.
done chan struct{}
doneOnce sync.Once
terminateErr atomic.Pointer[error]
}
func newBidiSyncCommandStream() *bidiSyncCommandStream {
return &bidiSyncCommandStream{
sendChan: make(chan *v1sync.SyncStreamItem, 256),
recvChan: make(chan *v1sync.SyncStreamItem, 1),
terminateWithErrChan: make(chan error, 1),
sendChan: make(chan *v1sync.SyncStreamItem, 256),
recvChan: make(chan *v1sync.SyncStreamItem, 1),
done: make(chan struct{}),
}
}
@@ -48,23 +56,57 @@ func (s *bidiSyncCommandStream) Send(item *v1sync.SyncStreamItem) {
}
}
// SendErrorAndTerminate marks the stream as terminated. The first call wins:
// its err (if non-nil) is the one returned by Err. Subsequent calls are no-ops.
// Safe to call from any goroutine; non-blocking.
func (s *bidiSyncCommandStream) SendErrorAndTerminate(err error) {
select {
case s.terminateWithErrChan <- err:
default:
s.doneOnce.Do(func() {
if err != nil {
errCopy := err
s.terminateErr.Store(&errCopy)
}
close(s.done)
})
}
// Err returns the termination error, or nil if the stream has not been
// terminated or was terminated without an error.
func (s *bidiSyncCommandStream) Err() error {
if errPtr := s.terminateErr.Load(); errPtr != nil {
return *errPtr
}
return nil
}
// Done returns a channel that is closed when the stream is terminated.
func (s *bidiSyncCommandStream) Done() <-chan struct{} {
return s.done
}
func (s *bidiSyncCommandStream) ReadChannel() chan *v1sync.SyncStreamItem {
return s.recvChan
}
func (s *bidiSyncCommandStream) ReceiveWithinDuration(d time.Duration) *v1sync.SyncStreamItem {
// ReceiveWithinDuration waits up to d for the next stream item. The returned
// error explains why no item arrived: ctx.Err() if ctx is cancelled, the
// stream's termination error (which may itself be nil) if the stream is
// terminated, or context.DeadlineExceeded if d elapses first. A nil item with
// a nil error means the stream was terminated cleanly with no cause.
func (s *bidiSyncCommandStream) ReceiveWithinDuration(ctx context.Context, d time.Duration) (*v1sync.SyncStreamItem, error) {
timer := time.NewTimer(d)
defer timer.Stop()
select {
case item := <-s.recvChan:
return item
case <-time.After(d):
return nil
case item, ok := <-s.recvChan:
if !ok {
return nil, s.Err()
}
return item, nil
case <-s.done:
return nil, s.Err()
case <-ctx.Done():
return nil, ctx.Err()
case <-timer.C:
return nil, context.DeadlineExceeded
}
}
@@ -78,6 +120,10 @@ func (s *bidiSyncCommandStream) ConnectStream(ctx context.Context, stream syncCo
// Perform ECDH key exchange on the raw transport before starting the pump.
transport, err := establishEncryption(stream)
if err != nil {
// Signal termination so any goroutine parked in ReceiveWithinDuration
// (e.g. runSync waiting for the handshake reply) wakes up immediately
// instead of waiting out its full timeout.
s.SendErrorAndTerminate(err)
return err
}
@@ -110,8 +156,8 @@ func (s *bidiSyncCommandStream) ConnectStream(ctx context.Context, stream syncCo
s.SendErrorAndTerminate(err)
return err
}
case err := <-s.terminateWithErrChan:
return err
case <-s.done:
return s.Err()
case <-ctx.Done():
return ctx.Err()
}
+5 -2
View File
@@ -46,9 +46,12 @@ func runSync(
commandStream.Send(handshakePacket)
// Wait for the handshake packet to be acknowledged by the peer.
handshake := commandStream.ReceiveWithinDuration(15 * time.Second)
handshake, err := commandStream.ReceiveWithinDuration(ctx, 15*time.Second)
if err != nil {
return NewSyncErrorAuth(fmt.Errorf("waiting for handshake packet from peer: %w", err))
}
if handshake == nil {
return NewSyncErrorAuth(fmt.Errorf("no handshake packet received from peer within timeout"))
return NewSyncErrorAuth(fmt.Errorf("no handshake packet received from peer"))
}
if _, err := verifyHandshakePacket(handshake); err != nil {
return NewSyncErrorAuth(fmt.Errorf("verifying handshake packet: %w", err))