mirror of
https://github.com/grafana/grafana.git
synced 2025-02-25 18:55:37 -06:00
live: prevent pipeline recursion (#39366)
This commit is contained in:
@@ -2,6 +2,7 @@ package pipeline
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
@@ -135,7 +136,7 @@ func (p *Pipeline) ProcessInput(ctx context.Context, orgID int64, channelID stri
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
err = p.processChannelFrames(ctx, orgID, channelID, channelFrames)
|
||||
err = p.processChannelFrames(ctx, orgID, channelID, channelFrames, nil)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("error processing frame: %w", err)
|
||||
}
|
||||
@@ -170,35 +171,50 @@ func (p *Pipeline) dataToChannelFrames(ctx context.Context, rule LiveChannelRule
|
||||
return frames, true, nil
|
||||
}
|
||||
|
||||
func (p *Pipeline) processChannelFrames(ctx context.Context, orgID int64, channelID string, channelFrames []*ChannelFrame) error {
|
||||
var errChannelRecursion = errors.New("channel recursion")
|
||||
|
||||
func (p *Pipeline) processChannelFrames(ctx context.Context, orgID int64, channelID string, channelFrames []*ChannelFrame, visitedChannels map[string]struct{}) error {
|
||||
if visitedChannels == nil {
|
||||
visitedChannels = map[string]struct{}{}
|
||||
}
|
||||
for _, channelFrame := range channelFrames {
|
||||
var processorChannel = channelID
|
||||
if channelFrame.Channel != "" {
|
||||
processorChannel = channelFrame.Channel
|
||||
}
|
||||
err := p.processFrame(ctx, orgID, processorChannel, channelFrame.Frame)
|
||||
if _, ok := visitedChannels[processorChannel]; ok {
|
||||
return fmt.Errorf("%w: %s", errChannelRecursion, processorChannel)
|
||||
}
|
||||
visitedChannels[processorChannel] = struct{}{}
|
||||
frames, err := p.processFrame(ctx, orgID, processorChannel, channelFrame.Frame)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(frames) > 0 {
|
||||
err := p.processChannelFrames(ctx, orgID, processorChannel, frames, visitedChannels)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Pipeline) processFrame(ctx context.Context, orgID int64, channelID string, frame *data.Frame) error {
|
||||
func (p *Pipeline) processFrame(ctx context.Context, orgID int64, channelID string, frame *data.Frame) ([]*ChannelFrame, error) {
|
||||
rule, ruleOk, err := p.ruleGetter.Get(orgID, channelID)
|
||||
if err != nil {
|
||||
logger.Error("Error getting rule", "error", err)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if !ruleOk {
|
||||
logger.Debug("Rule not found", "channel", channelID)
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ch, err := live.ParseChannel(channelID)
|
||||
if err != nil {
|
||||
logger.Error("Error parsing channel", "error", err, "channel", channelID)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
vars := ProcessorVars{
|
||||
@@ -215,10 +231,10 @@ func (p *Pipeline) processFrame(ctx context.Context, orgID int64, channelID stri
|
||||
frame, err = rule.Processor.Process(ctx, vars, frame)
|
||||
if err != nil {
|
||||
logger.Error("Error processing frame", "error", err)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
if frame == nil {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -230,15 +246,10 @@ func (p *Pipeline) processFrame(ctx context.Context, orgID int64, channelID stri
|
||||
frames, err := rule.Outputter.Output(ctx, outputVars, frame)
|
||||
if err != nil {
|
||||
logger.Error("Error outputting frame", "error", err)
|
||||
return err
|
||||
}
|
||||
if len(frames) > 0 {
|
||||
err := p.processChannelFrames(ctx, vars.OrgID, vars.Channel, frames)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return frames, nil
|
||||
}
|
||||
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
@@ -105,3 +105,25 @@ func TestPipeline_OutputError(t *testing.T) {
|
||||
_, err = p.ProcessInput(context.Background(), 1, "stream/test/xxx", []byte(`{}`))
|
||||
require.ErrorIs(t, err, boomErr)
|
||||
}
|
||||
|
||||
func TestPipeline_Recursion(t *testing.T) {
|
||||
p, err := New(&testRuleGetter{
|
||||
rules: map[string]*LiveChannelRule{
|
||||
"stream/test/xxx": {
|
||||
Converter: &testConverter{"", data.NewFrame("test")},
|
||||
Outputter: NewRedirectOutput(RedirectOutputConfig{
|
||||
Channel: "stream/test/yyy",
|
||||
}),
|
||||
},
|
||||
"stream/test/yyy": {
|
||||
Converter: &testConverter{"", data.NewFrame("test")},
|
||||
Outputter: NewRedirectOutput(RedirectOutputConfig{
|
||||
Channel: "stream/test/xxx",
|
||||
}),
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = p.ProcessInput(context.Background(), 1, "stream/test/xxx", []byte(`{}`))
|
||||
require.ErrorIs(t, err, errChannelRecursion)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user