Files

288 lines
10 KiB
Go

package syncapi
import (
"context"
"errors"
"fmt"
"io"
"sync"
"sync/atomic"
"time"
"connectrpc.com/connect"
"github.com/garethgeorge/backrest/gen/go/v1sync"
"github.com/garethgeorge/backrest/internal/cryptoutil"
"go.uber.org/zap"
)
type syncCommandStreamTrait interface {
Send(item *v1sync.SyncStreamItem) error
Receive() (*v1sync.SyncStreamItem, error)
}
var _ syncCommandStreamTrait = (*connect.BidiStream[v1sync.SyncStreamItem, v1sync.SyncStreamItem])(nil)
var _ syncCommandStreamTrait = (*connect.BidiStreamForClient[v1sync.SyncStreamItem, v1sync.SyncStreamItem])(nil)
type bidiSyncCommandStream struct {
sendChan chan *v1sync.SyncStreamItem
recvChan chan *v1sync.SyncStreamItem
// done is closed exactly once when the stream is terminated. Readers can
// observe termination by selecting on it; the cause (if any) is stored in
// terminateErr.
done chan struct{}
doneOnce sync.Once
terminateErr atomic.Pointer[error]
// transcript is the post-quantum transport transcript that runSync signs
// and verifies in its handshake. It is published once, by ConnectStream,
// after the KEM exchange succeeds, then closed via transcriptReady.
// Callers retrieve it through AwaitTranscript.
transcript []byte
transcriptReady chan struct{}
}
func newBidiSyncCommandStream() *bidiSyncCommandStream {
return &bidiSyncCommandStream{
sendChan: make(chan *v1sync.SyncStreamItem, 256),
recvChan: make(chan *v1sync.SyncStreamItem, 1),
done: make(chan struct{}),
transcriptReady: make(chan struct{}),
}
}
// AwaitTranscript blocks until the transport-layer handshake completes and
// returns the transcript bytes that the higher-layer identity exchange must
// sign. Returns ctx.Err() if ctx is cancelled, or the stream's termination
// error if the connection failed before the transcript could be published.
func (s *bidiSyncCommandStream) AwaitTranscript(ctx context.Context) ([]byte, error) {
select {
case <-s.transcriptReady:
return s.transcript, nil
case <-s.done:
if err := s.Err(); err != nil {
return nil, err
}
return nil, errors.New("stream terminated before transport transcript was available")
case <-ctx.Done():
return nil, ctx.Err()
}
}
func (s *bidiSyncCommandStream) Send(item *v1sync.SyncStreamItem) {
select {
case s.sendChan <- item:
default:
select {
case s.sendChan <- item:
case <-time.After(100 * time.Millisecond):
s.SendErrorAndTerminate(NewSyncErrorDisconnected(errors.New("send channel is full, cannot send item")))
}
}
}
// SendErrorAndTerminate marks the stream as terminated. The first call wins:
// its err (if non-nil) is the one returned by Err. Subsequent calls are no-ops.
// Safe to call from any goroutine; non-blocking.
func (s *bidiSyncCommandStream) SendErrorAndTerminate(err error) {
s.doneOnce.Do(func() {
if err != nil {
errCopy := err
s.terminateErr.Store(&errCopy)
}
close(s.done)
})
}
// Err returns the termination error, or nil if the stream has not been
// terminated or was terminated without an error.
func (s *bidiSyncCommandStream) Err() error {
if errPtr := s.terminateErr.Load(); errPtr != nil {
return *errPtr
}
return nil
}
// Done returns a channel that is closed when the stream is terminated.
func (s *bidiSyncCommandStream) Done() <-chan struct{} {
return s.done
}
func (s *bidiSyncCommandStream) ReadChannel() chan *v1sync.SyncStreamItem {
return s.recvChan
}
// ReceiveWithinDuration waits up to d for the next stream item. The returned
// error explains why no item arrived: ctx.Err() if ctx is cancelled, the
// stream's termination error (which may itself be nil) if the stream is
// terminated, or context.DeadlineExceeded if d elapses first. A nil item with
// a nil error means the stream was terminated cleanly with no cause.
func (s *bidiSyncCommandStream) ReceiveWithinDuration(ctx context.Context, d time.Duration) (*v1sync.SyncStreamItem, error) {
timer := time.NewTimer(d)
defer timer.Stop()
select {
case item, ok := <-s.recvChan:
if !ok {
return nil, s.Err()
}
return item, nil
case <-s.done:
return nil, s.Err()
case <-ctx.Done():
return nil, ctx.Err()
case <-timer.C:
return nil, context.DeadlineExceeded
}
}
// ConnectStream bridges the channel-based bidiSyncCommandStream to a real transport.
// It first performs a post-quantum KEM handshake on the raw transport to
// establish an encrypted session, then starts the send/recv pump loop over the
// encrypted channel. isInitiator must be true on the side that opens the
// connection (the client) and false on the side that accepts it (the server).
func (s *bidiSyncCommandStream) ConnectStream(ctx context.Context, stream syncCommandStreamTrait, isInitiator bool) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// Perform the PQ KEM handshake on the raw transport before starting the pump.
transport, transcript, err := establishEncryption(stream, isInitiator)
if err != nil {
// Signal termination so any goroutine parked in ReceiveWithinDuration
// (e.g. runSync waiting for the handshake reply) wakes up immediately
// instead of waiting out its full timeout.
s.SendErrorAndTerminate(err)
return err
}
// Publish the transcript so runSync can sign / verify the handshake packet.
s.transcript = transcript
close(s.transcriptReady)
go func() {
defer close(s.recvChan)
for {
val, err := transport.Receive()
if err != nil {
s.SendErrorAndTerminate(NewSyncErrorDisconnected(fmt.Errorf("receiving item: %w", err)))
return
}
select {
case s.recvChan <- val:
case <-ctx.Done():
return
}
}
}()
for {
select {
case item := <-s.sendChan:
if item == nil {
continue
}
if err := transport.Send(item); err != nil {
if errors.Is(err, io.EOF) {
err = fmt.Errorf("connection failed or dropped: %w", err)
}
s.SendErrorAndTerminate(err)
return err
}
case <-s.done:
return s.Err()
case <-ctx.Done():
return ctx.Err()
}
}
}
// establishEncryption performs a post-quantum KEM handshake on the raw
// transport and returns an encrypted stream wrapper plus the transport
// transcript that runSync uses to bind its ed25519 handshake signature to
// this specific KEM exchange. The KEM ciphersuite is hard-pinned to
// TransportProtocolVersion and tied to the wire format.
//
// Flow: the initiator generates an ephemeral hybrid (ML-KEM-1024 + ECDH-P384)
// HPKE keypair and sends its public key. The responder encapsulates against
// it and replies with the encapsulation. Both sides derive AES-256-GCM
// per-direction session keys via the HPKE Export interface and compute the
// same transcript hash. The application-layer identity handshake (signature
// over the transcript) runs over the encrypted channel afterward.
//
// isInitiator must be true on the connecting side (client) and false on the
// accepting side (server).
func establishEncryption(stream syncCommandStreamTrait, isInitiator bool) (syncCommandStreamTrait, []byte, error) {
if isInitiator {
recipient, pubBytes, err := cryptoutil.NewTransportRecipient()
if err != nil {
return nil, nil, NewSyncErrorInternal(fmt.Errorf("generating ephemeral KEM key: %w", err))
}
if err := stream.Send(&v1sync.SyncStreamItem{
Action: &v1sync.SyncStreamItem_EstablishSharedSecret{
EstablishSharedSecret: &v1sync.SyncStreamItem_SyncEstablishSharedSecret{
ProtocolVersion: cryptoutil.TransportProtocolVersion,
KemPublicKey: pubBytes,
},
},
}); err != nil {
return nil, nil, NewSyncErrorDisconnected(fmt.Errorf("sending KEM public key: %w", err))
}
peerMsg, err := stream.Receive()
if err != nil {
return nil, nil, NewSyncErrorDisconnected(fmt.Errorf("receiving KEM encapsulation: %w", err))
}
peerSecret := peerMsg.GetEstablishSharedSecret()
if peerSecret == nil {
return nil, nil, NewSyncErrorProtocol(fmt.Errorf("expected KEM key exchange, got %T", peerMsg.GetAction()))
}
if peerSecret.GetProtocolVersion() != cryptoutil.TransportProtocolVersion {
return nil, nil, NewSyncErrorProtocol(fmt.Errorf("unsupported transport protocol version %d (this build requires v%d, post-quantum)", peerSecret.GetProtocolVersion(), cryptoutil.TransportProtocolVersion))
}
if len(peerSecret.GetKemEncapsulation()) == 0 {
return nil, nil, NewSyncErrorProtocol(errors.New("responder did not send KEM encapsulation"))
}
sess, err := recipient.Decapsulate(peerSecret.GetKemEncapsulation())
if err != nil {
return nil, nil, NewSyncErrorProtocol(fmt.Errorf("decapsulating KEM: %w", err))
}
zap.L().Info("encrypted sync session established (initiator)")
return newEncryptedStream(stream, sess.Send, sess.Recv), sess.Transcript(), nil
}
peerMsg, err := stream.Receive()
if err != nil {
return nil, nil, NewSyncErrorDisconnected(fmt.Errorf("receiving KEM public key: %w", err))
}
peerSecret := peerMsg.GetEstablishSharedSecret()
if peerSecret == nil {
return nil, nil, NewSyncErrorProtocol(fmt.Errorf("expected KEM key exchange, got %T", peerMsg.GetAction()))
}
if peerSecret.GetProtocolVersion() != cryptoutil.TransportProtocolVersion {
return nil, nil, NewSyncErrorProtocol(fmt.Errorf("unsupported transport protocol version %d (this build requires v%d, post-quantum)", peerSecret.GetProtocolVersion(), cryptoutil.TransportProtocolVersion))
}
if len(peerSecret.GetKemPublicKey()) == 0 {
return nil, nil, NewSyncErrorProtocol(errors.New("initiator did not send KEM public key"))
}
enc, sess, err := cryptoutil.EncapsulateToTransport(peerSecret.GetKemPublicKey())
if err != nil {
return nil, nil, NewSyncErrorProtocol(fmt.Errorf("encapsulating to KEM public key: %w", err))
}
if err := stream.Send(&v1sync.SyncStreamItem{
Action: &v1sync.SyncStreamItem_EstablishSharedSecret{
EstablishSharedSecret: &v1sync.SyncStreamItem_SyncEstablishSharedSecret{
ProtocolVersion: cryptoutil.TransportProtocolVersion,
KemEncapsulation: enc,
},
},
}); err != nil {
return nil, nil, NewSyncErrorDisconnected(fmt.Errorf("sending KEM encapsulation: %w", err))
}
zap.L().Info("encrypted sync session established (responder)")
return newEncryptedStream(stream, sess.Send, sess.Recv), sess.Transcript(), nil
}