diff --git a/server/channels/store/sqlstore/post_store.go b/server/channels/store/sqlstore/post_store.go index 341c0ddbda..61568bc7bd 100644 --- a/server/channels/store/sqlstore/post_store.go +++ b/server/channels/store/sqlstore/post_store.go @@ -2552,6 +2552,7 @@ func (s *SqlPostStore) determineMaxPostSize() int { } // GetMaxPostSize returns the maximum number of runes that may be stored in a post. +// For any changes, accordingly update the markdown maxLen here - markdown/inspect.go. func (s *SqlPostStore) GetMaxPostSize() int { s.maxPostSizeOnce.Do(func() { s.maxPostSizeCached = s.determineMaxPostSize() diff --git a/server/public/shared/markdown/inspect.go b/server/public/shared/markdown/inspect.go index b3b5ead33d..151b959024 100644 --- a/server/public/shared/markdown/inspect.go +++ b/server/public/shared/markdown/inspect.go @@ -3,9 +3,18 @@ package markdown +const ( + // Assuming 64k maxSize of a post which can be stored in DB. + // Allow scanning upto twice(arbitrary value) the post size. + maxLen = 1024 * 64 * 2 +) + // Inspect traverses the markdown tree in depth-first order. If f returns true, Inspect invokes f // recursively for each child of the block or inline, followed by a call of f(nil). func Inspect(markdown string, f func(any) bool) { + if len(markdown) > maxLen { + return + } document, referenceDefinitions := Parse(markdown) InspectBlock(document, func(block Block) bool { if !f(block) { diff --git a/server/public/shared/markdown/inspect_test.go b/server/public/shared/markdown/inspect_test.go index 8443ef6e8c..bda7a6489d 100644 --- a/server/public/shared/markdown/inspect_test.go +++ b/server/public/shared/markdown/inspect_test.go @@ -12,7 +12,8 @@ import ( ) func TestInspect(t *testing.T) { - markdown := ` + t.Run("base", func(t *testing.T) { + markdown := ` [foo]: bar - a > [![]()]() @@ -20,37 +21,76 @@ func TestInspect(t *testing.T) { - d ` - visited := []string{} - level := 0 - Inspect(markdown, func(blockOrInline any) bool { - if blockOrInline == nil { - level-- - } else { - visited = append(visited, strings.Repeat(" ", level*4)+strings.TrimPrefix(fmt.Sprintf("%T", blockOrInline), "*markdown.")) - level++ - } - return true + visited := []string{} + level := 0 + Inspect(markdown, func(blockOrInline any) bool { + if blockOrInline == nil { + level-- + } else { + visited = append(visited, strings.Repeat(" ", level*4)+strings.TrimPrefix(fmt.Sprintf("%T", blockOrInline), "*markdown.")) + level++ + } + return true + }) + + assert.Equal(t, []string{ + "Document", + " Paragraph", + " List", + " ListItem", + " Paragraph", + " Text", + " BlockQuote", + " Paragraph", + " InlineLink", + " InlineImage", + " SoftLineBreak", + " ReferenceLink", + " ReferenceImage", + " Text", + " ListItem", + " Paragraph", + " Text", + }, visited) }) - assert.Equal(t, []string{ - "Document", - " Paragraph", - " List", - " ListItem", - " Paragraph", - " Text", - " BlockQuote", - " Paragraph", - " InlineLink", - " InlineImage", - " SoftLineBreak", - " ReferenceLink", - " ReferenceImage", - " Text", - " ListItem", - " Paragraph", - " Text", - }, visited) + t.Run("visit nodes when len is smaller than maxLen", func(t *testing.T) { + n := maxLen / 5 + markdown := strings.Repeat(`![`, n) + strings.Repeat(`]()`, n) + + visited := []string{} + level := 0 + Inspect(markdown, func(blockOrInline any) bool { + if blockOrInline == nil { + level-- + } else { + visited = append(visited, strings.Repeat(" ", level*4)+strings.TrimPrefix(fmt.Sprintf("%T", blockOrInline), "*markdown.")) + level++ + } + return true + }) + + assert.NotEmpty(t, visited) + }) + + t.Run("do not visit any nodes when len is greater than maxLen", func(t *testing.T) { + n := (maxLen / 5) + 1 + markdown := strings.Repeat(`![`, n) + strings.Repeat(`]()`, n) + + visited := []string{} + level := 0 + Inspect(markdown, func(blockOrInline any) bool { + if blockOrInline == nil { + level-- + } else { + visited = append(visited, strings.Repeat(" ", level*4)+strings.TrimPrefix(fmt.Sprintf("%T", blockOrInline), "*markdown.")) + level++ + } + return true + }) + + assert.Empty(t, visited) + }) } var counterSink int