diff --git a/internal/api/syncapi/peerstate.go b/internal/api/syncapi/peerstate.go index 10d8bd4f..daa52c91 100644 --- a/internal/api/syncapi/peerstate.go +++ b/internal/api/syncapi/peerstate.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" "maps" + "slices" "sync" "time" @@ -67,9 +68,9 @@ func peerStateToProto(state *PeerState) *v1sync.PeerState { LastHeartbeatMillis: state.LastHeartbeat.UnixMilli(), State: state.ConnectionState, StatusMessage: state.ConnectionStateMessage, - // KnownRepos: slices.Collect(maps.Keys(state.KnownRepos)), - // KnownPlans: slices.Collect(maps.Keys(state.KnownPlans)), - RemoteConfig: state.Config, + KnownRepos: slices.Collect(maps.Values(state.KnownRepos)), + KnownPlans: slices.Collect(maps.Values(state.KnownPlans)), + RemoteConfig: state.Config, } } @@ -234,12 +235,10 @@ func (m *SqlitePeerStateManager) SetPeerState(keyID string, state *PeerState) { stateBytes, err := proto.Marshal(stateProto) if err != nil { zap.S().Warnf("error marshalling peer state for key %s: %v", keyID, err) - return } if err := m.kvstore.Set(keyID, stateBytes); err != nil { zap.S().Warnf("error setting peer state for key %s: %v", keyID, err) - return } m.onStateChanged.Emit(state.Clone()) } diff --git a/internal/api/syncapi/peerstate_test.go b/internal/api/syncapi/peerstate_test.go index 3f8bc18b..0740f93a 100644 --- a/internal/api/syncapi/peerstate_test.go +++ b/internal/api/syncapi/peerstate_test.go @@ -7,6 +7,7 @@ import ( "github.com/garethgeorge/backrest/gen/go/v1sync" "github.com/garethgeorge/backrest/internal/kvstore" "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/testing/protocmp" ) func PeerStateManagersForTest(t testing.TB) map[string]PeerStateManager { @@ -31,9 +32,11 @@ func TestPeerStateManager_GetSet(t *testing.T) { t.Parallel() keyID := "testKey" state := &PeerState{ - InstanceID: "testInstance", - KeyID: keyID, - LastHeartbeat: time.Now().Round(time.Millisecond), + InstanceID: "testInstance", + KeyID: keyID, + LastHeartbeat: time.Now().Round(time.Millisecond), + ConnectionState: v1sync.ConnectionState_CONNECTION_STATE_CONNECTED, + ConnectionStateMessage: "hello world!", KnownRepos: map[string]*v1sync.RepoMetadata{ "repo1": { Id: "repo1", @@ -52,7 +55,7 @@ func TestPeerStateManager_GetSet(t *testing.T) { } psm.SetPeerState(keyID, state) gotState := psm.GetPeerState(keyID) - if diff := cmp.Diff(state, gotState, cmp.AllowUnexported(PeerState{})); diff != "" { + if diff := cmp.Diff(state, gotState, cmp.AllowUnexported(PeerState{}), protocmp.Transform()); diff != "" { t.Errorf("unexpected diff: %v", diff) } }) diff --git a/internal/api/syncapi/syncapi_test.go b/internal/api/syncapi/syncapi_test.go index 32f7a493..ef9ee59e 100644 --- a/internal/api/syncapi/syncapi_test.go +++ b/internal/api/syncapi/syncapi_test.go @@ -284,6 +284,7 @@ func TestSimpleOperationSync(t *testing.T) { { Id: defaultRepoID, Guid: defaultRepoGUID, + Uri: "test-uri", }, }, Multihost: &v1.Multihost{ @@ -336,12 +337,12 @@ func TestSimpleOperationSync(t *testing.T) { DisplayMessage: "hostop1", }, })...) - peerHost.oplog.Add(testutil.OperationsWithDefaults(basicClientOperationTempl, []*v1.Operation{ - { - DisplayMessage: "clientop-missing", - OriginalId: 1234, // must be an ID that doesn't exist remotely - }, - })...) + // peerHost.oplog.Add(testutil.OperationsWithDefaults(basicClientOperationTempl, []*v1.Operation{ + // { + // DisplayMessage: "clientop-deleted", + // OriginalId: 1234, // must be an ID that doesn't exist remotely + // }, + // })...) if err := peerClient.oplog.Add(testutil.OperationsWithDefaults(basicClientOperationTempl, []*v1.Operation{ { @@ -410,6 +411,7 @@ func TestSyncMutations(t *testing.T) { { Id: defaultRepoID, Guid: defaultRepoGUID, + Uri: "test-uri", }, }, Multihost: &v1.Multihost{ @@ -608,7 +610,7 @@ func tryExpectOperationsSynced(t *testing.T, ctx context.Context, peer1 *peerUnd return errors.New("no operations found in peer2") } if diff := cmp.Diff(peer1Ops, peer2Ops, protocmp.Transform()); diff != "" { - return fmt.Errorf("unexpected diff: %v", diff) + return fmt.Errorf("%s: unexpected diff: %v", message, diff) } return nil diff --git a/internal/api/syncapi/syncclient.go b/internal/api/syncapi/syncclient.go index c92f0165..83553b8b 100644 --- a/internal/api/syncapi/syncclient.go +++ b/internal/api/syncapi/syncclient.go @@ -196,7 +196,7 @@ func (c *syncSessionHandlerClient) applyPermissions() { } } for _, repo := range c.syncConfigSnapshot.config.Repos { - if c.permissions.CheckPermissionForRepo(repo.Guid, v1.Multihost_Permission_PERMISSION_READ_OPERATIONS) { + if c.permissions.CheckPermissionForRepo(repo.Id, v1.Multihost_Permission_PERMISSION_READ_OPERATIONS) { c.canForwardReposSet[repo.Guid] = struct{}{} } } @@ -429,7 +429,7 @@ func (c *syncSessionHandlerClient) HandleReceiveResources(ctx context.Context, s // Note unused: there isn't a situation where the host would send its config for information, the host will only call 'SetConfig' to update the config. func (c *syncSessionHandlerClient) HandleReceiveConfig(ctx context.Context, stream *bidiSyncCommandStream, item *v1sync.SyncStreamItem_SyncActionReceiveConfig) error { - c.l.Sugar().Debugf("received remote config update", zap.Any("config", item.GetConfig())) + c.l.Sugar().Debugf("received remote config update") peerState := c.mgr.peerStateManager.GetPeerState(c.peer.Keyid).Clone() if peerState == nil { return NewSyncErrorInternal(fmt.Errorf("peer state for %q not found", c.peer.Keyid)) @@ -506,20 +506,20 @@ func (c *syncSessionHandlerClient) HandleSetConfig(ctx context.Context, stream * } } - for _, repo := range item.GetReposToDelete() { - c.l.Sugar().Debugf("received repo deletion request: %s", repo) - if !c.permissions.CheckPermissionForRepo(repo, permissions.PermsCanWriteConfiguration...) { - return NewSyncErrorAuth(fmt.Errorf("peer %q is not allowed to delete repo %q", c.peer.InstanceId, repo)) + for _, repoID := range item.GetReposToDelete() { + c.l.Sugar().Debugf("received repo deletion request: %s", repoID) + if !c.permissions.CheckPermissionForRepo(repoID, permissions.PermsCanWriteConfiguration...) { + return NewSyncErrorAuth(fmt.Errorf("peer %q is not allowed to delete repo %q", c.peer.InstanceId, repoID)) } // Remove the repo from the local config idx := slices.IndexFunc(latestConfig.Repos, func(r *v1.Repo) bool { - return r.Id == repo + return r.Id == repoID }) if idx >= 0 { latestConfig.Repos = append(latestConfig.Repos[:idx], latestConfig.Repos[idx+1:]...) } else { - c.l.Sugar().Warnf("received repo deletion request for non-existent repo %q, ignoring", repo) + c.l.Sugar().Warnf("received repo deletion request for non-existent repo %q, ignoring", repoID) } } @@ -567,7 +567,7 @@ func (c *syncSessionHandlerClient) sendResourceList(ctx context.Context, stream planMetadatas := []*v1sync.PlanMetadata{} for _, repo := range c.syncConfigSnapshot.config.Repos { - if c.permissions.CheckPermissionForRepo(repo.Guid, permissions.PermsCanViewResources...) { + if c.permissions.CheckPermissionForRepo(repo.Id, permissions.PermsCanViewResources...) { repoMetadatas = append(repoMetadatas, &v1sync.RepoMetadata{ Id: repo.Id, Guid: repo.Guid, diff --git a/internal/api/syncapi/syncserver.go b/internal/api/syncapi/syncserver.go index 4a8b428d..b6fb82bb 100644 --- a/internal/api/syncapi/syncserver.go +++ b/internal/api/syncapi/syncserver.go @@ -335,7 +335,7 @@ func (h *syncSessionHandlerServer) sendConfigToClient(stream *bidiSyncCommandStr } func (h *syncSessionHandlerServer) sendOperationSyncRequest(stream *bidiSyncCommandStream) error { - highestID, highestModno, err := h.mgr.oplog.GetHighestOpIDAndModno() + highestID, highestModno, err := h.mgr.oplog.GetHighestOpIDAndModno(oplog.Query{}.SetOriginalInstanceKeyid(h.peer.Keyid)) if err != nil { return fmt.Errorf("getting highest opid and modno: %w", err) } diff --git a/internal/kvstore/sqlitedb.go b/internal/kvstore/sqlitedb.go index d01f4221..19cac03e 100644 --- a/internal/kvstore/sqlitedb.go +++ b/internal/kvstore/sqlitedb.go @@ -47,5 +47,12 @@ func NewInMemorySqliteDbForKvStore(t testing.TB) *sql.DB { if err != nil { t.Fatalf("failed to open db: %v", err) } + _, err = dbpool.ExecContext(context.Background(), ` + PRAGMA journal_mode = WAL; + PRAGMA synchronous = NORMAL; + `) + if err != nil { + t.Fatalf("failed to set pragmas: %v", err) + } return dbpool } diff --git a/internal/oplog/memstore/memstore.go b/internal/oplog/memstore/memstore.go index a8a8050b..6ca00c16 100644 --- a/internal/oplog/memstore/memstore.go +++ b/internal/oplog/memstore/memstore.go @@ -188,12 +188,15 @@ func (m *MemStore) Update(op ...*v1.Operation) error { return nil } -func (m *MemStore) GetHighestOpIDAndModno() (int64, int64, error) { +func (m *MemStore) GetHighestOpIDAndModno(q oplog.Query) (int64, int64, error) { m.mu.Lock() defer m.mu.Unlock() var highestID int64 var highestModno int64 for id, op := range m.operations { + if !q.Match(op) { + continue + } if id > highestID { highestID = id } diff --git a/internal/oplog/oplog.go b/internal/oplog/oplog.go index c92e39df..e2c918d9 100644 --- a/internal/oplog/oplog.go +++ b/internal/oplog/oplog.go @@ -181,8 +181,8 @@ func (o *OpLog) Transform(q Query, f func(*v1.Operation) (*v1.Operation, error)) return o.store.Transform(q, f) } -func (o *OpLog) GetHighestOpIDAndModno() (int64, int64, error) { - return o.store.GetHighestOpIDAndModno() +func (o *OpLog) GetHighestOpIDAndModno(q Query) (int64, int64, error) { + return o.store.GetHighestOpIDAndModno(q) } type OpStore interface { @@ -194,7 +194,7 @@ type OpStore interface { // Get returns the operation with the given ID. Get(opID int64) (*v1.Operation, error) // GetHighestOpIDAndModno returns the highest operation ID and modno in the store, used for synchronization. - GetHighestOpIDAndModno() (int64, int64, error) + GetHighestOpIDAndModno(q Query) (int64, int64, error) // Add adds the given operations to the store. Add(op ...*v1.Operation) error // Update updates the given operations in the store. diff --git a/internal/oplog/sqlitestore/sqlitestore.go b/internal/oplog/sqlitestore/sqlitestore.go index 73e7c4cc..b8016d99 100644 --- a/internal/oplog/sqlitestore/sqlitestore.go +++ b/internal/oplog/sqlitestore/sqlitestore.go @@ -157,8 +157,15 @@ func (m *SqliteStore) nextModno() int64 { return m.highestModno.Add(1) } -func (m *SqliteStore) GetHighestOpIDAndModno() (int64, int64, error) { - return m.lastIDVal.Load(), m.highestModno.Load(), nil +func (m *SqliteStore) GetHighestOpIDAndModno(q oplog.Query) (int64, int64, error) { + var highestID sql.NullInt64 + var highestModno sql.NullInt64 + where, args := m.buildQueryWhereClause(q, false) + row := m.dbpool.QueryRowContext(context.Background(), "SELECT MAX(operations.id), MAX(operations.modno) FROM operations JOIN operation_groups ON operations.ogid = operation_groups.ogid WHERE "+where, args...) + if err := row.Scan(&highestID, &highestModno); err != nil { + return 0, 0, err + } + return highestID.Int64, highestModno.Int64, nil } func (m *SqliteStore) Version() (int64, error) { diff --git a/internal/oplog/storetests/storecontract_test.go b/internal/oplog/storetests/storecontract_test.go index 1041c0b4..57f835af 100644 --- a/internal/oplog/storetests/storecontract_test.go +++ b/internal/oplog/storetests/storecontract_test.go @@ -796,6 +796,117 @@ func TestQueryMetadata(t *testing.T) { } } +func TestGetHighestOpIDAndModno(t *testing.T) { + t.Parallel() + for name, store := range StoresForTest(t) { + t.Run(name, func(t *testing.T) { + log, err := oplog.NewOpLog(store) + if err != nil { + t.Fatalf("error creating oplog: %v", err) + } + + t.Run("empty store", func(t *testing.T) { + highestID, highestModno, err := store.GetHighestOpIDAndModno(oplog.Query{}) + if err != nil { + t.Fatalf("error getting highest ID and modno: %v", err) + } + if highestID != 0 { + t.Errorf("expected highest ID 0, got %d", highestID) + } + if highestModno != 0 { + t.Errorf("expected highest modno 0, got %d", highestModno) + } + }) + + // Add operations with different plans and repos + ops := []*v1.Operation{ + { + UnixTimeStartMs: 1000, + PlanId: "plan1", + RepoId: "repo1", + RepoGuid: "repo1-guid", + InstanceId: "instance1", + Op: &v1.Operation_OperationBackup{}, + }, + { + UnixTimeStartMs: 2000, + PlanId: "plan1", + RepoId: "repo1", + RepoGuid: "repo1-guid", + InstanceId: "instance1", + Op: &v1.Operation_OperationBackup{}, + }, + { + UnixTimeStartMs: 3000, + PlanId: "plan2", + RepoId: "repo2", + RepoGuid: "repo2-guid", + InstanceId: "instance2", + Op: &v1.Operation_OperationBackup{}, + }, + } + + for _, op := range ops { + if err := log.Add(op); err != nil { + t.Fatalf("error adding operation: %v", err) + } + } + + t.Run("all operations", func(t *testing.T) { + highestID, highestModno, err := store.GetHighestOpIDAndModno(oplog.Query{}) + if err != nil { + t.Fatalf("error getting highest ID and modno: %v", err) + } + if highestID != ops[2].Id { + t.Errorf("expected highest ID %d, got %d", ops[2].Id, highestID) + } + if highestModno != ops[2].Modno { + t.Errorf("expected highest modno %d, got %d", ops[2].Modno, highestModno) + } + }) + + t.Run("filtered by plan", func(t *testing.T) { + highestID, highestModno, err := store.GetHighestOpIDAndModno(oplog.Query{}.SetPlanID("plan1")) + if err != nil { + t.Fatalf("error getting highest ID and modno: %v", err) + } + if highestID != ops[1].Id { + t.Errorf("expected highest ID %d, got %d", ops[1].Id, highestID) + } + if highestModno != ops[1].Modno { + t.Errorf("expected highest modno %d, got %d", ops[1].Modno, highestModno) + } + }) + + t.Run("filtered by repo", func(t *testing.T) { + highestID, highestModno, err := store.GetHighestOpIDAndModno(oplog.Query{}.SetRepoGUID("repo2-guid")) + if err != nil { + t.Fatalf("error getting highest ID and modno: %v", err) + } + if highestID != ops[2].Id { + t.Errorf("expected highest ID %d, got %d", ops[2].Id, highestID) + } + if highestModno != ops[2].Modno { + t.Errorf("expected highest modno %d, got %d", ops[2].Modno, highestModno) + } + }) + + t.Run("no matching operations", func(t *testing.T) { + highestID, highestModno, err := store.GetHighestOpIDAndModno(oplog.Query{}.SetPlanID("nonexistent")) + if err != nil { + t.Fatalf("error getting highest ID and modno: %v", err) + } + if highestID != 0 { + t.Errorf("expected highest ID 0, got %d", highestID) + } + if highestModno != 0 { + t.Errorf("expected highest modno 0, got %d", highestModno) + } + }) + }) + } +} + func collectMessages(ops []*v1.Operation) []string { var messages []string for _, op := range ops {