diff --git a/cmd/backrest/backrest.go b/cmd/backrest/backrest.go index 50ef87b3..34adb64f 100644 --- a/cmd/backrest/backrest.go +++ b/cmd/backrest/backrest.go @@ -19,6 +19,7 @@ import ( v1 "github.com/garethgeorge/backrest/gen/go/v1" "github.com/garethgeorge/backrest/gen/go/v1/v1connect" + "github.com/garethgeorge/backrest/gen/go/v1sync/v1syncconnect" "github.com/garethgeorge/backrest/internal/api" syncapi "github.com/garethgeorge/backrest/internal/api/syncapi" "github.com/garethgeorge/backrest/internal/auth" @@ -68,12 +69,11 @@ func runApp() { go onterm(os.Interrupt, newForceKillHandler()) // Create dependency components - configStore := createConfigStore() - cfg, err := configStore.Get() + configMgr := &config.ConfigManager{Store: createConfigStore()} + cfg, err := configMgr.Get() if err != nil { zap.L().Fatal("error loading config", zap.Error(err)) } - configMgr := &config.ConfigManager{Store: configStore} opLog, opLogStore, err := newOpLog(cfg) if err != nil { @@ -108,13 +108,7 @@ func runApp() { if err != nil { zap.L().Fatal("error creating peer state manager", zap.Error(err)) } -<<<<<<< HEAD syncMgr := syncapi.NewSyncManager(configMgr, opLog, orch, peerStateManager) - -======= - - syncMgr := syncapi.NewSyncManager(configMgr, opLog, logStore, orch, peerStateManager) ->>>>>>> 9041d3c (improve sync api security by using 'Authorization' headers for initial key exchange) authenticator := newAuthenticator(configMgr) // Start background services @@ -251,8 +245,8 @@ func newServer( func newRootMux( apiBackrestHandler v1connect.BackrestHandler, apiAuthenticationHandler v1connect.AuthenticationHandler, - syncHandler v1connect.BackrestSyncServiceHandler, - syncStateHandler v1connect.BackrestSyncStateServiceHandler, + syncHandler v1syncconnect.BackrestSyncServiceHandler, + syncStateHandler v1syncconnect.BackrestSyncStateServiceHandler, downloadHandler http.Handler, authenticator *auth.Authenticator, ) *http.ServeMux { @@ -260,7 +254,7 @@ func newRootMux( authedMux := http.NewServeMux() backrestPath, backrestHandler := v1connect.NewBackrestHandler(apiBackrestHandler) authedMux.Handle(backrestPath, backrestHandler) - syncStatePath, syncStateHandlerUnauthed := v1connect.NewBackrestSyncStateServiceHandler(syncStateHandler) + syncStatePath, syncStateHandlerUnauthed := v1syncconnect.NewBackrestSyncStateServiceHandler(syncStateHandler) authedMux.Handle(syncStatePath, syncStateHandlerUnauthed) authedMux.Handle("/download/", http.StripPrefix("/download", downloadHandler)) authedMux.Handle("/metrics", metric.GetRegistry().Handler()) @@ -269,7 +263,7 @@ func newRootMux( unauthedMux := http.NewServeMux() authPath, authHandler := v1connect.NewAuthenticationHandler(apiAuthenticationHandler) unauthedMux.Handle(authPath, authHandler) - syncPath, syncHandlerUnauthed := v1connect.NewBackrestSyncServiceHandler(syncHandler) + syncPath, syncHandlerUnauthed := v1syncconnect.NewBackrestSyncServiceHandler(syncHandler) unauthedMux.Handle(syncPath, syncHandlerUnauthed) // Root mux to dispatch to authenticated or unauthenticated handlers diff --git a/internal/api/syncapi/syncapi_test.go b/internal/api/syncapi/syncapi_test.go index ef9ee59e..e7d2d25e 100644 --- a/internal/api/syncapi/syncapi_test.go +++ b/internal/api/syncapi/syncapi_test.go @@ -370,24 +370,24 @@ func TestSimpleOperationSync(t *testing.T) { tryExpectExactOperations(t, ctx, peerHost, oplog.Query{}.SetInstanceID(defaultClientID).SetRepoGUID(defaultRepoGUID), testutil.OperationsWithDefaults(basicClientOperationTempl, []*v1.Operation{ { - Id: 3, // b/c of the already inserted host ops the sync'd ops start at 3 - FlowId: 3, + Id: 2, // b/c of the already inserted host ops the sync'd ops start at 3 + FlowId: 2, OriginalId: 1, OriginalFlowId: 1, OriginalInstanceKeyid: identity2.Keyid, DisplayMessage: "clientop1", }, { - Id: 4, - FlowId: 3, + Id: 3, + FlowId: 2, OriginalId: 2, OriginalFlowId: 1, OriginalInstanceKeyid: identity2.Keyid, DisplayMessage: "clientop2", }, { - Id: 5, - FlowId: 5, + Id: 4, + FlowId: 4, OriginalId: 3, OriginalFlowId: 2, OriginalInstanceKeyid: identity2.Keyid, @@ -584,6 +584,7 @@ func tryExpectOperationsSynced(t *testing.T, ctx context.Context, peer1 *peerUnd op.OriginalId = 0 op.OriginalFlowId = 0 op.OriginalInstanceKeyid = "" + op.Modno = 0 } for _, op := range peer2Ops { op.Id = 0 @@ -591,6 +592,7 @@ func tryExpectOperationsSynced(t *testing.T, ctx context.Context, peer1 *peerUnd op.OriginalId = 0 op.OriginalFlowId = 0 op.OriginalInstanceKeyid = "" + op.Modno = 0 } sortFn := func(a, b *v1.Operation) int { diff --git a/internal/api/syncapi/synccommon.go b/internal/api/syncapi/synccommon.go index 05ae527b..471f47b0 100644 --- a/internal/api/syncapi/synccommon.go +++ b/internal/api/syncapi/synccommon.go @@ -290,29 +290,61 @@ func newRemoteOpIDMapper(oplog *oplog.OpLog, cacheSize int) (*remoteOpIDMapper, }, nil } -// translateSingleID translates a single ID (either opID or flowID) using the provided cache and query -func (sh *remoteOpIDMapper) translateSingleID( - originalInstanceKeyid string, - originalID int64, - cache *lru.Cache[remoteOpIdCacheKey, int64], - query oplog.Query, -) (int64, error) { - if originalID == 0 { +// translateOpID translates a remote operation ID to a local one. +func (om *remoteOpIDMapper) translateOpID(originalInstanceKeyid string, originalOpId int64) (int64, error) { + if originalOpId == 0 { return 0, nil } cacheKey := remoteOpIdCacheKey{ OriginalInstanceKeyid: unique.Make(originalInstanceKeyid), - ID: originalID, + ID: originalOpId, } // Check cache first - if translatedID, ok := cache.Get(cacheKey); ok { + if translatedID, ok := om.opIDLru.Get(cacheKey); ok { return translatedID, nil } // Cache miss - query the database - op, err := sh.oplog.FindOneMetadata(query) + op, err := om.oplog.FindOneMetadata(oplog.Query{ + OriginalInstanceKeyid: &originalInstanceKeyid, + OriginalID: &originalOpId, + }) + if err != nil { + if errors.Is(err, oplog.ErrNoResults) { + return 0, nil // No results means the ID is not found + } + return 0, err // Other errors should be propagated + } + + // Cache the result and return + translatedID := op.ID + om.opIDLru.Add(cacheKey, translatedID) + return translatedID, nil +} + +// translateFlowID translates a remote flow ID to a local one. +func (om *remoteOpIDMapper) translateFlowID(originalInstanceKeyid string, originalFlowId int64) (int64, error) { + if originalFlowId == 0 { + return 0, nil + } + + cacheKey := remoteOpIdCacheKey{ + OriginalInstanceKeyid: unique.Make(originalInstanceKeyid), + ID: originalFlowId, + } + + // Check cache first + if translatedID, ok := om.flowIDLru.Get(cacheKey); ok { + return translatedID, nil + } + + // Cache miss - query the database + op, err := om.oplog.FindOneMetadata(oplog.Query{ + OriginalInstanceKeyid: &originalInstanceKeyid, + OriginalFlowID: &originalFlowId, + }) if err != nil { if errors.Is(err, oplog.ErrNoResults) { return 0, nil // No results means the ID is not found @@ -322,7 +354,7 @@ func (sh *remoteOpIDMapper) translateSingleID( // Cache the result and return translatedID := op.FlowID - cache.Add(cacheKey, translatedID) + om.flowIDLru.Add(cacheKey, translatedID) return translatedID, nil } @@ -331,29 +363,13 @@ func (om *remoteOpIDMapper) TranslateOpIdAndFlowID(originalInstanceKeyid string, defer om.opCacheMu.Unlock() // Translate opID - opID, err := om.translateSingleID( - originalInstanceKeyid, - originalOpId, - om.opIDLru, - oplog.Query{ - OriginalInstanceKeyid: &originalInstanceKeyid, - OriginalID: &originalOpId, - }, - ) + opID, err := om.translateOpID(originalInstanceKeyid, originalOpId) if err != nil { return 0, 0, err } // Translate flowID - flowID, err := om.translateSingleID( - originalInstanceKeyid, - originalFlowId, - om.flowIDLru, - oplog.Query{ - OriginalInstanceKeyid: &originalInstanceKeyid, - OriginalFlowID: &originalFlowId, - }, - ) + flowID, err := om.translateFlowID(originalInstanceKeyid, originalFlowId) if err != nil { return 0, 0, err } diff --git a/internal/api/syncapi/syncserver.go b/internal/api/syncapi/syncserver.go index b6fb82bb..ebd617b6 100644 --- a/internal/api/syncapi/syncserver.go +++ b/internal/api/syncapi/syncserver.go @@ -248,7 +248,7 @@ func (h *syncSessionHandlerServer) HandleReceiveResources(ctx context.Context, s func (h *syncSessionHandlerServer) insertOrUpdate(op *v1.Operation, isUpdate bool) error { // Returns a localOpID and localFlowID or 0 if not found in which case a new ID will be assigned by the insert. - localOpID, localFlowID, err := h.mapper.TranslateOpIdAndFlowID(op.OriginalInstanceKeyid, op.Id, op.FlowId) + localOpID, localFlowID, err := h.mapper.TranslateOpIdAndFlowID(h.peer.Keyid, op.Id, op.FlowId) if err != nil { return fmt.Errorf("translating operation ID and flow ID: %w", err) } @@ -257,13 +257,11 @@ func (h *syncSessionHandlerServer) insertOrUpdate(op *v1.Operation, isUpdate boo op.OriginalFlowId = op.FlowId op.Id = localOpID op.FlowId = localFlowID - if (op.Id == 0) != (op.FlowId == 0) { - return fmt.Errorf("inconsistent operation and flow ID mapping results: op.ID=%d, flow.ID=%d expected both to be 0 or both to be non-zero", op.Id, op.FlowId) - } if op.Id == 0 { if isUpdate { h.l.Sugar().Warnf("received update for non-existent operation %+v, inserting instead", op) } + op.Modno = 0 return h.mgr.oplog.Add(op) } else { if !isUpdate { diff --git a/internal/oplog/oplog.go b/internal/oplog/oplog.go index e2c918d9..179c36d4 100644 --- a/internal/oplog/oplog.go +++ b/internal/oplog/oplog.go @@ -135,14 +135,10 @@ func (o *OpLog) Get(opID int64) (*v1.Operation, error) { func (o *OpLog) Add(ops ...*v1.Operation) error { for _, o := range ops { - if o.Id != 0 { - return errors.New("operation already has an ID, OpLog.Add is expected to set the ID") - } - if o.Modno == 0 { - o.Modno = NewRandomModno(0) + if o.Id != 0 || o.Modno != 0 { + return errors.New("operation already has an ID or Modno, OpLog.Add is expected to set the ID/Modno") } } - if err := o.store.Add(ops...); err != nil { return err } @@ -156,7 +152,6 @@ func (o *OpLog) Update(ops ...*v1.Operation) error { if o.Id == 0 { return errors.New("operation does not have an ID, OpLog.Update is expected to have an ID") } - o.Modno = NewRandomModno(o.Modno) } if err := o.store.Update(ops...); err != nil { diff --git a/internal/oplog/randmodno.go b/internal/oplog/randmodno.go deleted file mode 100644 index 5a98b395..00000000 --- a/internal/oplog/randmodno.go +++ /dev/null @@ -1,24 +0,0 @@ -package oplog - -import ( - rand "math/rand/v2" - "sync" - - "github.com/garethgeorge/backrest/internal/cryptoutil" -) - -// setup a fast random number generator seeded with cryptographic randomness. -var mu sync.Mutex -var pgcRand = rand.NewPCG(cryptoutil.MustRandomUint64(), cryptoutil.MustRandomUint64()) -var randGen = rand.New(pgcRand) - -func NewRandomModno(lastModno int64) int64 { - mu.Lock() - defer mu.Unlock() - for { - modno := randGen.Int64() - if modno != lastModno { - return modno - } - } -} diff --git a/internal/oplog/sqlitestore/sqlitestore.go b/internal/oplog/sqlitestore/sqlitestore.go index b8016d99..0b908647 100644 --- a/internal/oplog/sqlitestore/sqlitestore.go +++ b/internal/oplog/sqlitestore/sqlitestore.go @@ -37,14 +37,14 @@ const ( ) type SqliteStore struct { - dbpool *sql.DB - lastIDVal atomic.Int64 - dblock *flock.Flock + dbpool *sql.DB + dblock *flock.Flock ogidCache *lru.TwoQueueCache[opGroupInfo, int64] kvstore kvstore.KvStore highestModno atomic.Int64 + highestOpID atomic.Int64 } var _ oplog.OpStore = (*SqliteStore)(nil) @@ -136,27 +136,15 @@ func (m *SqliteStore) init() error { return err } - var lastID int64 - err := m.dbpool.QueryRowContext(context.Background(), "SELECT operations.id FROM operations ORDER BY operations.id DESC LIMIT 1").Scan(&lastID) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - return fmt.Errorf("init sqlite: %v", err) - } - m.lastIDVal.Store(lastID) - - var highestModno int64 - err = m.dbpool.QueryRowContext(context.Background(), "SELECT operations.modno FROM operations ORDER BY operations.modno DESC LIMIT 1").Scan(&highestModno) - if err != nil && !errors.Is(err, sql.ErrNoRows) { - return fmt.Errorf("init sqlite: %v", err) + highestID, highestModno, err := m.GetHighestOpIDAndModno(oplog.Query{}.SetOriginalInstanceKeyid("")) + if err != nil { + return err } + m.highestOpID.Store(highestID) m.highestModno.Store(highestModno) - return nil } -func (m *SqliteStore) nextModno() int64 { - return m.highestModno.Add(1) -} - func (m *SqliteStore) GetHighestOpIDAndModno(q oplog.Query) (int64, int64, error) { var highestID sql.NullInt64 var highestModno sql.NullInt64 @@ -383,8 +371,6 @@ func (m *SqliteStore) Transform(q oplog.Query, f func(*v1.Operation) (*v1.Operat continue } - newOp.Modno = m.nextModno() - if err := m.updateInternal(tx, newOp); err != nil { return err } @@ -431,8 +417,8 @@ func (m *SqliteStore) Add(op ...*v1.Operation) error { defer tx.Rollback() for _, o := range op { - o.Id = m.lastIDVal.Add(1) - o.Modno = m.nextModno() + o.Id = m.highestOpID.Add(1) + o.Modno = m.highestModno.Add(1) if o.FlowId == 0 { o.FlowId = o.Id } @@ -464,7 +450,7 @@ func (m *SqliteStore) Update(op ...*v1.Operation) error { func (m *SqliteStore) updateInternal(tx *sql.Tx, op ...*v1.Operation) error { for _, o := range op { - o.Modno = m.nextModno() + o.Modno = m.highestModno.Add(1) if err := protoutil.ValidateOperation(o); err != nil { return err } @@ -595,7 +581,8 @@ func (m *SqliteStore) ResetForTest(t *testing.T) error { if err != nil { return fmt.Errorf("reset for test: %v", err) } - m.lastIDVal.Store(0) + m.highestOpID.Store(0) + m.highestModno.Store(0) return nil } diff --git a/internal/protoutil/validation.go b/internal/protoutil/validation.go index feea127a..63e2d0b8 100644 --- a/internal/protoutil/validation.go +++ b/internal/protoutil/validation.go @@ -10,6 +10,7 @@ import ( var ( errIDRequired = errors.New("id is required") + errModnoRequired = errors.New("modno is required") errFlowIDRequired = errors.New("flow_id is required") errRepoIDRequired = errors.New("repo_id is required") errRepoGUIDRequired = errors.New("repo_guid is required") @@ -23,6 +24,9 @@ func ValidateOperation(op *v1.Operation) error { if op.Id == 0 { return errIDRequired } + if op.Modno == 0 { + return errModnoRequired + } if op.RepoGuid == "" { return errRepoGUIDRequired }