live: prevent pipeline recursion (#39366)

This commit is contained in:
Alexander Emelin
2021-09-17 21:00:00 +03:00
committed by GitHub
parent 474461ba15
commit be2b08798b
2 changed files with 50 additions and 17 deletions

View File

@@ -2,6 +2,7 @@ package pipeline
import (
"context"
"errors"
"fmt"
"os"
@@ -135,7 +136,7 @@ func (p *Pipeline) ProcessInput(ctx context.Context, orgID int64, channelID stri
if !ok {
return false, nil
}
err = p.processChannelFrames(ctx, orgID, channelID, channelFrames)
err = p.processChannelFrames(ctx, orgID, channelID, channelFrames, nil)
if err != nil {
return false, fmt.Errorf("error processing frame: %w", err)
}
@@ -170,35 +171,50 @@ func (p *Pipeline) dataToChannelFrames(ctx context.Context, rule LiveChannelRule
return frames, true, nil
}
func (p *Pipeline) processChannelFrames(ctx context.Context, orgID int64, channelID string, channelFrames []*ChannelFrame) error {
var errChannelRecursion = errors.New("channel recursion")
func (p *Pipeline) processChannelFrames(ctx context.Context, orgID int64, channelID string, channelFrames []*ChannelFrame, visitedChannels map[string]struct{}) error {
if visitedChannels == nil {
visitedChannels = map[string]struct{}{}
}
for _, channelFrame := range channelFrames {
var processorChannel = channelID
if channelFrame.Channel != "" {
processorChannel = channelFrame.Channel
}
err := p.processFrame(ctx, orgID, processorChannel, channelFrame.Frame)
if _, ok := visitedChannels[processorChannel]; ok {
return fmt.Errorf("%w: %s", errChannelRecursion, processorChannel)
}
visitedChannels[processorChannel] = struct{}{}
frames, err := p.processFrame(ctx, orgID, processorChannel, channelFrame.Frame)
if err != nil {
return err
}
if len(frames) > 0 {
err := p.processChannelFrames(ctx, orgID, processorChannel, frames, visitedChannels)
if err != nil {
return err
}
}
}
return nil
}
func (p *Pipeline) processFrame(ctx context.Context, orgID int64, channelID string, frame *data.Frame) error {
func (p *Pipeline) processFrame(ctx context.Context, orgID int64, channelID string, frame *data.Frame) ([]*ChannelFrame, error) {
rule, ruleOk, err := p.ruleGetter.Get(orgID, channelID)
if err != nil {
logger.Error("Error getting rule", "error", err)
return err
return nil, err
}
if !ruleOk {
logger.Debug("Rule not found", "channel", channelID)
return nil
return nil, err
}
ch, err := live.ParseChannel(channelID)
if err != nil {
logger.Error("Error parsing channel", "error", err, "channel", channelID)
return err
return nil, err
}
vars := ProcessorVars{
@@ -215,10 +231,10 @@ func (p *Pipeline) processFrame(ctx context.Context, orgID int64, channelID stri
frame, err = rule.Processor.Process(ctx, vars, frame)
if err != nil {
logger.Error("Error processing frame", "error", err)
return err
return nil, err
}
if frame == nil {
return nil
return nil, nil
}
}
@@ -230,15 +246,10 @@ func (p *Pipeline) processFrame(ctx context.Context, orgID int64, channelID stri
frames, err := rule.Outputter.Output(ctx, outputVars, frame)
if err != nil {
logger.Error("Error outputting frame", "error", err)
return err
}
if len(frames) > 0 {
err := p.processChannelFrames(ctx, vars.OrgID, vars.Channel, frames)
if err != nil {
return err
}
return nil, err
}
return frames, nil
}
return nil
return nil, nil
}

View File

@@ -105,3 +105,25 @@ func TestPipeline_OutputError(t *testing.T) {
_, err = p.ProcessInput(context.Background(), 1, "stream/test/xxx", []byte(`{}`))
require.ErrorIs(t, err, boomErr)
}
func TestPipeline_Recursion(t *testing.T) {
p, err := New(&testRuleGetter{
rules: map[string]*LiveChannelRule{
"stream/test/xxx": {
Converter: &testConverter{"", data.NewFrame("test")},
Outputter: NewRedirectOutput(RedirectOutputConfig{
Channel: "stream/test/yyy",
}),
},
"stream/test/yyy": {
Converter: &testConverter{"", data.NewFrame("test")},
Outputter: NewRedirectOutput(RedirectOutputConfig{
Channel: "stream/test/xxx",
}),
},
},
})
require.NoError(t, err)
_, err = p.ProcessInput(context.Background(), 1, "stream/test/xxx", []byte(`{}`))
require.ErrorIs(t, err, errChannelRecursion)
}