mirror of
https://github.com/garethgeorge/backrest.git
synced 2026-05-06 04:50:35 +00:00
fix: implement an encryption layer around the syncapi transport
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package syncapi
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -9,6 +10,8 @@ import (
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/garethgeorge/backrest/gen/go/v1sync"
|
||||
"github.com/garethgeorge/backrest/internal/cryptoutil"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type syncCommandStreamTrait interface {
|
||||
@@ -16,8 +19,8 @@ type syncCommandStreamTrait interface {
|
||||
Receive() (*v1sync.SyncStreamItem, error)
|
||||
}
|
||||
|
||||
var _ syncCommandStreamTrait = (*connect.BidiStream[v1sync.SyncStreamItem, v1sync.SyncStreamItem])(nil) // Ensure that connect.BidiStream implements syncCommandStreamTrait
|
||||
var _ syncCommandStreamTrait = (*connect.BidiStreamForClient[v1sync.SyncStreamItem, v1sync.SyncStreamItem])(nil) // Ensure that connect.BidiStreamForClient implements syncCommandStreamTrait
|
||||
var _ syncCommandStreamTrait = (*connect.BidiStream[v1sync.SyncStreamItem, v1sync.SyncStreamItem])(nil)
|
||||
var _ syncCommandStreamTrait = (*connect.BidiStreamForClient[v1sync.SyncStreamItem, v1sync.SyncStreamItem])(nil)
|
||||
|
||||
type bidiSyncCommandStream struct {
|
||||
sendChan chan *v1sync.SyncStreamItem
|
||||
@@ -27,7 +30,7 @@ type bidiSyncCommandStream struct {
|
||||
|
||||
func newBidiSyncCommandStream() *bidiSyncCommandStream {
|
||||
return &bidiSyncCommandStream{
|
||||
sendChan: make(chan *v1sync.SyncStreamItem, 256), // Buffered channel to allow sending items without blocking
|
||||
sendChan: make(chan *v1sync.SyncStreamItem, 256),
|
||||
recvChan: make(chan *v1sync.SyncStreamItem, 1),
|
||||
terminateWithErrChan: make(chan error, 1),
|
||||
}
|
||||
@@ -37,7 +40,6 @@ func (s *bidiSyncCommandStream) Send(item *v1sync.SyncStreamItem) {
|
||||
select {
|
||||
case s.sendChan <- item:
|
||||
default:
|
||||
// Try again with a timeout, if it fails, send an error to terminate the stream
|
||||
select {
|
||||
case s.sendChan <- item:
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
@@ -46,14 +48,10 @@ func (s *bidiSyncCommandStream) Send(item *v1sync.SyncStreamItem) {
|
||||
}
|
||||
}
|
||||
|
||||
// SendErrorAndTerminate sends an error to the termination channel.
|
||||
// If the error is nil, it terminates only.
|
||||
func (s *bidiSyncCommandStream) SendErrorAndTerminate(err error) {
|
||||
select {
|
||||
case s.terminateWithErrChan <- err:
|
||||
default:
|
||||
// If the channel is full, we can't send the error, so we just ignore it.
|
||||
// This is a best-effort termination.
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,17 +64,27 @@ func (s *bidiSyncCommandStream) ReceiveWithinDuration(d time.Duration) *v1sync.S
|
||||
case item := <-s.recvChan:
|
||||
return item
|
||||
case <-time.After(d):
|
||||
return nil // Return nil if no item is received within the duration
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ConnectStream bridges the channel-based bidiSyncCommandStream to a real transport.
|
||||
// It first performs an ECDH key exchange on the raw transport to establish an encrypted
|
||||
// session, then starts the send/recv pump loop over the encrypted channel.
|
||||
func (s *bidiSyncCommandStream) ConnectStream(ctx context.Context, stream syncCommandStreamTrait) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Perform ECDH key exchange on the raw transport before starting the pump.
|
||||
transport, err := establishEncryption(stream)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer close(s.recvChan)
|
||||
for {
|
||||
val, err := stream.Receive()
|
||||
val, err := transport.Receive()
|
||||
if err != nil {
|
||||
s.SendErrorAndTerminate(NewSyncErrorDisconnected(fmt.Errorf("receiving item: %w", err)))
|
||||
return
|
||||
@@ -95,7 +103,7 @@ func (s *bidiSyncCommandStream) ConnectStream(ctx context.Context, stream syncCo
|
||||
if item == nil {
|
||||
continue
|
||||
}
|
||||
if err := stream.Send(item); err != nil {
|
||||
if err := transport.Send(item); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = fmt.Errorf("connection failed or dropped: %w", err)
|
||||
}
|
||||
@@ -103,10 +111,59 @@ func (s *bidiSyncCommandStream) ConnectStream(ctx context.Context, stream syncCo
|
||||
return err
|
||||
}
|
||||
case err := <-s.terminateWithErrChan:
|
||||
return err // Terminate the stream with the error or nil if no error was sent
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
// Context is done, we should stop processing.
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// establishEncryption performs an ECDH key exchange on the raw transport and
|
||||
// returns an encrypted stream wrapper. Each side generates an ephemeral ECDH P-256
|
||||
// key pair, exchanges public keys, and derives a shared AES-256-GCM session key.
|
||||
// The handshake (identity authentication) runs over the encrypted channel afterward.
|
||||
func establishEncryption(stream syncCommandStreamTrait) (syncCommandStreamTrait, error) {
|
||||
keyPair, err := cryptoutil.GenerateECDHKeyPair()
|
||||
if err != nil {
|
||||
return nil, NewSyncErrorInternal(fmt.Errorf("generating ephemeral ECDH key: %w", err))
|
||||
}
|
||||
|
||||
// Send our ephemeral ECDH public key
|
||||
if err := stream.Send(&v1sync.SyncStreamItem{
|
||||
Action: &v1sync.SyncStreamItem_EstablishSharedSecret{
|
||||
EstablishSharedSecret: &v1sync.SyncStreamItem_SyncEstablishSharedSecret{
|
||||
EcdhPublicKey: keyPair.Public.Bytes(),
|
||||
},
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, NewSyncErrorProtocol(fmt.Errorf("sending ECDH public key: %w", err))
|
||||
}
|
||||
|
||||
// Receive the peer's ephemeral ECDH public key
|
||||
peerMsg, err := stream.Receive()
|
||||
if err != nil {
|
||||
return nil, NewSyncErrorProtocol(fmt.Errorf("receiving ECDH public key: %w", err))
|
||||
}
|
||||
peerSecret := peerMsg.GetEstablishSharedSecret()
|
||||
if peerSecret == nil {
|
||||
return nil, NewSyncErrorProtocol(fmt.Errorf("expected ECDH key exchange, got %T", peerMsg.GetAction()))
|
||||
}
|
||||
|
||||
peerECDHPub, err := cryptoutil.ParseECDHPublicKey(peerSecret.GetEcdhPublicKey())
|
||||
if err != nil {
|
||||
return nil, NewSyncErrorProtocol(fmt.Errorf("parsing peer ECDH public key: %w", err))
|
||||
}
|
||||
|
||||
// Derive AES-256-GCM session key
|
||||
gcm, err := cryptoutil.DeriveSessionKey(keyPair.Private, peerECDHPub)
|
||||
if err != nil {
|
||||
return nil, NewSyncErrorProtocol(fmt.Errorf("deriving session key: %w", err))
|
||||
}
|
||||
|
||||
// Determine nonce direction: side with smaller public key uses prefix 0x00
|
||||
localIsSmaller := bytes.Compare(keyPair.Public.Bytes(), peerECDHPub.Bytes()) < 0
|
||||
|
||||
zap.L().Info("encrypted sync session established")
|
||||
|
||||
return newEncryptedStream(stream, gcm, localIsSmaller), nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
package syncapi
|
||||
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/garethgeorge/backrest/gen/go/v1sync"
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
// encryptedStream wraps a syncCommandStreamTrait with AES-256-GCM encryption.
|
||||
// Outgoing SyncStreamItems are serialized, encrypted, and sent as SyncActionEncrypted.
|
||||
// Incoming SyncActionEncrypted messages are decrypted and deserialized back to SyncStreamItems.
|
||||
//
|
||||
// To avoid nonce reuse (since both sides share the same key), each direction
|
||||
// uses a different nonce prefix byte: the side with the lexicographically smaller
|
||||
// ECDH public key uses prefix 0x00 for sending and expects 0x01 for receiving,
|
||||
// and vice versa.
|
||||
type encryptedStream struct {
|
||||
inner syncCommandStreamTrait
|
||||
gcm cipher.AEAD
|
||||
|
||||
sendPrefix byte
|
||||
recvPrefix byte
|
||||
|
||||
sendMu sync.Mutex
|
||||
sendCounter uint64
|
||||
|
||||
recvMu sync.Mutex
|
||||
recvCounter uint64
|
||||
}
|
||||
|
||||
func newEncryptedStream(inner syncCommandStreamTrait, gcm cipher.AEAD, localIsSmaller bool) *encryptedStream {
|
||||
var sendPrefix, recvPrefix byte
|
||||
if localIsSmaller {
|
||||
sendPrefix, recvPrefix = 0x00, 0x01
|
||||
} else {
|
||||
sendPrefix, recvPrefix = 0x01, 0x00
|
||||
}
|
||||
return &encryptedStream{
|
||||
inner: inner,
|
||||
gcm: gcm,
|
||||
sendPrefix: sendPrefix,
|
||||
recvPrefix: recvPrefix,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *encryptedStream) Send(item *v1sync.SyncStreamItem) error {
|
||||
plaintext, err := proto.Marshal(item)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshal for encryption: %w", err)
|
||||
}
|
||||
|
||||
s.sendMu.Lock()
|
||||
nonce := s.makeNonce(s.sendPrefix, s.sendCounter)
|
||||
s.sendCounter++
|
||||
s.sendMu.Unlock()
|
||||
|
||||
ciphertext := s.gcm.Seal(nil, nonce, plaintext, nil)
|
||||
|
||||
return s.inner.Send(&v1sync.SyncStreamItem{
|
||||
Action: &v1sync.SyncStreamItem_Encrypted{
|
||||
Encrypted: &v1sync.SyncStreamItem_SyncActionEncrypted{
|
||||
Nonce: nonce,
|
||||
Ciphertext: ciphertext,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func (s *encryptedStream) Receive() (*v1sync.SyncStreamItem, error) {
|
||||
envelope, err := s.inner.Receive()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encrypted := envelope.GetEncrypted()
|
||||
if encrypted == nil {
|
||||
return nil, fmt.Errorf("expected encrypted message, got %T", envelope.GetAction())
|
||||
}
|
||||
|
||||
s.recvMu.Lock()
|
||||
expectedNonce := s.makeNonce(s.recvPrefix, s.recvCounter)
|
||||
s.recvCounter++
|
||||
s.recvMu.Unlock()
|
||||
|
||||
if len(encrypted.Nonce) != s.gcm.NonceSize() {
|
||||
return nil, fmt.Errorf("invalid nonce size: got %d, want %d", len(encrypted.Nonce), s.gcm.NonceSize())
|
||||
}
|
||||
|
||||
// Verify nonce matches expected counter to prevent replay/reorder attacks
|
||||
for i := range expectedNonce {
|
||||
if expectedNonce[i] != encrypted.Nonce[i] {
|
||||
return nil, fmt.Errorf("nonce mismatch: possible replay or reorder attack")
|
||||
}
|
||||
}
|
||||
|
||||
plaintext, err := s.gcm.Open(nil, encrypted.Nonce, encrypted.Ciphertext, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decrypt message: %w", err)
|
||||
}
|
||||
|
||||
var inner v1sync.SyncStreamItem
|
||||
if err := proto.Unmarshal(plaintext, &inner); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal decrypted message: %w", err)
|
||||
}
|
||||
|
||||
return &inner, nil
|
||||
}
|
||||
|
||||
// makeNonce creates a 12-byte GCM nonce. The first byte is the direction prefix
|
||||
// (0x00 or 0x01), bytes 1-3 are zero, and bytes 4-11 are the counter in big-endian.
|
||||
func (s *encryptedStream) makeNonce(prefix byte, counter uint64) []byte {
|
||||
nonce := make([]byte, s.gcm.NonceSize()) // 12 bytes for GCM
|
||||
nonce[0] = prefix
|
||||
binary.BigEndian.PutUint64(nonce[4:], counter)
|
||||
return nonce
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
package syncapi
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/garethgeorge/backrest/gen/go/v1sync"
|
||||
"github.com/garethgeorge/backrest/internal/cryptoutil"
|
||||
)
|
||||
|
||||
// fakeStream is a pair of in-memory channels simulating a bidirectional transport.
|
||||
type fakeStream struct {
|
||||
sendCh chan *v1sync.SyncStreamItem
|
||||
recvCh chan *v1sync.SyncStreamItem
|
||||
}
|
||||
|
||||
func (f *fakeStream) Send(item *v1sync.SyncStreamItem) error {
|
||||
f.sendCh <- item
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeStream) Receive() (*v1sync.SyncStreamItem, error) {
|
||||
item := <-f.recvCh
|
||||
return item, nil
|
||||
}
|
||||
|
||||
// newFakeStreamPair creates two connected fakeStreams (A's send is B's recv and vice versa).
|
||||
func newFakeStreamPair() (*fakeStream, *fakeStream) {
|
||||
ab := make(chan *v1sync.SyncStreamItem, 16)
|
||||
ba := make(chan *v1sync.SyncStreamItem, 16)
|
||||
return &fakeStream{sendCh: ab, recvCh: ba}, &fakeStream{sendCh: ba, recvCh: ab}
|
||||
}
|
||||
|
||||
func TestEncryptedStream_RoundTrip(t *testing.T) {
|
||||
alice, err := cryptoutil.GenerateECDHKeyPair()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
bob, err := cryptoutil.GenerateECDHKeyPair()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
gcm, err := cryptoutil.DeriveSessionKey(alice.Private, bob.Public)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
aliceIsSmaller := string(alice.Public.Bytes()) < string(bob.Public.Bytes())
|
||||
|
||||
transportA, transportB := newFakeStreamPair()
|
||||
encA := newEncryptedStream(transportA, gcm, aliceIsSmaller)
|
||||
encB := newEncryptedStream(transportB, gcm, !aliceIsSmaller)
|
||||
|
||||
// Send a heartbeat from A to B
|
||||
sendItem := &v1sync.SyncStreamItem{
|
||||
Action: &v1sync.SyncStreamItem_Heartbeat{
|
||||
Heartbeat: &v1sync.SyncStreamItem_SyncActionHeartbeat{},
|
||||
},
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := encA.Send(sendItem); err != nil {
|
||||
t.Errorf("send: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
recvItem, err := encB.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("receive: %v", err)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if recvItem.GetHeartbeat() == nil {
|
||||
t.Fatalf("expected heartbeat, got %T", recvItem.GetAction())
|
||||
}
|
||||
}
|
||||
|
||||
func TestEncryptedStream_BidirectionalMultiMessage(t *testing.T) {
|
||||
alice, _ := cryptoutil.GenerateECDHKeyPair()
|
||||
bob, _ := cryptoutil.GenerateECDHKeyPair()
|
||||
gcm, _ := cryptoutil.DeriveSessionKey(alice.Private, bob.Public)
|
||||
|
||||
aliceIsSmaller := string(alice.Public.Bytes()) < string(bob.Public.Bytes())
|
||||
|
||||
transportA, transportB := newFakeStreamPair()
|
||||
encA := newEncryptedStream(transportA, gcm, aliceIsSmaller)
|
||||
encB := newEncryptedStream(transportB, gcm, !aliceIsSmaller)
|
||||
|
||||
heartbeat := &v1sync.SyncStreamItem{
|
||||
Action: &v1sync.SyncStreamItem_Heartbeat{
|
||||
Heartbeat: &v1sync.SyncStreamItem_SyncActionHeartbeat{},
|
||||
},
|
||||
}
|
||||
|
||||
// Send 5 messages A→B sequentially, then 5 messages B→A sequentially
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// A→B direction
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 5; i++ {
|
||||
if err := encA.Send(heartbeat); err != nil {
|
||||
t.Errorf("A send %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
for i := 0; i < 5; i++ {
|
||||
if _, err := encB.Receive(); err != nil {
|
||||
t.Fatalf("B receive %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
// B→A direction
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < 5; i++ {
|
||||
if err := encB.Send(heartbeat); err != nil {
|
||||
t.Errorf("B send %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
for i := 0; i < 5; i++ {
|
||||
if _, err := encA.Receive(); err != nil {
|
||||
t.Fatalf("A receive %d: %v", i, err)
|
||||
}
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestEstablishEncryption_Integration(t *testing.T) {
|
||||
transportA, transportB := newFakeStreamPair()
|
||||
|
||||
var encA, encB syncCommandStreamTrait
|
||||
var errA, errB error
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
encA, errA = establishEncryption(transportA)
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
encB, errB = establishEncryption(transportB)
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
if errA != nil {
|
||||
t.Fatalf("establish A: %v", errA)
|
||||
}
|
||||
if errB != nil {
|
||||
t.Fatalf("establish B: %v", errB)
|
||||
}
|
||||
|
||||
// Verify encrypted communication works
|
||||
heartbeat := &v1sync.SyncStreamItem{
|
||||
Action: &v1sync.SyncStreamItem_Heartbeat{
|
||||
Heartbeat: &v1sync.SyncStreamItem_SyncActionHeartbeat{},
|
||||
},
|
||||
}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := encA.Send(heartbeat); err != nil {
|
||||
t.Errorf("send: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
recv, err := encB.Receive()
|
||||
if err != nil {
|
||||
t.Fatalf("receive: %v", err)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if recv.GetHeartbeat() == nil {
|
||||
t.Fatalf("expected heartbeat, got %T", recv.GetAction())
|
||||
}
|
||||
}
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
"google.golang.org/protobuf/proto"
|
||||
)
|
||||
|
||||
const SyncProtocolVersion = 1
|
||||
const SyncProtocolVersion = 2
|
||||
|
||||
type BackrestSyncHandler struct {
|
||||
v1syncconnect.UnimplementedBackrestSyncServiceHandler
|
||||
|
||||
Reference in New Issue
Block a user