refactor message batching

This commit is contained in:
Evgeny Poberezkin 2023-12-21 21:55:29 +00:00
parent 5b8cb0743a
commit b26b03c922
4 changed files with 106 additions and 127 deletions

View File

@ -29,7 +29,7 @@ import Data.Bifunctor (bimap, first)
import Data.ByteArray (ScrubbedBytes)
import qualified Data.ByteArray as BA
import qualified Data.ByteString.Base64 as B64
import qualified Data.ByteString.Builder as Builder
import Data.ByteString.Builder (toLazyByteString)
import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as LB
@ -54,7 +54,7 @@ import Data.Time.Clock.System (systemToUTCTime)
import Data.Word (Word32)
import qualified Database.SQLite.Simple as SQL
import Simplex.Chat.Archive
import Simplex.Chat.ByteStringBatcher (BSBatch (..), batchByteStringObjects, partitionBatches)
import Simplex.Chat.ByteStringBatcher (MsgBatch (..), batchMessages)
import Simplex.Chat.Call
import Simplex.Chat.Controller
import Simplex.Chat.Files
@ -5617,17 +5617,17 @@ sendGroupMemberMessages user conn@Connection {connId} events groupId = do
when (connDisabled conn) $ throwChatError (CEConnectionDisabled conn)
(errs, msgs) <- partitionEithers <$> createSndMessages
unless (null errs) $ toView $ CRChatErrors Nothing errs
forM_ (L.nonEmpty msgs) $ \msgs' -> do
let (largeMsgs, msgBatches) = partitionBatches $ batchByteStringObjects maxChatMsgSize msgs'
-- shouldn't happen, as large messages would have caused createNewSndMessage to throw SELargeMsg
errs' = map (\SndMessage {msgId} -> ChatError $ CEInternalError ("large message " <> show msgId)) largeMsgs
unless (null msgs) $ do
let (errs', msgBatches) = partitionEithers $ batchMessages maxChatMsgSize msgs
-- errs' = map (\SndMessage {msgId} -> ChatError $ CEInternalError ("large message " <> show msgId)) largeMsgs
-- shouldn't happen, as large messages would have caused createNewSndMessage to throw SELargeMsg
unless (null errs') $ toView $ CRChatErrors Nothing errs'
forM_ msgBatches $ \batch ->
processBatch batch `catchChatError` (toView . CRChatError (Just user))
where
processBatch :: BSBatch SndMessage -> m ()
processBatch (BSBatch batchBuilder sndMsgs) = do
let batchBody = LB.toStrict $ Builder.toLazyByteString batchBuilder
processBatch :: MsgBatch -> m ()
processBatch (MsgBatch builder sndMsgs) = do
let batchBody = LB.toStrict $ toLazyByteString builder
agentMsgId <- withAgent $ \a -> sendMessage a (aConnId conn) MsgFlags {notification = True} batchBody
let sndMsgDelivery = SndMsgDelivery {connId, agentMsgId}
void . withStoreBatch' $ \db -> map (\SndMessage {msgId} -> createSndMsgDelivery db sndMsgDelivery msgId) sndMsgs

View File

@ -1,97 +1,49 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Simplex.Chat.ByteStringBatcher
( HasByteString (..),
BSBatch (..),
BSBatcherOutput (..),
batchByteStringObjects,
partitionBatches,
( MsgBatch (..),
batchMessages,
)
where
import qualified Data.ByteString.Builder as BB
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Lazy.Char8 as LB
import Data.ByteString.Builder (Builder, charUtf8, lazyByteString)
import qualified Data.ByteString.Lazy as LB
import Data.Int (Int64)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as L
import Simplex.Chat.Controller (ChatError (..), ChatErrorType (..))
import Simplex.Chat.Messages
class HasByteString a where
getByteString :: a -> L.ByteString
data MsgBatch = MsgBatch Builder [SndMessage]
deriving (Show)
instance HasByteString L.ByteString where
getByteString = id
data HasByteString a => BSBatch a = BSBatch BB.Builder [a]
data HasByteString a => BSBatcherOutput a
= BatcherOutputBatch (BSBatch a)
| BatcherOutputLarge a
-- | Batches instances of HasByteString into batches of ByteString builders in form of JSON arrays.
-- Does not check if the resulting batch is a valid JSON. If it is required,
-- getByteString should return ByteString encoded JSON object.
-- If a single element is passed, it is returned in form of JSON object instead.
-- If an element exceeds batchLenLimit, it is returned as BatcherOutputLarge.
batchByteStringObjects :: forall a. HasByteString a => Int64 -> NonEmpty a -> [BSBatcherOutput a]
batchByteStringObjects batchLenLimit = reverse . mkBatch []
-- | Batches [SndMessage] into batches of ByteString builders in form of JSON arrays.
-- Does not check if the resulting batch is a valid JSON.
-- If a single element is passed of fits the size, it is returned as is (a JSON string).
-- If an element exceeds maxLen, it is returned as ChatError.
batchMessages :: Int64 -> [SndMessage] -> [Either ChatError MsgBatch]
batchMessages maxLen msgs =
let (batches, batch, _, n) = foldr addToBatch ([], [], 0, 0) msgs
in if n == 0 then batches else msgBatch batch : batches
where
mkBatch :: [BSBatcherOutput a] -> NonEmpty a -> [BSBatcherOutput a]
mkBatch batches objs =
let (batch, objs_) = encodeBatch mempty 0 0 [] objs
batches' = batch : batches
in maybe batches' (mkBatch batches') objs_
encodeBatch :: BB.Builder -> Int64 -> Int -> [a] -> NonEmpty a -> (BSBatcherOutput a, Maybe (NonEmpty a))
encodeBatch builder len cnt batchedObjs remainingObjs@(obj :| objs_)
-- batched string fits
| len' <= maxSize' =
case L.nonEmpty objs_ of
Just objs' -> encodeBatch builder' len' cnt' batchedObjs' objs'
Nothing -> completeBatchLastStrFits
-- batched string doesn't fit
| cnt == 0 = (BatcherOutputLarge obj, L.nonEmpty objs_)
| otherwise = completeBatchStrDoesntFit
msgBatch batch = Right (MsgBatch (encodeMessages batch) batch)
addToBatch :: SndMessage -> ([Either ChatError MsgBatch], [SndMessage], Int64, Int) -> ([Either ChatError MsgBatch], [SndMessage], Int64, Int)
addToBatch msg@SndMessage {msgBody} (batches, batch, len, n)
| batchLen <= maxLen = (batches, msg : batch, len', n + 1)
| msgLen <= maxLen = (batches', [msg], msgLen, 1)
| otherwise = (errLarge msg : (if n == 0 then batches else batches'), [], 0, 0)
where
bStr = getByteString obj
cnt' = cnt + 1
(len', builder')
| cnt' == 1 =
( LB.length bStr, -- initially len = 0
BB.lazyByteString bStr
)
| cnt' == 2 =
( len + LB.length bStr + 2, -- for opening bracket "[" and comma ","
"[" <> builder <> "," <> BB.lazyByteString bStr
)
| otherwise =
( len + LB.length bStr + 1, -- for comma ","
builder <> "," <> BB.lazyByteString bStr
)
maxSize'
| cnt' == 1 = batchLenLimit
| otherwise = batchLenLimit - 1 -- for closing bracket "]"
batchedObjs' :: [a]
batchedObjs' = obj : batchedObjs
completeBatchLastStrFits :: (BSBatcherOutput a, Maybe (NonEmpty a))
completeBatchLastStrFits =
(BatcherOutputBatch $ BSBatch completeBuilder (reverse batchedObjs'), Nothing)
where
completeBuilder
| cnt' == 1 = builder' -- if last string fits, we look at current cnt'
| otherwise = builder' <> "]"
completeBatchStrDoesntFit :: (BSBatcherOutput a, Maybe (NonEmpty a))
completeBatchStrDoesntFit =
(BatcherOutputBatch $ BSBatch completeBuilder (reverse batchedObjs), Just remainingObjs)
where
completeBuilder
| cnt == 1 = builder -- if string doesn't fit, we look at previous cnt
| otherwise = builder <> "]"
msgLen = LB.length msgBody
batches' = msgBatch batch : batches
len' = msgLen + if n == 0 then 0 else len + 1 -- 1 accounts for comma
batchLen = len' + (if n == 0 then 0 else 2) -- 2 accounts for opening and closing brackets
errLarge SndMessage {msgId} = Left $ ChatError $ CEInternalError ("large message " <> show msgId)
-- | Partitions list of batcher outputs into lists of batches and large objects.
partitionBatches :: forall a. HasByteString a => [BSBatcherOutput a] -> ([a], [BSBatch a])
partitionBatches = foldr partition' ([], [])
encodeMessages :: [SndMessage] -> Builder
encodeMessages = \case
[] -> mempty
[msg] -> encodeMsg msg
(msg : msgs) -> charUtf8 '[' <> encodeMsg msg <> mconcat [charUtf8 ',' <> encodeMsg msg' | msg' <- msgs] <> charUtf8 ']'
where
partition' :: BSBatcherOutput a -> ([a], [BSBatch a]) -> ([a], [BSBatch a])
partition' (BatcherOutputBatch bStrBatch) (largeBStrs, bStrBatches) = (largeBStrs, bStrBatch : bStrBatches)
partition' (BatcherOutputLarge largeBStr) (largeBStrs, bStrBatches) = (largeBStr : largeBStrs, bStrBatches)
encodeMsg SndMessage {msgBody} = lazyByteString msgBody

View File

@ -35,7 +35,6 @@ import Data.Type.Equality
import Data.Typeable (Typeable)
import Database.SQLite.Simple.FromField (FromField (..))
import Database.SQLite.Simple.ToField (ToField (..))
import Simplex.Chat.ByteStringBatcher (HasByteString(..))
import Simplex.Chat.Markdown
import Simplex.Chat.Messages.CIContent
import Simplex.Chat.Protocol
@ -772,9 +771,7 @@ data SndMessage = SndMessage
sharedMsgId :: SharedMsgId,
msgBody :: LazyMsgBody
}
instance HasByteString SndMessage where
getByteString SndMessage {msgBody} = msgBody
deriving (Show)
data NewRcvMessage e = NewRcvMessage
{ chatMsgEvent :: ChatMsgEvent e,

View File

@ -1,13 +1,23 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# OPTIONS_GHC -Wno-orphans #-}
module ByteStringBatcherTests where
module ByteStringBatcherTests (byteStringBatcherTests) where
import qualified Data.ByteString.Builder as BB
import qualified Data.ByteString.Lazy as L
import Crypto.Number.Serialize (os2ip)
import Data.ByteString.Builder (toLazyByteString)
import qualified Data.ByteString.Lazy as LB
import Data.Either (partitionEithers)
import Data.Int (Int64)
import Data.List.NonEmpty (fromList)
import Data.String (IsString (..))
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import Simplex.Chat.ByteStringBatcher
import Simplex.Chat.Protocol (maxChatMsgSize)
import Simplex.Chat.Controller (ChatError (..), ChatErrorType (..))
import Simplex.Chat.Messages (SndMessage (..))
import Simplex.Chat.Protocol (SharedMsgId (..), maxChatMsgSize)
import Test.Hspec
byteStringBatcherTests :: Spec
@ -15,30 +25,46 @@ byteStringBatcherTests = describe "ByteStringBatcher tests" $ do
testBatchingCorrectness
it "image x.msg.new and x.msg.file.descr should fit into single batch" testImageFitsSingleBatch
instance IsString SndMessage where
fromString s = SndMessage {msgId, sharedMsgId = SharedMsgId "", msgBody = LB.fromStrict s'}
where
s' = encodeUtf8 $ T.pack s
msgId = fromInteger $ os2ip s'
deriving instance Eq SndMessage
instance IsString ChatError where
fromString s = ChatError $ CEInternalError ("large message " <> show msgId)
where
s' = encodeUtf8 $ T.pack s
msgId = fromInteger (os2ip s') :: Int64
testBatchingCorrectness :: Spec
testBatchingCorrectness = describe "correctness tests" $ do
runBatcherTest 8 ["a"] [] ["a"]
runBatcherTest 8 ["a", "b"] [] ["[a,b]"]
runBatcherTest 8 ["a", "b", "c"] [] ["[a,b,c]"]
runBatcherTest 8 ["a", "bb", "c"] [] ["[a,bb,c]"]
runBatcherTest 8 ["a", "b", "c", "d"] [] ["[a,b,c]", "d"]
runBatcherTest 8 ["a", "bb", "c", "d"] [] ["[a,bb,c]", "d"]
runBatcherTest 8 ["a", "bb", "c", "de"] [] ["[a,bb,c]", "de"]
runBatcherTest 8 ["a", "b", "c", "d", "e"] [] ["[a,b,c]", "[d,e]"]
runBatcherTest 8 ["a", "b", "c", "d"] [] ["a", "[b,c,d]"]
runBatcherTest 8 ["a", "bb", "c", "d"] [] ["a", "[bb,c,d]"]
runBatcherTest 8 ["a", "bb", "c", "de"] [] ["[a,bb]", "[c,de]"]
runBatcherTest 8 ["a", "b", "c", "d", "e"] [] ["[a,b]", "[c,d,e]"]
runBatcherTest 8 ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] [] ["a", "[b,c,d]", "[e,f,g]", "[h,i,j]"]
runBatcherTest 8 ["aaaaa"] [] ["aaaaa"]
runBatcherTest 8 ["8aaaaaaa"] [] ["8aaaaaaa"]
runBatcherTest 8 ["aaaa", "bbbb"] [] ["aaaa", "bbbb"]
runBatcherTest 8 ["aa", "bbb", "cc", "dd"] [] ["[aa,bbb]", "[cc,dd]"]
runBatcherTest 8 ["aa", "bbb", "cc", "dd", "eee", "fff", "gg", "hh"] [] ["aa", "[bbb,cc]", "[dd,eee]", "fff", "[gg,hh]"]
runBatcherTest 8 ["9aaaaaaaa"] ["9aaaaaaaa"] []
runBatcherTest 8 ["aaaaa", "bbb", "cc"] [] ["aaaaa", "[bbb,cc]"]
runBatcherTest 8 ["8aaaaaaa", "bbb", "cc"] [] ["8aaaaaaa", "[bbb,cc]"]
runBatcherTest 8 ["9aaaaaaaa", "bbb", "cc"] ["9aaaaaaaa"] ["[bbb,cc]"]
runBatcherTest 8 ["9aaaaaaaa", "bbb", "cc", "dd"] ["9aaaaaaaa"] ["[bbb,cc]", "dd"]
runBatcherTest 8 ["9aaaaaaaa", "bbb", "cc", "dd"] ["9aaaaaaaa"] ["bbb", "[cc,dd]"]
runBatcherTest 8 ["9aaaaaaaa", "bbb", "cc", "dd", "e"] ["9aaaaaaaa"] ["[bbb,cc]", "[dd,e]"]
runBatcherTest 8 ["bbb", "cc", "aaaaa"] [] ["[bbb,cc]", "aaaaa"]
runBatcherTest 8 ["bbb", "cc", "8aaaaaaa"] [] ["[bbb,cc]", "8aaaaaaa"]
runBatcherTest 8 ["bbb", "cc", "9aaaaaaaa"] ["9aaaaaaaa"] ["[bbb,cc]"]
runBatcherTest 8 ["bbb", "cc", "dd", "9aaaaaaaa"] ["9aaaaaaaa"] ["[bbb,cc]", "dd"]
runBatcherTest 8 ["bbb", "cc", "dd", "9aaaaaaaa"] ["9aaaaaaaa"] ["bbb", "[cc,dd]"]
runBatcherTest 8 ["bbb", "cc", "dd", "e", "9aaaaaaaa"] ["9aaaaaaaa"] ["[bbb,cc]", "[dd,e]"]
runBatcherTest 8 ["bbb", "cc", "aaaaa", "dd"] [] ["[bbb,cc]", "aaaaa", "dd"]
runBatcherTest 8 ["bbb", "cc", "aaaaa", "dd", "e"] [] ["[bbb,cc]", "aaaaa", "[dd,e]"]
@ -51,13 +77,13 @@ testBatchingCorrectness = describe "correctness tests" $ do
runBatcherTest 8 ["8aaaaaaa", "9aaaaaaaa", "10aaaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["8aaaaaaa"]
runBatcherTest 8 ["9aaaaaaaa", "8aaaaaaa", "10aaaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["8aaaaaaa"]
runBatcherTest 8 ["9aaaaaaaa", "10aaaaaaaa", "8aaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["8aaaaaaa"]
runBatcherTest 8 ["bb", "cc", "dd", "9aaaaaaaa", "10aaaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["[bb,cc]", "dd"]
runBatcherTest 8 ["bb", "cc", "dd", "9aaaaaaaa", "10aaaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["bb", "[cc,dd]"]
runBatcherTest 8 ["bb", "cc", "9aaaaaaaa", "dd", "10aaaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["[bb,cc]", "dd"]
runBatcherTest 8 ["bb", "9aaaaaaaa", "cc", "dd", "10aaaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["bb", "[cc,dd]"]
runBatcherTest 8 ["bb", "9aaaaaaaa", "cc", "10aaaaaaaa", "dd"] ["9aaaaaaaa", "10aaaaaaaa"] ["bb", "cc", "dd"]
runBatcherTest 8 ["9aaaaaaaa", "bb", "cc", "dd", "10aaaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["[bb,cc]", "dd"]
runBatcherTest 8 ["9aaaaaaaa", "bb", "cc", "dd", "10aaaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["bb", "[cc,dd]"]
runBatcherTest 8 ["9aaaaaaaa", "bb", "10aaaaaaaa", "cc", "dd"] ["9aaaaaaaa", "10aaaaaaaa"] ["bb", "[cc,dd]"]
runBatcherTest 8 ["9aaaaaaaa", "10aaaaaaaa", "bb", "cc", "dd"] ["9aaaaaaaa", "10aaaaaaaa"] ["[bb,cc]", "dd"]
runBatcherTest 8 ["9aaaaaaaa", "10aaaaaaaa", "bb", "cc", "dd"] ["9aaaaaaaa", "10aaaaaaaa"] ["bb", "[cc,dd]"]
testImageFitsSingleBatch :: IO ()
testImageFitsSingleBatch = do
@ -68,23 +94,27 @@ testImageFitsSingleBatch = do
-- 261_120 bytes (MAX_IMAGE_SIZE in UI), rounded up, example was 743
let descrRoundedSize = 800
let xMsgNewStr = L.replicate xMsgNewRoundedSize 1
descrStr = L.replicate descrRoundedSize 2
let xMsgNewStr = LB.replicate xMsgNewRoundedSize 1
descrStr = LB.replicate descrRoundedSize 2
msg s = SndMessage {msgId = 0, sharedMsgId = SharedMsgId "", msgBody = s}
batched = "[" <> xMsgNewStr <> "," <> descrStr <> "]"
runBatcherTest' maxChatMsgSize [msg xMsgNewStr, msg descrStr] [] [batched]
runBatcherTest' maxChatMsgSize [xMsgNewStr, descrStr] [] ["[" <> xMsgNewStr <> "," <> descrStr <> "]"]
runBatcherTest :: Int64 -> [L.ByteString] -> [L.ByteString] -> [L.ByteString] -> SpecWith ()
runBatcherTest batchLenLimit bStrs expectedLargeStrs expectedBatchedStrs =
runBatcherTest :: Int64 -> [SndMessage] -> [ChatError] -> [LB.ByteString] -> Spec
runBatcherTest maxLen msgs expectedErrors expectedBatches =
it
( (show bStrs <> ", limit " <> show batchLenLimit <> ": should return ")
<> (show (length expectedLargeStrs) <> " large, ")
<> (show (length expectedBatchedStrs) <> " batches")
( (show (map (\SndMessage {msgBody} -> msgBody) msgs) <> ", limit " <> show maxLen <> ": should return ")
<> (show (length expectedErrors) <> " large, ")
<> (show (length expectedBatches) <> " batches")
)
(runBatcherTest' batchLenLimit bStrs expectedLargeStrs expectedBatchedStrs)
(runBatcherTest' maxLen msgs expectedErrors expectedBatches)
runBatcherTest' :: Int64 -> [L.ByteString] -> [L.ByteString] -> [L.ByteString] -> IO ()
runBatcherTest' batchLenLimit bStrs expectedLargeStrs expectedBatchedStrs = do
let (largeStrs, batches) = partitionBatches $ batchByteStringObjects batchLenLimit (fromList bStrs)
batchedStrs = map (\(BSBatch batchBuilder _) -> BB.toLazyByteString batchBuilder) batches
largeStrs `shouldBe` expectedLargeStrs
batchedStrs `shouldBe` expectedBatchedStrs
runBatcherTest' :: Int64 -> [SndMessage] -> [ChatError] -> [LB.ByteString] -> IO ()
runBatcherTest' maxLen msgs expectedErrors expectedBatches = do
let (errors, batches) = partitionEithers $ batchMessages maxLen msgs
batchedStrs = map (\(MsgBatch builder _) -> toLazyByteString builder) batches
testErrors errors `shouldBe` testErrors expectedErrors
batchedStrs `shouldBe` expectedBatches
where
testErrors = map (\case ChatError (CEInternalError s) -> Just s; _ -> Nothing)