mirror of
https://github.com/garethgeorge/backrest.git
synced 2026-05-04 03:50:30 +00:00
cancellation behavior fixes
This commit is contained in:
@@ -6,6 +6,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
@@ -23,16 +25,22 @@ var _ syncCommandStreamTrait = (*connect.BidiStream[v1sync.SyncStreamItem, v1syn
|
||||
var _ syncCommandStreamTrait = (*connect.BidiStreamForClient[v1sync.SyncStreamItem, v1sync.SyncStreamItem])(nil)
|
||||
|
||||
type bidiSyncCommandStream struct {
|
||||
sendChan chan *v1sync.SyncStreamItem
|
||||
recvChan chan *v1sync.SyncStreamItem
|
||||
terminateWithErrChan chan error
|
||||
sendChan chan *v1sync.SyncStreamItem
|
||||
recvChan chan *v1sync.SyncStreamItem
|
||||
|
||||
// done is closed exactly once when the stream is terminated. Readers can
|
||||
// observe termination by selecting on it; the cause (if any) is stored in
|
||||
// terminateErr.
|
||||
done chan struct{}
|
||||
doneOnce sync.Once
|
||||
terminateErr atomic.Pointer[error]
|
||||
}
|
||||
|
||||
func newBidiSyncCommandStream() *bidiSyncCommandStream {
|
||||
return &bidiSyncCommandStream{
|
||||
sendChan: make(chan *v1sync.SyncStreamItem, 256),
|
||||
recvChan: make(chan *v1sync.SyncStreamItem, 1),
|
||||
terminateWithErrChan: make(chan error, 1),
|
||||
sendChan: make(chan *v1sync.SyncStreamItem, 256),
|
||||
recvChan: make(chan *v1sync.SyncStreamItem, 1),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,23 +56,57 @@ func (s *bidiSyncCommandStream) Send(item *v1sync.SyncStreamItem) {
|
||||
}
|
||||
}
|
||||
|
||||
// SendErrorAndTerminate marks the stream as terminated. The first call wins:
|
||||
// its err (if non-nil) is the one returned by Err. Subsequent calls are no-ops.
|
||||
// Safe to call from any goroutine; non-blocking.
|
||||
func (s *bidiSyncCommandStream) SendErrorAndTerminate(err error) {
|
||||
select {
|
||||
case s.terminateWithErrChan <- err:
|
||||
default:
|
||||
s.doneOnce.Do(func() {
|
||||
if err != nil {
|
||||
errCopy := err
|
||||
s.terminateErr.Store(&errCopy)
|
||||
}
|
||||
close(s.done)
|
||||
})
|
||||
}
|
||||
|
||||
// Err returns the termination error, or nil if the stream has not been
|
||||
// terminated or was terminated without an error.
|
||||
func (s *bidiSyncCommandStream) Err() error {
|
||||
if errPtr := s.terminateErr.Load(); errPtr != nil {
|
||||
return *errPtr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Done returns a channel that is closed when the stream is terminated.
|
||||
func (s *bidiSyncCommandStream) Done() <-chan struct{} {
|
||||
return s.done
|
||||
}
|
||||
|
||||
func (s *bidiSyncCommandStream) ReadChannel() chan *v1sync.SyncStreamItem {
|
||||
return s.recvChan
|
||||
}
|
||||
|
||||
func (s *bidiSyncCommandStream) ReceiveWithinDuration(d time.Duration) *v1sync.SyncStreamItem {
|
||||
// ReceiveWithinDuration waits up to d for the next stream item. The returned
|
||||
// error explains why no item arrived: ctx.Err() if ctx is cancelled, the
|
||||
// stream's termination error (which may itself be nil) if the stream is
|
||||
// terminated, or context.DeadlineExceeded if d elapses first. A nil item with
|
||||
// a nil error means the stream was terminated cleanly with no cause.
|
||||
func (s *bidiSyncCommandStream) ReceiveWithinDuration(ctx context.Context, d time.Duration) (*v1sync.SyncStreamItem, error) {
|
||||
timer := time.NewTimer(d)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case item := <-s.recvChan:
|
||||
return item
|
||||
case <-time.After(d):
|
||||
return nil
|
||||
case item, ok := <-s.recvChan:
|
||||
if !ok {
|
||||
return nil, s.Err()
|
||||
}
|
||||
return item, nil
|
||||
case <-s.done:
|
||||
return nil, s.Err()
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil, context.DeadlineExceeded
|
||||
}
|
||||
}
|
||||
|
||||
@@ -78,6 +120,10 @@ func (s *bidiSyncCommandStream) ConnectStream(ctx context.Context, stream syncCo
|
||||
// Perform ECDH key exchange on the raw transport before starting the pump.
|
||||
transport, err := establishEncryption(stream)
|
||||
if err != nil {
|
||||
// Signal termination so any goroutine parked in ReceiveWithinDuration
|
||||
// (e.g. runSync waiting for the handshake reply) wakes up immediately
|
||||
// instead of waiting out its full timeout.
|
||||
s.SendErrorAndTerminate(err)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -110,8 +156,8 @@ func (s *bidiSyncCommandStream) ConnectStream(ctx context.Context, stream syncCo
|
||||
s.SendErrorAndTerminate(err)
|
||||
return err
|
||||
}
|
||||
case err := <-s.terminateWithErrChan:
|
||||
return err
|
||||
case <-s.done:
|
||||
return s.Err()
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
@@ -46,9 +46,12 @@ func runSync(
|
||||
commandStream.Send(handshakePacket)
|
||||
|
||||
// Wait for the handshake packet to be acknowledged by the peer.
|
||||
handshake := commandStream.ReceiveWithinDuration(15 * time.Second)
|
||||
handshake, err := commandStream.ReceiveWithinDuration(ctx, 15*time.Second)
|
||||
if err != nil {
|
||||
return NewSyncErrorAuth(fmt.Errorf("waiting for handshake packet from peer: %w", err))
|
||||
}
|
||||
if handshake == nil {
|
||||
return NewSyncErrorAuth(fmt.Errorf("no handshake packet received from peer within timeout"))
|
||||
return NewSyncErrorAuth(fmt.Errorf("no handshake packet received from peer"))
|
||||
}
|
||||
if _, err := verifyHandshakePacket(handshake); err != nil {
|
||||
return NewSyncErrorAuth(fmt.Errorf("verifying handshake packet: %w", err))
|
||||
|
||||
Reference in New Issue
Block a user