Files
backrest/internal/api/syncapi/synchandler.go

384 lines
14 KiB
Go

package syncapi
import (
"context"
"errors"
"fmt"
"slices"
"sort"
"connectrpc.com/connect"
v1 "github.com/garethgeorge/backrest/gen/go/v1"
"github.com/garethgeorge/backrest/gen/go/v1/v1connect"
"github.com/garethgeorge/backrest/internal/config"
"github.com/garethgeorge/backrest/internal/oplog"
"github.com/garethgeorge/backrest/internal/protoutil"
lru "github.com/hashicorp/golang-lru/v2"
"go.uber.org/zap"
)
const SyncProtocolVersion = 1
type BackrestSyncHandler struct {
v1connect.UnimplementedBackrestSyncServiceHandler
mgr *SyncManager
}
var _ v1connect.BackrestSyncServiceHandler = &BackrestSyncHandler{}
func NewBackrestSyncHandler(mgr *SyncManager) *BackrestSyncHandler {
return &BackrestSyncHandler{
mgr: mgr,
}
}
func (h *BackrestSyncHandler) Sync(ctx context.Context, stream *connect.BidiStream[v1.SyncStreamItem, v1.SyncStreamItem]) error {
// TODO: this request can be very long lived, we must periodically refresh the config
// e.g. to disconnect a client if its access is revoked.
initialConfig, err := h.mgr.configMgr.Get()
if err != nil {
return err
}
receive := make(chan *v1.SyncStreamItem, 1)
send := make(chan *v1.SyncStreamItem, 1)
go func() {
for {
item, err := stream.Receive()
if err != nil {
break
}
receive <- item
}
close(receive)
}()
// Broadcast initial packet containing the protocol version and instance ID.
zap.S().Debugf("syncserver a client connected, broadcast handshake as %v", initialConfig.Instance)
if err := stream.Send(&v1.SyncStreamItem{
Action: &v1.SyncStreamItem_Handshake{
Handshake: &v1.SyncStreamItem_SyncActionHandshake{
ProtocolVersion: SyncProtocolVersion,
InstanceId: &v1.SignedMessage{
Payload: []byte(initialConfig.Instance),
Signature: []byte("TODO: inject a valid signature"),
Keyid: "TODO: inject a valid key ID",
},
},
},
}); err != nil {
return err
}
// Try to read the handshake packet from the client.
// TODO: perform this handshake in a header as a pre-flight before opening the stream.
clientInstanceID := ""
if msg, ok := <-receive; ok {
handshake := msg.GetHandshake()
if handshake == nil {
return connect.NewError(connect.CodeInvalidArgument, errors.New("handshake packet must be sent first"))
}
clientInstanceID = string(handshake.GetInstanceId().GetPayload())
if clientInstanceID == "" {
return connect.NewError(connect.CodeInvalidArgument, errors.New("instance ID is required"))
}
} else {
return connect.NewError(connect.CodeInvalidArgument, errors.New("no packets received"))
}
var authorizedClientPeer *v1.Multihost_Peer
authorizedClientPeerIdx := slices.IndexFunc(initialConfig.Multihost.GetAuthorizedClients(), func(peer *v1.Multihost_Peer) bool {
return peer.InstanceId == clientInstanceID
})
if authorizedClientPeerIdx == -1 {
// TODO: check the key signature of the handshake message here.
zap.S().Warnf("syncserver rejected a connection from client instance ID %q because it is not authorized", clientInstanceID)
return connect.NewError(connect.CodePermissionDenied, errors.New("client is not an authorized peer"))
} else {
authorizedClientPeer = initialConfig.Multihost.AuthorizedClients[authorizedClientPeerIdx]
}
zap.S().Infof("syncserver accepted a connection from client instance ID %q", authorizedClientPeer.InstanceId)
opIDLru, _ := lru.New[int64, int64](128) // original ID -> local ID
flowIDLru, _ := lru.New[int64, int64](128) // original flow ID -> local flow ID
insertOrUpdate := func(op *v1.Operation) error {
op.OriginalId = op.Id
op.OriginalFlowId = op.FlowId
var ok bool
if op.Id, ok = opIDLru.Get(op.OriginalId); !ok {
var foundOp *v1.Operation
if err := h.mgr.oplog.Query(oplog.Query{}.
SetOriginalID(op.OriginalId).
SetInstanceID(op.InstanceId), func(o *v1.Operation) error {
foundOp = o
return nil
}); err != nil {
return fmt.Errorf("mapping remote ID to local ID: %w", err)
}
if foundOp != nil {
op.Id = foundOp.Id
opIDLru.Add(foundOp.Id, foundOp.Id)
}
}
if op.FlowId, ok = flowIDLru.Get(op.OriginalFlowId); !ok {
var flowOp *v1.Operation
if err := h.mgr.oplog.Query(oplog.Query{}.
SetOriginalFlowID(op.OriginalFlowId).
SetInstanceID(op.InstanceId), func(o *v1.Operation) error {
flowOp = o
return nil
}); err != nil {
return fmt.Errorf("mapping remote flow ID to local ID: %w", err)
}
if flowOp != nil {
op.FlowId = flowOp.FlowId
flowIDLru.Add(op.OriginalFlowId, flowOp.FlowId)
}
}
return h.mgr.oplog.Set(op)
}
deleteByOriginalID := func(originalID int64) error {
var foundOp *v1.Operation
if err := h.mgr.oplog.Query(oplog.Query{}.SetOriginalID(originalID), func(o *v1.Operation) error {
foundOp = o
return nil
}); err != nil {
return fmt.Errorf("mapping remote ID to local ID: %w", err)
}
if foundOp == nil {
zap.S().Debugf("syncserver received delete for non-existent operation %v", originalID)
return nil
}
return h.mgr.oplog.Delete(foundOp.Id)
}
sendConfigToClient := func(config *v1.Config) error {
remoteConfig := &v1.RemoteConfig{}
var allowedRepoIDs []string
for _, repo := range config.Repos {
if slices.Contains(repo.AllowedPeerInstanceIds, clientInstanceID) {
allowedRepoIDs = append(allowedRepoIDs, repo.Id)
remoteConfig.Repos = append(remoteConfig.Repos, protoutil.RepoToRemoteRepo(repo))
}
}
zap.S().Debugf("syncserver determined client %v is allowlisted for repos %v", clientInstanceID, allowedRepoIDs)
// Send the config, this is the first meaningful packet the client will receive.
// Once configuration is received, the client will start sending diffs.
if err := stream.Send(&v1.SyncStreamItem{
Action: &v1.SyncStreamItem_SendConfig{
SendConfig: &v1.SyncStreamItem_SyncActionSendConfig{
Config: remoteConfig,
},
},
}); err != nil {
return fmt.Errorf("sending config to client: %w", err)
}
return nil
}
handleSyncCommand := func(item *v1.SyncStreamItem) error {
switch action := item.Action.(type) {
case *v1.SyncStreamItem_SendConfig:
return errors.New("clients can not push configs to server")
case *v1.SyncStreamItem_DiffOperations:
diffSel := action.DiffOperations.GetHaveOperationsSelector()
if diffSel == nil {
return connect.NewError(connect.CodeInvalidArgument, errors.New("action DiffOperations: selector is required"))
}
// The diff selector _must_ be scoped to the instance ID of the client.
if diffSel.GetInstanceId() != clientInstanceID {
return connect.NewError(connect.CodePermissionDenied, errors.New("action DiffOperations: instance ID mismatch in diff selector"))
}
// The diff selector _must_ specify a repo the client has access to
repo := config.FindRepoByGUID(initialConfig, diffSel.GetRepoGuid())
if repo == nil {
zap.S().Warnf("syncserver action DiffOperations: client %q tried to diff with repo %q that does not exist", clientInstanceID, diffSel.GetRepoGuid())
return connect.NewError(connect.CodePermissionDenied, fmt.Errorf("action DiffOperations: repo %q not found", diffSel.GetRepoGuid()))
}
if !slices.Contains(repo.GetAllowedPeerInstanceIds(), clientInstanceID) {
zap.S().Warnf("syncserver action DiffOperations: client %q tried to diff with repo %q that they are not allowed to access", clientInstanceID, repo.Id)
return connect.NewError(connect.CodePermissionDenied, fmt.Errorf("action DiffOperations: client is not allowed to access repo %q", repo.Id))
}
// These are required to be the same length for a pairwise zip.
if len(action.DiffOperations.HaveOperationIds) != len(action.DiffOperations.HaveOperationModnos) {
return connect.NewError(connect.CodeInvalidArgument, errors.New("action DiffOperations: operation IDs and modnos must be the same length"))
}
diffSelQuery, err := protoutil.OpSelectorToQuery(diffSel)
if err != nil {
return fmt.Errorf("action DiffOperations: converting diff selector to query: %w", err)
}
localMetadata := []oplog.OpMetadata{}
if err := h.mgr.oplog.QueryMetadata(diffSelQuery, func(metadata oplog.OpMetadata) error {
if metadata.OriginalID == 0 {
return nil // skip operations that didn't come from a remote
}
localMetadata = append(localMetadata, metadata)
return nil
}); err != nil {
return fmt.Errorf("action DiffOperations: querying local metadata: %w", err)
}
sort.Slice(localMetadata, func(i, j int) bool {
return localMetadata[i].OriginalID < localMetadata[j].OriginalID
})
remoteMetadata := make([]oplog.OpMetadata, len(action.DiffOperations.HaveOperationIds))
for i, id := range action.DiffOperations.HaveOperationIds {
remoteMetadata[i] = oplog.OpMetadata{
ID: id,
Modno: action.DiffOperations.HaveOperationModnos[i],
}
}
sort.Slice(remoteMetadata, func(i, j int) bool {
return remoteMetadata[i].ID < remoteMetadata[j].ID
})
requestDueToModno := 0
requestMissingRemote := 0
requestMissingLocal := 0
requestIDs := []int64{}
// This is a simple O(n) diff algorithm that compares the local and remote metadata vectors.
localIndex := 0
remoteIndex := 0
for localIndex < len(localMetadata) && remoteIndex < len(remoteMetadata) {
local := localMetadata[localIndex]
remote := remoteMetadata[remoteIndex]
if local.OriginalID == remote.ID {
if local.Modno != remote.Modno {
requestIDs = append(requestIDs, local.OriginalID)
requestDueToModno++
}
localIndex++
remoteIndex++
} else if local.OriginalID < remote.ID {
// the ID is found locally not remotely, request it and see if we get a delete event back
// from the client indicating that the operation was deleted.
requestIDs = append(requestIDs, local.OriginalID)
localIndex++
requestMissingLocal++
} else {
// the ID is found remotely not locally, request it for initial sync.
requestIDs = append(requestIDs, remote.ID)
remoteIndex++
requestMissingRemote++
}
}
for localIndex < len(localMetadata) {
requestIDs = append(requestIDs, localMetadata[localIndex].OriginalID)
localIndex++
requestMissingLocal++
}
for remoteIndex < len(remoteMetadata) {
requestIDs = append(requestIDs, remoteMetadata[remoteIndex].ID)
remoteIndex++
requestMissingRemote++
}
zap.L().Debug("syncserver diff operations with client metadata",
zap.String("client_instance_id", clientInstanceID),
zap.Any("query", diffSelQuery),
zap.Int("request_due_to_modno", requestDueToModno),
zap.Int("request_local_but_not_remote", requestMissingLocal),
zap.Int("request_remote_but_not_local", requestMissingRemote),
zap.Int("request_ids_total", len(requestIDs)),
)
if len(requestIDs) > 0 {
zap.L().Debug("syncserver sending request operations to client", zap.String("client_instance_id", clientInstanceID), zap.Any("request_ids", requestIDs))
if err := stream.Send(&v1.SyncStreamItem{
Action: &v1.SyncStreamItem_DiffOperations{
DiffOperations: &v1.SyncStreamItem_SyncActionDiffOperations{
RequestOperations: requestIDs,
},
},
}); err != nil {
return fmt.Errorf("sending request operations: %w", err)
}
}
return nil
case *v1.SyncStreamItem_SendOperations:
switch event := action.SendOperations.GetEvent().Event.(type) {
case *v1.OperationEvent_CreatedOperations:
zap.L().Debug("syncserver received created operations", zap.Any("operations", event.CreatedOperations.GetOperations()))
for _, op := range event.CreatedOperations.GetOperations() {
if err := insertOrUpdate(op); err != nil {
return fmt.Errorf("action SendOperations: operation event create: %w", err)
}
}
case *v1.OperationEvent_UpdatedOperations:
zap.L().Debug("syncserver received update operations", zap.Any("operations", event.UpdatedOperations.GetOperations()))
for _, op := range event.UpdatedOperations.GetOperations() {
if err := insertOrUpdate(op); err != nil {
return fmt.Errorf("action SendOperations: operation event update: %w", err)
}
}
case *v1.OperationEvent_DeletedOperations:
zap.L().Debug("syncserver received delete operations", zap.Any("operations", event.DeletedOperations.GetValues()))
for _, id := range event.DeletedOperations.GetValues() {
if err := deleteByOriginalID(id); err != nil {
return fmt.Errorf("action SendOperations: operation event delete %d: %w", id, err)
}
}
case *v1.OperationEvent_KeepAlive:
default:
return connect.NewError(connect.CodeInvalidArgument, errors.New("action SendOperations: unknown event type"))
}
default:
return connect.NewError(connect.CodeInvalidArgument, errors.New("unknown action type"))
}
return nil
}
// subscribe to our own configuration for changes
configWatchCh := h.mgr.configMgr.Watch()
defer h.mgr.configMgr.StopWatching(configWatchCh)
sendConfigToClient(initialConfig)
for {
select {
case item, ok := <-receive:
if !ok {
return nil
}
if err := handleSyncCommand(item); err != nil {
return err
}
case sendItem, ok := <-send: // note: send channel should only be used when sending from a different goroutine than the main loop
if !ok {
return nil
}
if err := stream.Send(sendItem); err != nil {
return err
}
case <-configWatchCh:
newConfig, err := h.mgr.configMgr.Get()
if err != nil {
zap.S().Warnf("syncserver failed to get the newest config: %v", err)
continue
}
sendConfigToClient(newConfig)
case <-ctx.Done():
zap.S().Infof("syncserver client %q disconnected", authorizedClientPeer.InstanceId)
return ctx.Err()
}
}
}