diff --git a/pkg/services/live/pipeline/pipeline.go b/pkg/services/live/pipeline/pipeline.go index 5ff17d5b086..f6a23eb02f0 100644 --- a/pkg/services/live/pipeline/pipeline.go +++ b/pkg/services/live/pipeline/pipeline.go @@ -2,6 +2,7 @@ package pipeline import ( "context" + "errors" "fmt" "os" @@ -135,7 +136,7 @@ func (p *Pipeline) ProcessInput(ctx context.Context, orgID int64, channelID stri if !ok { return false, nil } - err = p.processChannelFrames(ctx, orgID, channelID, channelFrames) + err = p.processChannelFrames(ctx, orgID, channelID, channelFrames, nil) if err != nil { return false, fmt.Errorf("error processing frame: %w", err) } @@ -170,35 +171,50 @@ func (p *Pipeline) dataToChannelFrames(ctx context.Context, rule LiveChannelRule return frames, true, nil } -func (p *Pipeline) processChannelFrames(ctx context.Context, orgID int64, channelID string, channelFrames []*ChannelFrame) error { +var errChannelRecursion = errors.New("channel recursion") + +func (p *Pipeline) processChannelFrames(ctx context.Context, orgID int64, channelID string, channelFrames []*ChannelFrame, visitedChannels map[string]struct{}) error { + if visitedChannels == nil { + visitedChannels = map[string]struct{}{} + } for _, channelFrame := range channelFrames { var processorChannel = channelID if channelFrame.Channel != "" { processorChannel = channelFrame.Channel } - err := p.processFrame(ctx, orgID, processorChannel, channelFrame.Frame) + if _, ok := visitedChannels[processorChannel]; ok { + return fmt.Errorf("%w: %s", errChannelRecursion, processorChannel) + } + visitedChannels[processorChannel] = struct{}{} + frames, err := p.processFrame(ctx, orgID, processorChannel, channelFrame.Frame) if err != nil { return err } + if len(frames) > 0 { + err := p.processChannelFrames(ctx, orgID, processorChannel, frames, visitedChannels) + if err != nil { + return err + } + } } return nil } -func (p *Pipeline) processFrame(ctx context.Context, orgID int64, channelID string, frame *data.Frame) error { +func (p *Pipeline) processFrame(ctx context.Context, orgID int64, channelID string, frame *data.Frame) ([]*ChannelFrame, error) { rule, ruleOk, err := p.ruleGetter.Get(orgID, channelID) if err != nil { logger.Error("Error getting rule", "error", err) - return err + return nil, err } if !ruleOk { logger.Debug("Rule not found", "channel", channelID) - return nil + return nil, err } ch, err := live.ParseChannel(channelID) if err != nil { logger.Error("Error parsing channel", "error", err, "channel", channelID) - return err + return nil, err } vars := ProcessorVars{ @@ -215,10 +231,10 @@ func (p *Pipeline) processFrame(ctx context.Context, orgID int64, channelID stri frame, err = rule.Processor.Process(ctx, vars, frame) if err != nil { logger.Error("Error processing frame", "error", err) - return err + return nil, err } if frame == nil { - return nil + return nil, nil } } @@ -230,15 +246,10 @@ func (p *Pipeline) processFrame(ctx context.Context, orgID int64, channelID stri frames, err := rule.Outputter.Output(ctx, outputVars, frame) if err != nil { logger.Error("Error outputting frame", "error", err) - return err - } - if len(frames) > 0 { - err := p.processChannelFrames(ctx, vars.OrgID, vars.Channel, frames) - if err != nil { - return err - } + return nil, err } + return frames, nil } - return nil + return nil, nil } diff --git a/pkg/services/live/pipeline/pipeline_test.go b/pkg/services/live/pipeline/pipeline_test.go index 0313a10a430..cf1b0c2e4ed 100644 --- a/pkg/services/live/pipeline/pipeline_test.go +++ b/pkg/services/live/pipeline/pipeline_test.go @@ -105,3 +105,25 @@ func TestPipeline_OutputError(t *testing.T) { _, err = p.ProcessInput(context.Background(), 1, "stream/test/xxx", []byte(`{}`)) require.ErrorIs(t, err, boomErr) } + +func TestPipeline_Recursion(t *testing.T) { + p, err := New(&testRuleGetter{ + rules: map[string]*LiveChannelRule{ + "stream/test/xxx": { + Converter: &testConverter{"", data.NewFrame("test")}, + Outputter: NewRedirectOutput(RedirectOutputConfig{ + Channel: "stream/test/yyy", + }), + }, + "stream/test/yyy": { + Converter: &testConverter{"", data.NewFrame("test")}, + Outputter: NewRedirectOutput(RedirectOutputConfig{ + Channel: "stream/test/xxx", + }), + }, + }, + }) + require.NoError(t, err) + _, err = p.ProcessInput(context.Background(), 1, "stream/test/xxx", []byte(`{}`)) + require.ErrorIs(t, err, errChannelRecursion) +}