fix: implement an encryption layer around the syncapi transport

This commit is contained in:
Gareth George
2026-05-02 23:01:21 -07:00
parent b98b1dc3cc
commit 9f5c75431d
9 changed files with 836 additions and 149 deletions
+70 -13
View File
@@ -1,6 +1,7 @@
package syncapi
import (
"bytes"
"context"
"errors"
"fmt"
@@ -9,6 +10,8 @@ import (
"connectrpc.com/connect"
"github.com/garethgeorge/backrest/gen/go/v1sync"
"github.com/garethgeorge/backrest/internal/cryptoutil"
"go.uber.org/zap"
)
type syncCommandStreamTrait interface {
@@ -16,8 +19,8 @@ type syncCommandStreamTrait interface {
Receive() (*v1sync.SyncStreamItem, error)
}
var _ syncCommandStreamTrait = (*connect.BidiStream[v1sync.SyncStreamItem, v1sync.SyncStreamItem])(nil) // Ensure that connect.BidiStream implements syncCommandStreamTrait
var _ syncCommandStreamTrait = (*connect.BidiStreamForClient[v1sync.SyncStreamItem, v1sync.SyncStreamItem])(nil) // Ensure that connect.BidiStreamForClient implements syncCommandStreamTrait
var _ syncCommandStreamTrait = (*connect.BidiStream[v1sync.SyncStreamItem, v1sync.SyncStreamItem])(nil)
var _ syncCommandStreamTrait = (*connect.BidiStreamForClient[v1sync.SyncStreamItem, v1sync.SyncStreamItem])(nil)
type bidiSyncCommandStream struct {
sendChan chan *v1sync.SyncStreamItem
@@ -27,7 +30,7 @@ type bidiSyncCommandStream struct {
func newBidiSyncCommandStream() *bidiSyncCommandStream {
return &bidiSyncCommandStream{
sendChan: make(chan *v1sync.SyncStreamItem, 256), // Buffered channel to allow sending items without blocking
sendChan: make(chan *v1sync.SyncStreamItem, 256),
recvChan: make(chan *v1sync.SyncStreamItem, 1),
terminateWithErrChan: make(chan error, 1),
}
@@ -37,7 +40,6 @@ func (s *bidiSyncCommandStream) Send(item *v1sync.SyncStreamItem) {
select {
case s.sendChan <- item:
default:
// Try again with a timeout, if it fails, send an error to terminate the stream
select {
case s.sendChan <- item:
case <-time.After(100 * time.Millisecond):
@@ -46,14 +48,10 @@ func (s *bidiSyncCommandStream) Send(item *v1sync.SyncStreamItem) {
}
}
// SendErrorAndTerminate sends an error to the termination channel.
// If the error is nil, it terminates only.
func (s *bidiSyncCommandStream) SendErrorAndTerminate(err error) {
select {
case s.terminateWithErrChan <- err:
default:
// If the channel is full, we can't send the error, so we just ignore it.
// This is a best-effort termination.
}
}
@@ -66,17 +64,27 @@ func (s *bidiSyncCommandStream) ReceiveWithinDuration(d time.Duration) *v1sync.S
case item := <-s.recvChan:
return item
case <-time.After(d):
return nil // Return nil if no item is received within the duration
return nil
}
}
// ConnectStream bridges the channel-based bidiSyncCommandStream to a real transport.
// It first performs an ECDH key exchange on the raw transport to establish an encrypted
// session, then starts the send/recv pump loop over the encrypted channel.
func (s *bidiSyncCommandStream) ConnectStream(ctx context.Context, stream syncCommandStreamTrait) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// Perform ECDH key exchange on the raw transport before starting the pump.
transport, err := establishEncryption(stream)
if err != nil {
return err
}
go func() {
defer close(s.recvChan)
for {
val, err := stream.Receive()
val, err := transport.Receive()
if err != nil {
s.SendErrorAndTerminate(NewSyncErrorDisconnected(fmt.Errorf("receiving item: %w", err)))
return
@@ -95,7 +103,7 @@ func (s *bidiSyncCommandStream) ConnectStream(ctx context.Context, stream syncCo
if item == nil {
continue
}
if err := stream.Send(item); err != nil {
if err := transport.Send(item); err != nil {
if errors.Is(err, io.EOF) {
err = fmt.Errorf("connection failed or dropped: %w", err)
}
@@ -103,10 +111,59 @@ func (s *bidiSyncCommandStream) ConnectStream(ctx context.Context, stream syncCo
return err
}
case err := <-s.terminateWithErrChan:
return err // Terminate the stream with the error or nil if no error was sent
return err
case <-ctx.Done():
// Context is done, we should stop processing.
return ctx.Err()
}
}
}
// establishEncryption performs an ECDH key exchange on the raw transport and
// returns an encrypted stream wrapper. Each side generates an ephemeral ECDH P-256
// key pair, exchanges public keys, and derives a shared AES-256-GCM session key.
// The handshake (identity authentication) runs over the encrypted channel afterward.
func establishEncryption(stream syncCommandStreamTrait) (syncCommandStreamTrait, error) {
keyPair, err := cryptoutil.GenerateECDHKeyPair()
if err != nil {
return nil, NewSyncErrorInternal(fmt.Errorf("generating ephemeral ECDH key: %w", err))
}
// Send our ephemeral ECDH public key
if err := stream.Send(&v1sync.SyncStreamItem{
Action: &v1sync.SyncStreamItem_EstablishSharedSecret{
EstablishSharedSecret: &v1sync.SyncStreamItem_SyncEstablishSharedSecret{
EcdhPublicKey: keyPair.Public.Bytes(),
},
},
}); err != nil {
return nil, NewSyncErrorProtocol(fmt.Errorf("sending ECDH public key: %w", err))
}
// Receive the peer's ephemeral ECDH public key
peerMsg, err := stream.Receive()
if err != nil {
return nil, NewSyncErrorProtocol(fmt.Errorf("receiving ECDH public key: %w", err))
}
peerSecret := peerMsg.GetEstablishSharedSecret()
if peerSecret == nil {
return nil, NewSyncErrorProtocol(fmt.Errorf("expected ECDH key exchange, got %T", peerMsg.GetAction()))
}
peerECDHPub, err := cryptoutil.ParseECDHPublicKey(peerSecret.GetEcdhPublicKey())
if err != nil {
return nil, NewSyncErrorProtocol(fmt.Errorf("parsing peer ECDH public key: %w", err))
}
// Derive AES-256-GCM session key
gcm, err := cryptoutil.DeriveSessionKey(keyPair.Private, peerECDHPub)
if err != nil {
return nil, NewSyncErrorProtocol(fmt.Errorf("deriving session key: %w", err))
}
// Determine nonce direction: side with smaller public key uses prefix 0x00
localIsSmaller := bytes.Compare(keyPair.Public.Bytes(), peerECDHPub.Bytes()) < 0
zap.L().Info("encrypted sync session established")
return newEncryptedStream(stream, gcm, localIsSmaller), nil
}
+120
View File
@@ -0,0 +1,120 @@
package syncapi
import (
"crypto/cipher"
"encoding/binary"
"fmt"
"sync"
"github.com/garethgeorge/backrest/gen/go/v1sync"
"google.golang.org/protobuf/proto"
)
// encryptedStream wraps a syncCommandStreamTrait with AES-256-GCM encryption.
// Outgoing SyncStreamItems are serialized, encrypted, and sent as SyncActionEncrypted.
// Incoming SyncActionEncrypted messages are decrypted and deserialized back to SyncStreamItems.
//
// To avoid nonce reuse (since both sides share the same key), each direction
// uses a different nonce prefix byte: the side with the lexicographically smaller
// ECDH public key uses prefix 0x00 for sending and expects 0x01 for receiving,
// and vice versa.
type encryptedStream struct {
inner syncCommandStreamTrait
gcm cipher.AEAD
sendPrefix byte
recvPrefix byte
sendMu sync.Mutex
sendCounter uint64
recvMu sync.Mutex
recvCounter uint64
}
func newEncryptedStream(inner syncCommandStreamTrait, gcm cipher.AEAD, localIsSmaller bool) *encryptedStream {
var sendPrefix, recvPrefix byte
if localIsSmaller {
sendPrefix, recvPrefix = 0x00, 0x01
} else {
sendPrefix, recvPrefix = 0x01, 0x00
}
return &encryptedStream{
inner: inner,
gcm: gcm,
sendPrefix: sendPrefix,
recvPrefix: recvPrefix,
}
}
func (s *encryptedStream) Send(item *v1sync.SyncStreamItem) error {
plaintext, err := proto.Marshal(item)
if err != nil {
return fmt.Errorf("marshal for encryption: %w", err)
}
s.sendMu.Lock()
nonce := s.makeNonce(s.sendPrefix, s.sendCounter)
s.sendCounter++
s.sendMu.Unlock()
ciphertext := s.gcm.Seal(nil, nonce, plaintext, nil)
return s.inner.Send(&v1sync.SyncStreamItem{
Action: &v1sync.SyncStreamItem_Encrypted{
Encrypted: &v1sync.SyncStreamItem_SyncActionEncrypted{
Nonce: nonce,
Ciphertext: ciphertext,
},
},
})
}
func (s *encryptedStream) Receive() (*v1sync.SyncStreamItem, error) {
envelope, err := s.inner.Receive()
if err != nil {
return nil, err
}
encrypted := envelope.GetEncrypted()
if encrypted == nil {
return nil, fmt.Errorf("expected encrypted message, got %T", envelope.GetAction())
}
s.recvMu.Lock()
expectedNonce := s.makeNonce(s.recvPrefix, s.recvCounter)
s.recvCounter++
s.recvMu.Unlock()
if len(encrypted.Nonce) != s.gcm.NonceSize() {
return nil, fmt.Errorf("invalid nonce size: got %d, want %d", len(encrypted.Nonce), s.gcm.NonceSize())
}
// Verify nonce matches expected counter to prevent replay/reorder attacks
for i := range expectedNonce {
if expectedNonce[i] != encrypted.Nonce[i] {
return nil, fmt.Errorf("nonce mismatch: possible replay or reorder attack")
}
}
plaintext, err := s.gcm.Open(nil, encrypted.Nonce, encrypted.Ciphertext, nil)
if err != nil {
return nil, fmt.Errorf("decrypt message: %w", err)
}
var inner v1sync.SyncStreamItem
if err := proto.Unmarshal(plaintext, &inner); err != nil {
return nil, fmt.Errorf("unmarshal decrypted message: %w", err)
}
return &inner, nil
}
// makeNonce creates a 12-byte GCM nonce. The first byte is the direction prefix
// (0x00 or 0x01), bytes 1-3 are zero, and bytes 4-11 are the counter in big-endian.
func (s *encryptedStream) makeNonce(prefix byte, counter uint64) []byte {
nonce := make([]byte, s.gcm.NonceSize()) // 12 bytes for GCM
nonce[0] = prefix
binary.BigEndian.PutUint64(nonce[4:], counter)
return nonce
}
+186
View File
@@ -0,0 +1,186 @@
package syncapi
import (
"sync"
"testing"
"github.com/garethgeorge/backrest/gen/go/v1sync"
"github.com/garethgeorge/backrest/internal/cryptoutil"
)
// fakeStream is a pair of in-memory channels simulating a bidirectional transport.
type fakeStream struct {
sendCh chan *v1sync.SyncStreamItem
recvCh chan *v1sync.SyncStreamItem
}
func (f *fakeStream) Send(item *v1sync.SyncStreamItem) error {
f.sendCh <- item
return nil
}
func (f *fakeStream) Receive() (*v1sync.SyncStreamItem, error) {
item := <-f.recvCh
return item, nil
}
// newFakeStreamPair creates two connected fakeStreams (A's send is B's recv and vice versa).
func newFakeStreamPair() (*fakeStream, *fakeStream) {
ab := make(chan *v1sync.SyncStreamItem, 16)
ba := make(chan *v1sync.SyncStreamItem, 16)
return &fakeStream{sendCh: ab, recvCh: ba}, &fakeStream{sendCh: ba, recvCh: ab}
}
func TestEncryptedStream_RoundTrip(t *testing.T) {
alice, err := cryptoutil.GenerateECDHKeyPair()
if err != nil {
t.Fatal(err)
}
bob, err := cryptoutil.GenerateECDHKeyPair()
if err != nil {
t.Fatal(err)
}
gcm, err := cryptoutil.DeriveSessionKey(alice.Private, bob.Public)
if err != nil {
t.Fatal(err)
}
aliceIsSmaller := string(alice.Public.Bytes()) < string(bob.Public.Bytes())
transportA, transportB := newFakeStreamPair()
encA := newEncryptedStream(transportA, gcm, aliceIsSmaller)
encB := newEncryptedStream(transportB, gcm, !aliceIsSmaller)
// Send a heartbeat from A to B
sendItem := &v1sync.SyncStreamItem{
Action: &v1sync.SyncStreamItem_Heartbeat{
Heartbeat: &v1sync.SyncStreamItem_SyncActionHeartbeat{},
},
}
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
if err := encA.Send(sendItem); err != nil {
t.Errorf("send: %v", err)
}
}()
recvItem, err := encB.Receive()
if err != nil {
t.Fatalf("receive: %v", err)
}
wg.Wait()
if recvItem.GetHeartbeat() == nil {
t.Fatalf("expected heartbeat, got %T", recvItem.GetAction())
}
}
func TestEncryptedStream_BidirectionalMultiMessage(t *testing.T) {
alice, _ := cryptoutil.GenerateECDHKeyPair()
bob, _ := cryptoutil.GenerateECDHKeyPair()
gcm, _ := cryptoutil.DeriveSessionKey(alice.Private, bob.Public)
aliceIsSmaller := string(alice.Public.Bytes()) < string(bob.Public.Bytes())
transportA, transportB := newFakeStreamPair()
encA := newEncryptedStream(transportA, gcm, aliceIsSmaller)
encB := newEncryptedStream(transportB, gcm, !aliceIsSmaller)
heartbeat := &v1sync.SyncStreamItem{
Action: &v1sync.SyncStreamItem_Heartbeat{
Heartbeat: &v1sync.SyncStreamItem_SyncActionHeartbeat{},
},
}
// Send 5 messages A→B sequentially, then 5 messages B→A sequentially
var wg sync.WaitGroup
// A→B direction
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 5; i++ {
if err := encA.Send(heartbeat); err != nil {
t.Errorf("A send %d: %v", i, err)
}
}
}()
for i := 0; i < 5; i++ {
if _, err := encB.Receive(); err != nil {
t.Fatalf("B receive %d: %v", i, err)
}
}
wg.Wait()
// B→A direction
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < 5; i++ {
if err := encB.Send(heartbeat); err != nil {
t.Errorf("B send %d: %v", i, err)
}
}
}()
for i := 0; i < 5; i++ {
if _, err := encA.Receive(); err != nil {
t.Fatalf("A receive %d: %v", i, err)
}
}
wg.Wait()
}
func TestEstablishEncryption_Integration(t *testing.T) {
transportA, transportB := newFakeStreamPair()
var encA, encB syncCommandStreamTrait
var errA, errB error
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
encA, errA = establishEncryption(transportA)
}()
go func() {
defer wg.Done()
encB, errB = establishEncryption(transportB)
}()
wg.Wait()
if errA != nil {
t.Fatalf("establish A: %v", errA)
}
if errB != nil {
t.Fatalf("establish B: %v", errB)
}
// Verify encrypted communication works
heartbeat := &v1sync.SyncStreamItem{
Action: &v1sync.SyncStreamItem_Heartbeat{
Heartbeat: &v1sync.SyncStreamItem_SyncActionHeartbeat{},
},
}
wg.Add(1)
go func() {
defer wg.Done()
if err := encA.Send(heartbeat); err != nil {
t.Errorf("send: %v", err)
}
}()
recv, err := encB.Receive()
if err != nil {
t.Fatalf("receive: %v", err)
}
wg.Wait()
if recv.GetHeartbeat() == nil {
t.Fatalf("expected heartbeat, got %T", recv.GetAction())
}
}
+1 -1
View File
@@ -18,7 +18,7 @@ import (
"google.golang.org/protobuf/proto"
)
const SyncProtocolVersion = 1
const SyncProtocolVersion = 2
type BackrestSyncHandler struct {
v1syncconnect.UnimplementedBackrestSyncServiceHandler