diff --git a/internal/oplog/indexutil/indexutil.go b/internal/oplog/indexutil/indexutil.go index 12dbd96..ce84f66 100644 --- a/internal/oplog/indexutil/indexutil.go +++ b/internal/oplog/indexutil/indexutil.go @@ -22,17 +22,21 @@ func IndexRemoveByteValue(b *bolt.Bucket, value []byte, recordId int64) error { } // IndexSearchByteValue searches the index given a value and returns an iterator over the associated recordIds. -func IndexSearchByteValue(b *bolt.Bucket, value []byte) *IndexSearchIterator { +func IndexSearchByteValue(b *bolt.Bucket, value []byte) IndexIterator { return newSearchIterator(b, serializationutil.BytesToKey(value)) } +type IndexIterator interface { + Next() (int64, bool) +} + type IndexSearchIterator struct { c *bolt.Cursor k []byte prefix []byte } -func newSearchIterator(b *bolt.Bucket, prefix []byte) *IndexSearchIterator { +func newSearchIterator(b *bolt.Bucket, prefix []byte) IndexIterator { c := b.Cursor() k, _ := c.Seek(prefix) return &IndexSearchIterator{ @@ -55,24 +59,73 @@ func (i *IndexSearchIterator) Next() (int64, bool) { return id, true } -func (i *IndexSearchIterator) ToSlice() []int64 { - var ids []int64 - for id, ok := i.Next(); ok; id, ok = i.Next() { - ids = append(ids, id) - } - return ids +type JoinIterator struct { + iters []IndexIterator } -type Collector func(*IndexSearchIterator) []int64 +func NewJoinIterator(iters ...IndexIterator) *JoinIterator { + return &JoinIterator{ + iters: iters, + } +} + +func (j *JoinIterator) Next() (int64, bool) { + if len(j.iters) == 0 { + return 0, false + } + + nexts := make([]int64, len(j.iters)) + for idx, iter := range j.iters { + id, ok := iter.Next() + if !ok { + return 0, false + } + nexts[idx] = id + } + + for { + var ok bool + maxIdx := 0 + allSame := true + for idx, id := range nexts { + if id > nexts[maxIdx] { + maxIdx = idx + } + if id != nexts[0] { + allSame = false + } + } + + if allSame { + return nexts[0], true + } + + for idx, id := range nexts { + if id == nexts[maxIdx] { + continue + } + nexts[idx], ok = j.iters[idx].Next() + if !ok { + return 0, false + } + } + } +} + +type Collector func(IndexIterator) []int64 func CollectAll() Collector { - return func(iter *IndexSearchIterator) []int64 { - return iter.ToSlice() + return func(iter IndexIterator) []int64 { + ids := make([]int64, 0, 100) + for id, ok := iter.Next(); ok; id, ok = iter.Next() { + ids = append(ids, id) + } + return ids } } func CollectFirstN(firstN int) Collector { - return func(iter *IndexSearchIterator) []int64 { + return func(iter IndexIterator) []int64 { ids := make([]int64, 0, firstN) for id, ok := iter.Next(); ok && len(ids) < firstN; id, ok = iter.Next() { ids = append(ids, id) @@ -85,7 +138,7 @@ func CollectFirstN(firstN int) Collector { } func CollectLastN(lastN int) Collector { - return func(iter *IndexSearchIterator) []int64 { + return func(iter IndexIterator) []int64 { ids := make([]int64, lastN) count := 0 for id, ok := iter.Next(); ok; id, ok = iter.Next() { diff --git a/internal/oplog/indexutil/indexutil_test.go b/internal/oplog/indexutil/indexutil_test.go index d04eb07..6784c38 100644 --- a/internal/oplog/indexutil/indexutil_test.go +++ b/internal/oplog/indexutil/indexutil_test.go @@ -2,13 +2,14 @@ package indexutil import ( "fmt" + "reflect" "testing" "go.etcd.io/bbolt" ) func TestIndexing(t *testing.T) { - db, err := bbolt.Open(t.TempDir() + "/test.boltdb", 0600, nil) + db, err := bbolt.Open(t.TempDir()+"/test.boltdb", 0600, nil) if err != nil { t.Fatalf("error opening database: %s", err) } @@ -27,14 +28,14 @@ func TestIndexing(t *testing.T) { }); err != nil { t.Fatalf("db.Update error: %v", err) } - + if err := db.View(func(tx *bbolt.Tx) error { b := tx.Bucket([]byte("test")) - ids := IndexSearchByteValue(b, []byte("document")).ToSlice() + ids := CollectAll()(IndexSearchByteValue(b, []byte("document"))) if len(ids) != 100 { t.Errorf("want 100 ids, got %d", len(ids)) } - ids = IndexSearchByteValue(b, []byte("other")).ToSlice() + ids = CollectAll()(IndexSearchByteValue(b, []byte("other"))) if len(ids) != 0 { t.Errorf("want 0 ids, got %d", len(ids)) } @@ -43,3 +44,57 @@ func TestIndexing(t *testing.T) { t.Fatalf("db.View error: %v", err) } } + +func TestIndexJoin(t *testing.T) { + // Arrange + db, err := bbolt.Open(t.TempDir()+"/test.boltdb", 0600, nil) + if err != nil { + t.Fatalf("error opening database: %s", err) + } + + if err := db.Update(func(tx *bbolt.Tx) error { + b, err := tx.CreateBucket([]byte("test")) + if err != nil { + return fmt.Errorf("error creating bucket: %s", err) + } + for id := 0; id < 150; id += 1 { + if err := IndexByteValue(b, []byte("document"), int64(id)); err != nil { + return err + } + } + + for id := 0; id < 100; id += 2 { + if err := IndexByteValue(b, []byte("other"), int64(id)); err != nil { + return err + } + } + + return nil + }); err != nil { + t.Fatalf("db.Update error: %v", err) + } + + if err := db.View(func(tx *bbolt.Tx) error { + // Act + b := tx.Bucket([]byte("test")) + ids := CollectAll()(NewJoinIterator(IndexSearchByteValue(b, []byte("document")), IndexSearchByteValue(b, []byte("other")))) + + // Assert + if len(ids) != 50 { + t.Errorf("want 50 ids, got %d", len(ids)) + } + + wantIds := []int64{} + for id := 0; id < 100; id += 2 { + wantIds = append(wantIds, int64(id)) + } + + if !reflect.DeepEqual(ids, wantIds) { + t.Errorf("want %v, got %v", wantIds, ids) + } + + return nil + }); err != nil { + t.Fatalf("db.View error: %v", err) + } +}