mirror of
https://github.com/garethgeorge/backrest.git
synced 2026-05-04 03:50:30 +00:00
170 lines
4.9 KiB
Go
170 lines
4.9 KiB
Go
package syncapi
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"time"
|
|
|
|
"connectrpc.com/connect"
|
|
"github.com/garethgeorge/backrest/gen/go/v1sync"
|
|
"github.com/garethgeorge/backrest/internal/cryptoutil"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
type syncCommandStreamTrait interface {
|
|
Send(item *v1sync.SyncStreamItem) error
|
|
Receive() (*v1sync.SyncStreamItem, error)
|
|
}
|
|
|
|
var _ syncCommandStreamTrait = (*connect.BidiStream[v1sync.SyncStreamItem, v1sync.SyncStreamItem])(nil)
|
|
var _ syncCommandStreamTrait = (*connect.BidiStreamForClient[v1sync.SyncStreamItem, v1sync.SyncStreamItem])(nil)
|
|
|
|
type bidiSyncCommandStream struct {
|
|
sendChan chan *v1sync.SyncStreamItem
|
|
recvChan chan *v1sync.SyncStreamItem
|
|
terminateWithErrChan chan error
|
|
}
|
|
|
|
func newBidiSyncCommandStream() *bidiSyncCommandStream {
|
|
return &bidiSyncCommandStream{
|
|
sendChan: make(chan *v1sync.SyncStreamItem, 256),
|
|
recvChan: make(chan *v1sync.SyncStreamItem, 1),
|
|
terminateWithErrChan: make(chan error, 1),
|
|
}
|
|
}
|
|
|
|
func (s *bidiSyncCommandStream) Send(item *v1sync.SyncStreamItem) {
|
|
select {
|
|
case s.sendChan <- item:
|
|
default:
|
|
select {
|
|
case s.sendChan <- item:
|
|
case <-time.After(100 * time.Millisecond):
|
|
s.SendErrorAndTerminate(NewSyncErrorDisconnected(errors.New("send channel is full, cannot send item")))
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *bidiSyncCommandStream) SendErrorAndTerminate(err error) {
|
|
select {
|
|
case s.terminateWithErrChan <- err:
|
|
default:
|
|
}
|
|
}
|
|
|
|
func (s *bidiSyncCommandStream) ReadChannel() chan *v1sync.SyncStreamItem {
|
|
return s.recvChan
|
|
}
|
|
|
|
func (s *bidiSyncCommandStream) ReceiveWithinDuration(d time.Duration) *v1sync.SyncStreamItem {
|
|
select {
|
|
case item := <-s.recvChan:
|
|
return item
|
|
case <-time.After(d):
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// ConnectStream bridges the channel-based bidiSyncCommandStream to a real transport.
|
|
// It first performs an ECDH key exchange on the raw transport to establish an encrypted
|
|
// session, then starts the send/recv pump loop over the encrypted channel.
|
|
func (s *bidiSyncCommandStream) ConnectStream(ctx context.Context, stream syncCommandStreamTrait) error {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
// Perform ECDH key exchange on the raw transport before starting the pump.
|
|
transport, err := establishEncryption(stream)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
go func() {
|
|
defer close(s.recvChan)
|
|
for {
|
|
val, err := transport.Receive()
|
|
if err != nil {
|
|
s.SendErrorAndTerminate(NewSyncErrorDisconnected(fmt.Errorf("receiving item: %w", err)))
|
|
return
|
|
}
|
|
select {
|
|
case s.recvChan <- val:
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
for {
|
|
select {
|
|
case item := <-s.sendChan:
|
|
if item == nil {
|
|
continue
|
|
}
|
|
if err := transport.Send(item); err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
err = fmt.Errorf("connection failed or dropped: %w", err)
|
|
}
|
|
s.SendErrorAndTerminate(err)
|
|
return err
|
|
}
|
|
case err := <-s.terminateWithErrChan:
|
|
return err
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
}
|
|
}
|
|
}
|
|
|
|
// establishEncryption performs an ECDH key exchange on the raw transport and
|
|
// returns an encrypted stream wrapper. Each side generates an ephemeral ECDH P-256
|
|
// key pair, exchanges public keys, and derives a shared AES-256-GCM session key.
|
|
// The handshake (identity authentication) runs over the encrypted channel afterward.
|
|
func establishEncryption(stream syncCommandStreamTrait) (syncCommandStreamTrait, error) {
|
|
keyPair, err := cryptoutil.GenerateECDHKeyPair()
|
|
if err != nil {
|
|
return nil, NewSyncErrorInternal(fmt.Errorf("generating ephemeral ECDH key: %w", err))
|
|
}
|
|
|
|
// Send our ephemeral ECDH public key
|
|
if err := stream.Send(&v1sync.SyncStreamItem{
|
|
Action: &v1sync.SyncStreamItem_EstablishSharedSecret{
|
|
EstablishSharedSecret: &v1sync.SyncStreamItem_SyncEstablishSharedSecret{
|
|
EcdhPublicKey: keyPair.Public.Bytes(),
|
|
},
|
|
},
|
|
}); err != nil {
|
|
return nil, NewSyncErrorDisconnected(fmt.Errorf("sending ECDH public key: %w", err))
|
|
}
|
|
|
|
// Receive the peer's ephemeral ECDH public key
|
|
peerMsg, err := stream.Receive()
|
|
if err != nil {
|
|
return nil, NewSyncErrorDisconnected(fmt.Errorf("receiving ECDH public key: %w", err))
|
|
}
|
|
peerSecret := peerMsg.GetEstablishSharedSecret()
|
|
if peerSecret == nil {
|
|
return nil, NewSyncErrorProtocol(fmt.Errorf("expected ECDH key exchange, got %T", peerMsg.GetAction()))
|
|
}
|
|
|
|
peerECDHPub, err := cryptoutil.ParseECDHPublicKey(peerSecret.GetEcdhPublicKey())
|
|
if err != nil {
|
|
return nil, NewSyncErrorProtocol(fmt.Errorf("parsing peer ECDH public key: %w", err))
|
|
}
|
|
|
|
// Derive AES-256-GCM session key
|
|
gcm, err := cryptoutil.DeriveSessionKey(keyPair.Private, peerECDHPub)
|
|
if err != nil {
|
|
return nil, NewSyncErrorProtocol(fmt.Errorf("deriving session key: %w", err))
|
|
}
|
|
|
|
// Determine nonce direction: side with smaller public key uses prefix 0x00
|
|
localIsSmaller := bytes.Compare(keyPair.Public.Bytes(), peerECDHPub.Bytes()) < 0
|
|
|
|
zap.L().Info("encrypted sync session established")
|
|
|
|
return newEncryptedStream(stream, gcm, localIsSmaller), nil
|
|
}
|