mirror of
https://github.com/garethgeorge/backrest.git
synced 2025-12-14 17:45:36 +00:00
384 lines
14 KiB
Go
384 lines
14 KiB
Go
package syncapi
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"slices"
|
|
"sort"
|
|
|
|
"connectrpc.com/connect"
|
|
v1 "github.com/garethgeorge/backrest/gen/go/v1"
|
|
"github.com/garethgeorge/backrest/gen/go/v1/v1connect"
|
|
"github.com/garethgeorge/backrest/internal/config"
|
|
"github.com/garethgeorge/backrest/internal/oplog"
|
|
"github.com/garethgeorge/backrest/internal/protoutil"
|
|
lru "github.com/hashicorp/golang-lru/v2"
|
|
"go.uber.org/zap"
|
|
)
|
|
|
|
const SyncProtocolVersion = 1
|
|
|
|
type BackrestSyncHandler struct {
|
|
v1connect.UnimplementedBackrestSyncServiceHandler
|
|
mgr *SyncManager
|
|
}
|
|
|
|
var _ v1connect.BackrestSyncServiceHandler = &BackrestSyncHandler{}
|
|
|
|
func NewBackrestSyncHandler(mgr *SyncManager) *BackrestSyncHandler {
|
|
return &BackrestSyncHandler{
|
|
mgr: mgr,
|
|
}
|
|
}
|
|
|
|
func (h *BackrestSyncHandler) Sync(ctx context.Context, stream *connect.BidiStream[v1.SyncStreamItem, v1.SyncStreamItem]) error {
|
|
// TODO: this request can be very long lived, we must periodically refresh the config
|
|
// e.g. to disconnect a client if its access is revoked.
|
|
initialConfig, err := h.mgr.configMgr.Get()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
receive := make(chan *v1.SyncStreamItem, 1)
|
|
send := make(chan *v1.SyncStreamItem, 1)
|
|
go func() {
|
|
for {
|
|
item, err := stream.Receive()
|
|
if err != nil {
|
|
break
|
|
}
|
|
receive <- item
|
|
}
|
|
close(receive)
|
|
}()
|
|
|
|
// Broadcast initial packet containing the protocol version and instance ID.
|
|
zap.S().Debugf("syncserver a client connected, broadcast handshake as %v", initialConfig.Instance)
|
|
if err := stream.Send(&v1.SyncStreamItem{
|
|
Action: &v1.SyncStreamItem_Handshake{
|
|
Handshake: &v1.SyncStreamItem_SyncActionHandshake{
|
|
ProtocolVersion: SyncProtocolVersion,
|
|
InstanceId: &v1.SignedMessage{
|
|
Payload: []byte(initialConfig.Instance),
|
|
Signature: []byte("TODO: inject a valid signature"),
|
|
Keyid: "TODO: inject a valid key ID",
|
|
},
|
|
},
|
|
},
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
|
|
// Try to read the handshake packet from the client.
|
|
// TODO: perform this handshake in a header as a pre-flight before opening the stream.
|
|
clientInstanceID := ""
|
|
if msg, ok := <-receive; ok {
|
|
handshake := msg.GetHandshake()
|
|
if handshake == nil {
|
|
return connect.NewError(connect.CodeInvalidArgument, errors.New("handshake packet must be sent first"))
|
|
}
|
|
|
|
clientInstanceID = string(handshake.GetInstanceId().GetPayload())
|
|
if clientInstanceID == "" {
|
|
return connect.NewError(connect.CodeInvalidArgument, errors.New("instance ID is required"))
|
|
}
|
|
} else {
|
|
return connect.NewError(connect.CodeInvalidArgument, errors.New("no packets received"))
|
|
}
|
|
|
|
var authorizedClientPeer *v1.Multihost_Peer
|
|
authorizedClientPeerIdx := slices.IndexFunc(initialConfig.Multihost.GetAuthorizedClients(), func(peer *v1.Multihost_Peer) bool {
|
|
return peer.InstanceId == clientInstanceID
|
|
})
|
|
if authorizedClientPeerIdx == -1 {
|
|
// TODO: check the key signature of the handshake message here.
|
|
zap.S().Warnf("syncserver rejected a connection from client instance ID %q because it is not authorized", clientInstanceID)
|
|
return connect.NewError(connect.CodePermissionDenied, errors.New("client is not an authorized peer"))
|
|
} else {
|
|
authorizedClientPeer = initialConfig.Multihost.AuthorizedClients[authorizedClientPeerIdx]
|
|
}
|
|
zap.S().Infof("syncserver accepted a connection from client instance ID %q", authorizedClientPeer.InstanceId)
|
|
|
|
opIDLru, _ := lru.New[int64, int64](128) // original ID -> local ID
|
|
flowIDLru, _ := lru.New[int64, int64](128) // original flow ID -> local flow ID
|
|
|
|
insertOrUpdate := func(op *v1.Operation) error {
|
|
op.OriginalId = op.Id
|
|
op.OriginalFlowId = op.FlowId
|
|
var ok bool
|
|
if op.Id, ok = opIDLru.Get(op.OriginalId); !ok {
|
|
var foundOp *v1.Operation
|
|
if err := h.mgr.oplog.Query(oplog.Query{}.
|
|
SetOriginalID(op.OriginalId).
|
|
SetInstanceID(op.InstanceId), func(o *v1.Operation) error {
|
|
foundOp = o
|
|
return nil
|
|
}); err != nil {
|
|
return fmt.Errorf("mapping remote ID to local ID: %w", err)
|
|
}
|
|
if foundOp != nil {
|
|
op.Id = foundOp.Id
|
|
opIDLru.Add(foundOp.Id, foundOp.Id)
|
|
}
|
|
}
|
|
if op.FlowId, ok = flowIDLru.Get(op.OriginalFlowId); !ok {
|
|
var flowOp *v1.Operation
|
|
if err := h.mgr.oplog.Query(oplog.Query{}.
|
|
SetOriginalFlowID(op.OriginalFlowId).
|
|
SetInstanceID(op.InstanceId), func(o *v1.Operation) error {
|
|
flowOp = o
|
|
return nil
|
|
}); err != nil {
|
|
return fmt.Errorf("mapping remote flow ID to local ID: %w", err)
|
|
}
|
|
if flowOp != nil {
|
|
op.FlowId = flowOp.FlowId
|
|
flowIDLru.Add(op.OriginalFlowId, flowOp.FlowId)
|
|
}
|
|
}
|
|
|
|
return h.mgr.oplog.Set(op)
|
|
}
|
|
|
|
deleteByOriginalID := func(originalID int64) error {
|
|
var foundOp *v1.Operation
|
|
if err := h.mgr.oplog.Query(oplog.Query{}.SetOriginalID(originalID), func(o *v1.Operation) error {
|
|
foundOp = o
|
|
return nil
|
|
}); err != nil {
|
|
return fmt.Errorf("mapping remote ID to local ID: %w", err)
|
|
}
|
|
|
|
if foundOp == nil {
|
|
zap.S().Debugf("syncserver received delete for non-existent operation %v", originalID)
|
|
return nil
|
|
}
|
|
|
|
return h.mgr.oplog.Delete(foundOp.Id)
|
|
}
|
|
|
|
sendConfigToClient := func(config *v1.Config) error {
|
|
remoteConfig := &v1.RemoteConfig{}
|
|
var allowedRepoIDs []string
|
|
for _, repo := range config.Repos {
|
|
if slices.Contains(repo.AllowedPeerInstanceIds, clientInstanceID) {
|
|
allowedRepoIDs = append(allowedRepoIDs, repo.Id)
|
|
remoteConfig.Repos = append(remoteConfig.Repos, protoutil.RepoToRemoteRepo(repo))
|
|
}
|
|
}
|
|
|
|
zap.S().Debugf("syncserver determined client %v is allowlisted for repos %v", clientInstanceID, allowedRepoIDs)
|
|
|
|
// Send the config, this is the first meaningful packet the client will receive.
|
|
// Once configuration is received, the client will start sending diffs.
|
|
if err := stream.Send(&v1.SyncStreamItem{
|
|
Action: &v1.SyncStreamItem_SendConfig{
|
|
SendConfig: &v1.SyncStreamItem_SyncActionSendConfig{
|
|
Config: remoteConfig,
|
|
},
|
|
},
|
|
}); err != nil {
|
|
return fmt.Errorf("sending config to client: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
handleSyncCommand := func(item *v1.SyncStreamItem) error {
|
|
switch action := item.Action.(type) {
|
|
case *v1.SyncStreamItem_SendConfig:
|
|
return errors.New("clients can not push configs to server")
|
|
case *v1.SyncStreamItem_DiffOperations:
|
|
diffSel := action.DiffOperations.GetHaveOperationsSelector()
|
|
|
|
if diffSel == nil {
|
|
return connect.NewError(connect.CodeInvalidArgument, errors.New("action DiffOperations: selector is required"))
|
|
}
|
|
|
|
// The diff selector _must_ be scoped to the instance ID of the client.
|
|
if diffSel.GetInstanceId() != clientInstanceID {
|
|
return connect.NewError(connect.CodePermissionDenied, errors.New("action DiffOperations: instance ID mismatch in diff selector"))
|
|
}
|
|
|
|
// The diff selector _must_ specify a repo the client has access to
|
|
repo := config.FindRepoByGUID(initialConfig, diffSel.GetRepoGuid())
|
|
if repo == nil {
|
|
zap.S().Warnf("syncserver action DiffOperations: client %q tried to diff with repo %q that does not exist", clientInstanceID, diffSel.GetRepoGuid())
|
|
return connect.NewError(connect.CodePermissionDenied, fmt.Errorf("action DiffOperations: repo %q not found", diffSel.GetRepoGuid()))
|
|
}
|
|
if !slices.Contains(repo.GetAllowedPeerInstanceIds(), clientInstanceID) {
|
|
zap.S().Warnf("syncserver action DiffOperations: client %q tried to diff with repo %q that they are not allowed to access", clientInstanceID, repo.Id)
|
|
return connect.NewError(connect.CodePermissionDenied, fmt.Errorf("action DiffOperations: client is not allowed to access repo %q", repo.Id))
|
|
}
|
|
|
|
// These are required to be the same length for a pairwise zip.
|
|
if len(action.DiffOperations.HaveOperationIds) != len(action.DiffOperations.HaveOperationModnos) {
|
|
return connect.NewError(connect.CodeInvalidArgument, errors.New("action DiffOperations: operation IDs and modnos must be the same length"))
|
|
}
|
|
|
|
diffSelQuery, err := protoutil.OpSelectorToQuery(diffSel)
|
|
if err != nil {
|
|
return fmt.Errorf("action DiffOperations: converting diff selector to query: %w", err)
|
|
}
|
|
|
|
localMetadata := []oplog.OpMetadata{}
|
|
if err := h.mgr.oplog.QueryMetadata(diffSelQuery, func(metadata oplog.OpMetadata) error {
|
|
if metadata.OriginalID == 0 {
|
|
return nil // skip operations that didn't come from a remote
|
|
}
|
|
localMetadata = append(localMetadata, metadata)
|
|
return nil
|
|
}); err != nil {
|
|
return fmt.Errorf("action DiffOperations: querying local metadata: %w", err)
|
|
}
|
|
sort.Slice(localMetadata, func(i, j int) bool {
|
|
return localMetadata[i].OriginalID < localMetadata[j].OriginalID
|
|
})
|
|
|
|
remoteMetadata := make([]oplog.OpMetadata, len(action.DiffOperations.HaveOperationIds))
|
|
for i, id := range action.DiffOperations.HaveOperationIds {
|
|
remoteMetadata[i] = oplog.OpMetadata{
|
|
ID: id,
|
|
Modno: action.DiffOperations.HaveOperationModnos[i],
|
|
}
|
|
}
|
|
sort.Slice(remoteMetadata, func(i, j int) bool {
|
|
return remoteMetadata[i].ID < remoteMetadata[j].ID
|
|
})
|
|
|
|
requestDueToModno := 0
|
|
requestMissingRemote := 0
|
|
requestMissingLocal := 0
|
|
requestIDs := []int64{}
|
|
|
|
// This is a simple O(n) diff algorithm that compares the local and remote metadata vectors.
|
|
localIndex := 0
|
|
remoteIndex := 0
|
|
for localIndex < len(localMetadata) && remoteIndex < len(remoteMetadata) {
|
|
local := localMetadata[localIndex]
|
|
remote := remoteMetadata[remoteIndex]
|
|
|
|
if local.OriginalID == remote.ID {
|
|
if local.Modno != remote.Modno {
|
|
requestIDs = append(requestIDs, local.OriginalID)
|
|
requestDueToModno++
|
|
}
|
|
localIndex++
|
|
remoteIndex++
|
|
} else if local.OriginalID < remote.ID {
|
|
// the ID is found locally not remotely, request it and see if we get a delete event back
|
|
// from the client indicating that the operation was deleted.
|
|
requestIDs = append(requestIDs, local.OriginalID)
|
|
localIndex++
|
|
requestMissingLocal++
|
|
} else {
|
|
// the ID is found remotely not locally, request it for initial sync.
|
|
requestIDs = append(requestIDs, remote.ID)
|
|
remoteIndex++
|
|
requestMissingRemote++
|
|
}
|
|
}
|
|
for localIndex < len(localMetadata) {
|
|
requestIDs = append(requestIDs, localMetadata[localIndex].OriginalID)
|
|
localIndex++
|
|
requestMissingLocal++
|
|
}
|
|
for remoteIndex < len(remoteMetadata) {
|
|
requestIDs = append(requestIDs, remoteMetadata[remoteIndex].ID)
|
|
remoteIndex++
|
|
requestMissingRemote++
|
|
}
|
|
|
|
zap.L().Debug("syncserver diff operations with client metadata",
|
|
zap.String("client_instance_id", clientInstanceID),
|
|
zap.Any("query", diffSelQuery),
|
|
zap.Int("request_due_to_modno", requestDueToModno),
|
|
zap.Int("request_local_but_not_remote", requestMissingLocal),
|
|
zap.Int("request_remote_but_not_local", requestMissingRemote),
|
|
zap.Int("request_ids_total", len(requestIDs)),
|
|
)
|
|
if len(requestIDs) > 0 {
|
|
zap.L().Debug("syncserver sending request operations to client", zap.String("client_instance_id", clientInstanceID), zap.Any("request_ids", requestIDs))
|
|
if err := stream.Send(&v1.SyncStreamItem{
|
|
Action: &v1.SyncStreamItem_DiffOperations{
|
|
DiffOperations: &v1.SyncStreamItem_SyncActionDiffOperations{
|
|
RequestOperations: requestIDs,
|
|
},
|
|
},
|
|
}); err != nil {
|
|
return fmt.Errorf("sending request operations: %w", err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
case *v1.SyncStreamItem_SendOperations:
|
|
switch event := action.SendOperations.GetEvent().Event.(type) {
|
|
case *v1.OperationEvent_CreatedOperations:
|
|
zap.L().Debug("syncserver received created operations", zap.Any("operations", event.CreatedOperations.GetOperations()))
|
|
for _, op := range event.CreatedOperations.GetOperations() {
|
|
if err := insertOrUpdate(op); err != nil {
|
|
return fmt.Errorf("action SendOperations: operation event create: %w", err)
|
|
}
|
|
}
|
|
case *v1.OperationEvent_UpdatedOperations:
|
|
zap.L().Debug("syncserver received update operations", zap.Any("operations", event.UpdatedOperations.GetOperations()))
|
|
for _, op := range event.UpdatedOperations.GetOperations() {
|
|
if err := insertOrUpdate(op); err != nil {
|
|
return fmt.Errorf("action SendOperations: operation event update: %w", err)
|
|
}
|
|
}
|
|
case *v1.OperationEvent_DeletedOperations:
|
|
zap.L().Debug("syncserver received delete operations", zap.Any("operations", event.DeletedOperations.GetValues()))
|
|
for _, id := range event.DeletedOperations.GetValues() {
|
|
if err := deleteByOriginalID(id); err != nil {
|
|
return fmt.Errorf("action SendOperations: operation event delete %d: %w", id, err)
|
|
}
|
|
}
|
|
case *v1.OperationEvent_KeepAlive:
|
|
default:
|
|
return connect.NewError(connect.CodeInvalidArgument, errors.New("action SendOperations: unknown event type"))
|
|
}
|
|
default:
|
|
return connect.NewError(connect.CodeInvalidArgument, errors.New("unknown action type"))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// subscribe to our own configuration for changes
|
|
configWatchCh := h.mgr.configMgr.Watch()
|
|
defer h.mgr.configMgr.StopWatching(configWatchCh)
|
|
sendConfigToClient(initialConfig)
|
|
|
|
for {
|
|
select {
|
|
case item, ok := <-receive:
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
if err := handleSyncCommand(item); err != nil {
|
|
return err
|
|
}
|
|
case sendItem, ok := <-send: // note: send channel should only be used when sending from a different goroutine than the main loop
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
if err := stream.Send(sendItem); err != nil {
|
|
return err
|
|
}
|
|
case <-configWatchCh:
|
|
newConfig, err := h.mgr.configMgr.Get()
|
|
if err != nil {
|
|
zap.S().Warnf("syncserver failed to get the newest config: %v", err)
|
|
continue
|
|
}
|
|
sendConfigToClient(newConfig)
|
|
case <-ctx.Done():
|
|
zap.S().Infof("syncserver client %q disconnected", authorizedClientPeer.InstanceId)
|
|
return ctx.Err()
|
|
}
|
|
}
|
|
}
|