diff --git a/dag/set.go b/dag/set.go index 3929c9d0e9..92b42151d7 100644 --- a/dag/set.go +++ b/dag/set.go @@ -81,6 +81,20 @@ func (s *Set) Difference(other *Set) *Set { return result } +// Filter returns a set that contains the elements from the receiver +// where the given callback returns true. +func (s *Set) Filter(cb func(interface{}) bool) *Set { + result := new(Set) + + for _, v := range s.m { + if cb(v) { + result.Add(v) + } + } + + return result +} + // Len is the number of items in the set. func (s *Set) Len() int { if s == nil { diff --git a/dag/set_test.go b/dag/set_test.go index 8aeae70732..c70da475eb 100644 --- a/dag/set_test.go +++ b/dag/set_test.go @@ -54,3 +54,45 @@ func TestSetDifference(t *testing.T) { }) } } + +func TestSetFilter(t *testing.T) { + cases := []struct { + Input []interface{} + Expected []interface{} + }{ + { + []interface{}{1, 2, 3}, + []interface{}{1, 2, 3}, + }, + + { + []interface{}{4, 5, 6}, + []interface{}{4}, + }, + + { + []interface{}{7, 8, 9}, + []interface{}{}, + }, + } + + for i, tc := range cases { + t.Run(fmt.Sprintf("%d-%#v", i, tc.Input), func(t *testing.T) { + var input, expected Set + for _, v := range tc.Input { + input.Add(v) + } + for _, v := range tc.Expected { + expected.Add(v) + } + + actual := input.Filter(func(v interface{}) bool { + return v.(int) < 5 + }) + match := actual.Intersection(&expected) + if match.Len() != expected.Len() { + t.Fatalf("bad: %#v", actual.List()) + } + }) + } +}