Files

110 lines
2.9 KiB
Go

package syncapi
import (
"crypto/cipher"
"encoding/binary"
"fmt"
"sync"
"github.com/garethgeorge/backrest/gen/go/v1sync"
"google.golang.org/protobuf/proto"
)
// encryptedStream wraps a syncCommandStreamTrait with AES-256-GCM encryption
// using the per-direction AEADs derived during the transport handshake.
//
// Each direction has an independent key (initiator-to-responder vs
// responder-to-initiator), so a counter-based nonce starting at zero is
// sufficient: there is no shared (key, nonce) space to collide in.
type encryptedStream struct {
inner syncCommandStreamTrait
send cipher.AEAD
recv cipher.AEAD
sendMu sync.Mutex
sendCounter uint64
recvMu sync.Mutex
recvCounter uint64
}
func newEncryptedStream(inner syncCommandStreamTrait, send, recv cipher.AEAD) *encryptedStream {
return &encryptedStream{
inner: inner,
send: send,
recv: recv,
}
}
func (s *encryptedStream) Send(item *v1sync.SyncStreamItem) error {
plaintext, err := proto.Marshal(item)
if err != nil {
return fmt.Errorf("marshal for encryption: %w", err)
}
s.sendMu.Lock()
nonce := makeNonce(s.send.NonceSize(), s.sendCounter)
s.sendCounter++
s.sendMu.Unlock()
ciphertext := s.send.Seal(nil, nonce, plaintext, nil)
return s.inner.Send(&v1sync.SyncStreamItem{
Action: &v1sync.SyncStreamItem_Encrypted{
Encrypted: &v1sync.SyncStreamItem_SyncActionEncrypted{
Nonce: nonce,
Ciphertext: ciphertext,
},
},
})
}
func (s *encryptedStream) Receive() (*v1sync.SyncStreamItem, error) {
envelope, err := s.inner.Receive()
if err != nil {
return nil, err
}
encrypted := envelope.GetEncrypted()
if encrypted == nil {
return nil, fmt.Errorf("expected encrypted message, got %T", envelope.GetAction())
}
s.recvMu.Lock()
expectedNonce := makeNonce(s.recv.NonceSize(), s.recvCounter)
s.recvCounter++
s.recvMu.Unlock()
if len(encrypted.Nonce) != s.recv.NonceSize() {
return nil, fmt.Errorf("invalid nonce size: got %d, want %d", len(encrypted.Nonce), s.recv.NonceSize())
}
// Verify nonce matches expected counter to prevent replay/reorder attacks.
for i := range expectedNonce {
if expectedNonce[i] != encrypted.Nonce[i] {
return nil, fmt.Errorf("nonce mismatch: possible replay or reorder attack")
}
}
plaintext, err := s.recv.Open(nil, encrypted.Nonce, encrypted.Ciphertext, nil)
if err != nil {
return nil, fmt.Errorf("decrypt message: %w", err)
}
var inner v1sync.SyncStreamItem
if err := proto.Unmarshal(plaintext, &inner); err != nil {
return nil, fmt.Errorf("unmarshal decrypted message: %w", err)
}
return &inner, nil
}
// makeNonce builds an N-byte AES-GCM nonce by big-endian-encoding the counter
// in the trailing 8 bytes; the leading bytes are zero. The counter never
// repeats within a single session direction, so the nonce never repeats.
func makeNonce(size int, counter uint64) []byte {
nonce := make([]byte, size)
binary.BigEndian.PutUint64(nonce[size-8:], counter)
return nonce
}