refactor message batching

This commit is contained in:
Evgeny Poberezkin 2023-12-21 21:55:29 +00:00
parent 5b8cb0743a
commit b26b03c922
4 changed files with 106 additions and 127 deletions

View File

@ -29,7 +29,7 @@ import Data.Bifunctor (bimap, first)
import Data.ByteArray (ScrubbedBytes) import Data.ByteArray (ScrubbedBytes)
import qualified Data.ByteArray as BA import qualified Data.ByteArray as BA
import qualified Data.ByteString.Base64 as B64 import qualified Data.ByteString.Base64 as B64
import qualified Data.ByteString.Builder as Builder import Data.ByteString.Builder (toLazyByteString)
import Data.ByteString.Char8 (ByteString) import Data.ByteString.Char8 (ByteString)
import qualified Data.ByteString.Char8 as B import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy.Char8 as LB import qualified Data.ByteString.Lazy.Char8 as LB
@ -54,7 +54,7 @@ import Data.Time.Clock.System (systemToUTCTime)
import Data.Word (Word32) import Data.Word (Word32)
import qualified Database.SQLite.Simple as SQL import qualified Database.SQLite.Simple as SQL
import Simplex.Chat.Archive import Simplex.Chat.Archive
import Simplex.Chat.ByteStringBatcher (BSBatch (..), batchByteStringObjects, partitionBatches) import Simplex.Chat.ByteStringBatcher (MsgBatch (..), batchMessages)
import Simplex.Chat.Call import Simplex.Chat.Call
import Simplex.Chat.Controller import Simplex.Chat.Controller
import Simplex.Chat.Files import Simplex.Chat.Files
@ -5617,17 +5617,17 @@ sendGroupMemberMessages user conn@Connection {connId} events groupId = do
when (connDisabled conn) $ throwChatError (CEConnectionDisabled conn) when (connDisabled conn) $ throwChatError (CEConnectionDisabled conn)
(errs, msgs) <- partitionEithers <$> createSndMessages (errs, msgs) <- partitionEithers <$> createSndMessages
unless (null errs) $ toView $ CRChatErrors Nothing errs unless (null errs) $ toView $ CRChatErrors Nothing errs
forM_ (L.nonEmpty msgs) $ \msgs' -> do unless (null msgs) $ do
let (largeMsgs, msgBatches) = partitionBatches $ batchByteStringObjects maxChatMsgSize msgs' let (errs', msgBatches) = partitionEithers $ batchMessages maxChatMsgSize msgs
-- shouldn't happen, as large messages would have caused createNewSndMessage to throw SELargeMsg -- errs' = map (\SndMessage {msgId} -> ChatError $ CEInternalError ("large message " <> show msgId)) largeMsgs
errs' = map (\SndMessage {msgId} -> ChatError $ CEInternalError ("large message " <> show msgId)) largeMsgs -- shouldn't happen, as large messages would have caused createNewSndMessage to throw SELargeMsg
unless (null errs') $ toView $ CRChatErrors Nothing errs' unless (null errs') $ toView $ CRChatErrors Nothing errs'
forM_ msgBatches $ \batch -> forM_ msgBatches $ \batch ->
processBatch batch `catchChatError` (toView . CRChatError (Just user)) processBatch batch `catchChatError` (toView . CRChatError (Just user))
where where
processBatch :: BSBatch SndMessage -> m () processBatch :: MsgBatch -> m ()
processBatch (BSBatch batchBuilder sndMsgs) = do processBatch (MsgBatch builder sndMsgs) = do
let batchBody = LB.toStrict $ Builder.toLazyByteString batchBuilder let batchBody = LB.toStrict $ toLazyByteString builder
agentMsgId <- withAgent $ \a -> sendMessage a (aConnId conn) MsgFlags {notification = True} batchBody agentMsgId <- withAgent $ \a -> sendMessage a (aConnId conn) MsgFlags {notification = True} batchBody
let sndMsgDelivery = SndMsgDelivery {connId, agentMsgId} let sndMsgDelivery = SndMsgDelivery {connId, agentMsgId}
void . withStoreBatch' $ \db -> map (\SndMessage {msgId} -> createSndMsgDelivery db sndMsgDelivery msgId) sndMsgs void . withStoreBatch' $ \db -> map (\SndMessage {msgId} -> createSndMsgDelivery db sndMsgDelivery msgId) sndMsgs

View File

@ -1,97 +1,49 @@
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
module Simplex.Chat.ByteStringBatcher module Simplex.Chat.ByteStringBatcher
( HasByteString (..), ( MsgBatch (..),
BSBatch (..), batchMessages,
BSBatcherOutput (..),
batchByteStringObjects,
partitionBatches,
) )
where where
import qualified Data.ByteString.Builder as BB import Data.ByteString.Builder (Builder, charUtf8, lazyByteString)
import qualified Data.ByteString.Lazy as L import qualified Data.ByteString.Lazy as LB
import qualified Data.ByteString.Lazy.Char8 as LB
import Data.Int (Int64) import Data.Int (Int64)
import Data.List.NonEmpty (NonEmpty (..)) import Simplex.Chat.Controller (ChatError (..), ChatErrorType (..))
import qualified Data.List.NonEmpty as L import Simplex.Chat.Messages
class HasByteString a where data MsgBatch = MsgBatch Builder [SndMessage]
getByteString :: a -> L.ByteString deriving (Show)
instance HasByteString L.ByteString where -- | Batches [SndMessage] into batches of ByteString builders in form of JSON arrays.
getByteString = id -- Does not check if the resulting batch is a valid JSON.
-- If a single element is passed of fits the size, it is returned as is (a JSON string).
data HasByteString a => BSBatch a = BSBatch BB.Builder [a] -- If an element exceeds maxLen, it is returned as ChatError.
batchMessages :: Int64 -> [SndMessage] -> [Either ChatError MsgBatch]
data HasByteString a => BSBatcherOutput a batchMessages maxLen msgs =
= BatcherOutputBatch (BSBatch a) let (batches, batch, _, n) = foldr addToBatch ([], [], 0, 0) msgs
| BatcherOutputLarge a in if n == 0 then batches else msgBatch batch : batches
-- | Batches instances of HasByteString into batches of ByteString builders in form of JSON arrays.
-- Does not check if the resulting batch is a valid JSON. If it is required,
-- getByteString should return ByteString encoded JSON object.
-- If a single element is passed, it is returned in form of JSON object instead.
-- If an element exceeds batchLenLimit, it is returned as BatcherOutputLarge.
batchByteStringObjects :: forall a. HasByteString a => Int64 -> NonEmpty a -> [BSBatcherOutput a]
batchByteStringObjects batchLenLimit = reverse . mkBatch []
where where
mkBatch :: [BSBatcherOutput a] -> NonEmpty a -> [BSBatcherOutput a] msgBatch batch = Right (MsgBatch (encodeMessages batch) batch)
mkBatch batches objs = addToBatch :: SndMessage -> ([Either ChatError MsgBatch], [SndMessage], Int64, Int) -> ([Either ChatError MsgBatch], [SndMessage], Int64, Int)
let (batch, objs_) = encodeBatch mempty 0 0 [] objs addToBatch msg@SndMessage {msgBody} (batches, batch, len, n)
batches' = batch : batches | batchLen <= maxLen = (batches, msg : batch, len', n + 1)
in maybe batches' (mkBatch batches') objs_ | msgLen <= maxLen = (batches', [msg], msgLen, 1)
encodeBatch :: BB.Builder -> Int64 -> Int -> [a] -> NonEmpty a -> (BSBatcherOutput a, Maybe (NonEmpty a)) | otherwise = (errLarge msg : (if n == 0 then batches else batches'), [], 0, 0)
encodeBatch builder len cnt batchedObjs remainingObjs@(obj :| objs_)
-- batched string fits
| len' <= maxSize' =
case L.nonEmpty objs_ of
Just objs' -> encodeBatch builder' len' cnt' batchedObjs' objs'
Nothing -> completeBatchLastStrFits
-- batched string doesn't fit
| cnt == 0 = (BatcherOutputLarge obj, L.nonEmpty objs_)
| otherwise = completeBatchStrDoesntFit
where where
bStr = getByteString obj msgLen = LB.length msgBody
cnt' = cnt + 1 batches' = msgBatch batch : batches
(len', builder') len' = msgLen + if n == 0 then 0 else len + 1 -- 1 accounts for comma
| cnt' == 1 = batchLen = len' + (if n == 0 then 0 else 2) -- 2 accounts for opening and closing brackets
( LB.length bStr, -- initially len = 0 errLarge SndMessage {msgId} = Left $ ChatError $ CEInternalError ("large message " <> show msgId)
BB.lazyByteString bStr
)
| cnt' == 2 =
( len + LB.length bStr + 2, -- for opening bracket "[" and comma ","
"[" <> builder <> "," <> BB.lazyByteString bStr
)
| otherwise =
( len + LB.length bStr + 1, -- for comma ","
builder <> "," <> BB.lazyByteString bStr
)
maxSize'
| cnt' == 1 = batchLenLimit
| otherwise = batchLenLimit - 1 -- for closing bracket "]"
batchedObjs' :: [a]
batchedObjs' = obj : batchedObjs
completeBatchLastStrFits :: (BSBatcherOutput a, Maybe (NonEmpty a))
completeBatchLastStrFits =
(BatcherOutputBatch $ BSBatch completeBuilder (reverse batchedObjs'), Nothing)
where
completeBuilder
| cnt' == 1 = builder' -- if last string fits, we look at current cnt'
| otherwise = builder' <> "]"
completeBatchStrDoesntFit :: (BSBatcherOutput a, Maybe (NonEmpty a))
completeBatchStrDoesntFit =
(BatcherOutputBatch $ BSBatch completeBuilder (reverse batchedObjs), Just remainingObjs)
where
completeBuilder
| cnt == 1 = builder -- if string doesn't fit, we look at previous cnt
| otherwise = builder <> "]"
-- | Partitions list of batcher outputs into lists of batches and large objects. encodeMessages :: [SndMessage] -> Builder
partitionBatches :: forall a. HasByteString a => [BSBatcherOutput a] -> ([a], [BSBatch a]) encodeMessages = \case
partitionBatches = foldr partition' ([], []) [] -> mempty
[msg] -> encodeMsg msg
(msg : msgs) -> charUtf8 '[' <> encodeMsg msg <> mconcat [charUtf8 ',' <> encodeMsg msg' | msg' <- msgs] <> charUtf8 ']'
where where
partition' :: BSBatcherOutput a -> ([a], [BSBatch a]) -> ([a], [BSBatch a]) encodeMsg SndMessage {msgBody} = lazyByteString msgBody
partition' (BatcherOutputBatch bStrBatch) (largeBStrs, bStrBatches) = (largeBStrs, bStrBatch : bStrBatches)
partition' (BatcherOutputLarge largeBStr) (largeBStrs, bStrBatches) = (largeBStr : largeBStrs, bStrBatches)

View File

@ -35,7 +35,6 @@ import Data.Type.Equality
import Data.Typeable (Typeable) import Data.Typeable (Typeable)
import Database.SQLite.Simple.FromField (FromField (..)) import Database.SQLite.Simple.FromField (FromField (..))
import Database.SQLite.Simple.ToField (ToField (..)) import Database.SQLite.Simple.ToField (ToField (..))
import Simplex.Chat.ByteStringBatcher (HasByteString(..))
import Simplex.Chat.Markdown import Simplex.Chat.Markdown
import Simplex.Chat.Messages.CIContent import Simplex.Chat.Messages.CIContent
import Simplex.Chat.Protocol import Simplex.Chat.Protocol
@ -772,9 +771,7 @@ data SndMessage = SndMessage
sharedMsgId :: SharedMsgId, sharedMsgId :: SharedMsgId,
msgBody :: LazyMsgBody msgBody :: LazyMsgBody
} }
deriving (Show)
instance HasByteString SndMessage where
getByteString SndMessage {msgBody} = msgBody
data NewRcvMessage e = NewRcvMessage data NewRcvMessage e = NewRcvMessage
{ chatMsgEvent :: ChatMsgEvent e, { chatMsgEvent :: ChatMsgEvent e,

View File

@ -1,13 +1,23 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# OPTIONS_GHC -Wno-orphans #-}
module ByteStringBatcherTests where module ByteStringBatcherTests (byteStringBatcherTests) where
import qualified Data.ByteString.Builder as BB import Crypto.Number.Serialize (os2ip)
import qualified Data.ByteString.Lazy as L import Data.ByteString.Builder (toLazyByteString)
import qualified Data.ByteString.Lazy as LB
import Data.Either (partitionEithers)
import Data.Int (Int64) import Data.Int (Int64)
import Data.List.NonEmpty (fromList) import Data.String (IsString (..))
import qualified Data.Text as T
import Data.Text.Encoding (encodeUtf8)
import Simplex.Chat.ByteStringBatcher import Simplex.Chat.ByteStringBatcher
import Simplex.Chat.Protocol (maxChatMsgSize) import Simplex.Chat.Controller (ChatError (..), ChatErrorType (..))
import Simplex.Chat.Messages (SndMessage (..))
import Simplex.Chat.Protocol (SharedMsgId (..), maxChatMsgSize)
import Test.Hspec import Test.Hspec
byteStringBatcherTests :: Spec byteStringBatcherTests :: Spec
@ -15,30 +25,46 @@ byteStringBatcherTests = describe "ByteStringBatcher tests" $ do
testBatchingCorrectness testBatchingCorrectness
it "image x.msg.new and x.msg.file.descr should fit into single batch" testImageFitsSingleBatch it "image x.msg.new and x.msg.file.descr should fit into single batch" testImageFitsSingleBatch
instance IsString SndMessage where
fromString s = SndMessage {msgId, sharedMsgId = SharedMsgId "", msgBody = LB.fromStrict s'}
where
s' = encodeUtf8 $ T.pack s
msgId = fromInteger $ os2ip s'
deriving instance Eq SndMessage
instance IsString ChatError where
fromString s = ChatError $ CEInternalError ("large message " <> show msgId)
where
s' = encodeUtf8 $ T.pack s
msgId = fromInteger (os2ip s') :: Int64
testBatchingCorrectness :: Spec testBatchingCorrectness :: Spec
testBatchingCorrectness = describe "correctness tests" $ do testBatchingCorrectness = describe "correctness tests" $ do
runBatcherTest 8 ["a"] [] ["a"] runBatcherTest 8 ["a"] [] ["a"]
runBatcherTest 8 ["a", "b"] [] ["[a,b]"] runBatcherTest 8 ["a", "b"] [] ["[a,b]"]
runBatcherTest 8 ["a", "b", "c"] [] ["[a,b,c]"] runBatcherTest 8 ["a", "b", "c"] [] ["[a,b,c]"]
runBatcherTest 8 ["a", "bb", "c"] [] ["[a,bb,c]"] runBatcherTest 8 ["a", "bb", "c"] [] ["[a,bb,c]"]
runBatcherTest 8 ["a", "b", "c", "d"] [] ["[a,b,c]", "d"] runBatcherTest 8 ["a", "b", "c", "d"] [] ["a", "[b,c,d]"]
runBatcherTest 8 ["a", "bb", "c", "d"] [] ["[a,bb,c]", "d"] runBatcherTest 8 ["a", "bb", "c", "d"] [] ["a", "[bb,c,d]"]
runBatcherTest 8 ["a", "bb", "c", "de"] [] ["[a,bb,c]", "de"] runBatcherTest 8 ["a", "bb", "c", "de"] [] ["[a,bb]", "[c,de]"]
runBatcherTest 8 ["a", "b", "c", "d", "e"] [] ["[a,b,c]", "[d,e]"] runBatcherTest 8 ["a", "b", "c", "d", "e"] [] ["[a,b]", "[c,d,e]"]
runBatcherTest 8 ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j"] [] ["a", "[b,c,d]", "[e,f,g]", "[h,i,j]"]
runBatcherTest 8 ["aaaaa"] [] ["aaaaa"] runBatcherTest 8 ["aaaaa"] [] ["aaaaa"]
runBatcherTest 8 ["8aaaaaaa"] [] ["8aaaaaaa"] runBatcherTest 8 ["8aaaaaaa"] [] ["8aaaaaaa"]
runBatcherTest 8 ["aaaa", "bbbb"] [] ["aaaa", "bbbb"] runBatcherTest 8 ["aaaa", "bbbb"] [] ["aaaa", "bbbb"]
runBatcherTest 8 ["aa", "bbb", "cc", "dd"] [] ["[aa,bbb]", "[cc,dd]"] runBatcherTest 8 ["aa", "bbb", "cc", "dd"] [] ["[aa,bbb]", "[cc,dd]"]
runBatcherTest 8 ["aa", "bbb", "cc", "dd", "eee", "fff", "gg", "hh"] [] ["aa", "[bbb,cc]", "[dd,eee]", "fff", "[gg,hh]"]
runBatcherTest 8 ["9aaaaaaaa"] ["9aaaaaaaa"] [] runBatcherTest 8 ["9aaaaaaaa"] ["9aaaaaaaa"] []
runBatcherTest 8 ["aaaaa", "bbb", "cc"] [] ["aaaaa", "[bbb,cc]"] runBatcherTest 8 ["aaaaa", "bbb", "cc"] [] ["aaaaa", "[bbb,cc]"]
runBatcherTest 8 ["8aaaaaaa", "bbb", "cc"] [] ["8aaaaaaa", "[bbb,cc]"] runBatcherTest 8 ["8aaaaaaa", "bbb", "cc"] [] ["8aaaaaaa", "[bbb,cc]"]
runBatcherTest 8 ["9aaaaaaaa", "bbb", "cc"] ["9aaaaaaaa"] ["[bbb,cc]"] runBatcherTest 8 ["9aaaaaaaa", "bbb", "cc"] ["9aaaaaaaa"] ["[bbb,cc]"]
runBatcherTest 8 ["9aaaaaaaa", "bbb", "cc", "dd"] ["9aaaaaaaa"] ["[bbb,cc]", "dd"] runBatcherTest 8 ["9aaaaaaaa", "bbb", "cc", "dd"] ["9aaaaaaaa"] ["bbb", "[cc,dd]"]
runBatcherTest 8 ["9aaaaaaaa", "bbb", "cc", "dd", "e"] ["9aaaaaaaa"] ["[bbb,cc]", "[dd,e]"] runBatcherTest 8 ["9aaaaaaaa", "bbb", "cc", "dd", "e"] ["9aaaaaaaa"] ["[bbb,cc]", "[dd,e]"]
runBatcherTest 8 ["bbb", "cc", "aaaaa"] [] ["[bbb,cc]", "aaaaa"] runBatcherTest 8 ["bbb", "cc", "aaaaa"] [] ["[bbb,cc]", "aaaaa"]
runBatcherTest 8 ["bbb", "cc", "8aaaaaaa"] [] ["[bbb,cc]", "8aaaaaaa"] runBatcherTest 8 ["bbb", "cc", "8aaaaaaa"] [] ["[bbb,cc]", "8aaaaaaa"]
runBatcherTest 8 ["bbb", "cc", "9aaaaaaaa"] ["9aaaaaaaa"] ["[bbb,cc]"] runBatcherTest 8 ["bbb", "cc", "9aaaaaaaa"] ["9aaaaaaaa"] ["[bbb,cc]"]
runBatcherTest 8 ["bbb", "cc", "dd", "9aaaaaaaa"] ["9aaaaaaaa"] ["[bbb,cc]", "dd"] runBatcherTest 8 ["bbb", "cc", "dd", "9aaaaaaaa"] ["9aaaaaaaa"] ["bbb", "[cc,dd]"]
runBatcherTest 8 ["bbb", "cc", "dd", "e", "9aaaaaaaa"] ["9aaaaaaaa"] ["[bbb,cc]", "[dd,e]"] runBatcherTest 8 ["bbb", "cc", "dd", "e", "9aaaaaaaa"] ["9aaaaaaaa"] ["[bbb,cc]", "[dd,e]"]
runBatcherTest 8 ["bbb", "cc", "aaaaa", "dd"] [] ["[bbb,cc]", "aaaaa", "dd"] runBatcherTest 8 ["bbb", "cc", "aaaaa", "dd"] [] ["[bbb,cc]", "aaaaa", "dd"]
runBatcherTest 8 ["bbb", "cc", "aaaaa", "dd", "e"] [] ["[bbb,cc]", "aaaaa", "[dd,e]"] runBatcherTest 8 ["bbb", "cc", "aaaaa", "dd", "e"] [] ["[bbb,cc]", "aaaaa", "[dd,e]"]
@ -51,13 +77,13 @@ testBatchingCorrectness = describe "correctness tests" $ do
runBatcherTest 8 ["8aaaaaaa", "9aaaaaaaa", "10aaaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["8aaaaaaa"] runBatcherTest 8 ["8aaaaaaa", "9aaaaaaaa", "10aaaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["8aaaaaaa"]
runBatcherTest 8 ["9aaaaaaaa", "8aaaaaaa", "10aaaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["8aaaaaaa"] runBatcherTest 8 ["9aaaaaaaa", "8aaaaaaa", "10aaaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["8aaaaaaa"]
runBatcherTest 8 ["9aaaaaaaa", "10aaaaaaaa", "8aaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["8aaaaaaa"] runBatcherTest 8 ["9aaaaaaaa", "10aaaaaaaa", "8aaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["8aaaaaaa"]
runBatcherTest 8 ["bb", "cc", "dd", "9aaaaaaaa", "10aaaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["[bb,cc]", "dd"] runBatcherTest 8 ["bb", "cc", "dd", "9aaaaaaaa", "10aaaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["bb", "[cc,dd]"]
runBatcherTest 8 ["bb", "cc", "9aaaaaaaa", "dd", "10aaaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["[bb,cc]", "dd"] runBatcherTest 8 ["bb", "cc", "9aaaaaaaa", "dd", "10aaaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["[bb,cc]", "dd"]
runBatcherTest 8 ["bb", "9aaaaaaaa", "cc", "dd", "10aaaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["bb", "[cc,dd]"] runBatcherTest 8 ["bb", "9aaaaaaaa", "cc", "dd", "10aaaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["bb", "[cc,dd]"]
runBatcherTest 8 ["bb", "9aaaaaaaa", "cc", "10aaaaaaaa", "dd"] ["9aaaaaaaa", "10aaaaaaaa"] ["bb", "cc", "dd"] runBatcherTest 8 ["bb", "9aaaaaaaa", "cc", "10aaaaaaaa", "dd"] ["9aaaaaaaa", "10aaaaaaaa"] ["bb", "cc", "dd"]
runBatcherTest 8 ["9aaaaaaaa", "bb", "cc", "dd", "10aaaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["[bb,cc]", "dd"] runBatcherTest 8 ["9aaaaaaaa", "bb", "cc", "dd", "10aaaaaaaa"] ["9aaaaaaaa", "10aaaaaaaa"] ["bb", "[cc,dd]"]
runBatcherTest 8 ["9aaaaaaaa", "bb", "10aaaaaaaa", "cc", "dd"] ["9aaaaaaaa", "10aaaaaaaa"] ["bb", "[cc,dd]"] runBatcherTest 8 ["9aaaaaaaa", "bb", "10aaaaaaaa", "cc", "dd"] ["9aaaaaaaa", "10aaaaaaaa"] ["bb", "[cc,dd]"]
runBatcherTest 8 ["9aaaaaaaa", "10aaaaaaaa", "bb", "cc", "dd"] ["9aaaaaaaa", "10aaaaaaaa"] ["[bb,cc]", "dd"] runBatcherTest 8 ["9aaaaaaaa", "10aaaaaaaa", "bb", "cc", "dd"] ["9aaaaaaaa", "10aaaaaaaa"] ["bb", "[cc,dd]"]
testImageFitsSingleBatch :: IO () testImageFitsSingleBatch :: IO ()
testImageFitsSingleBatch = do testImageFitsSingleBatch = do
@ -68,23 +94,27 @@ testImageFitsSingleBatch = do
-- 261_120 bytes (MAX_IMAGE_SIZE in UI), rounded up, example was 743 -- 261_120 bytes (MAX_IMAGE_SIZE in UI), rounded up, example was 743
let descrRoundedSize = 800 let descrRoundedSize = 800
let xMsgNewStr = L.replicate xMsgNewRoundedSize 1 let xMsgNewStr = LB.replicate xMsgNewRoundedSize 1
descrStr = L.replicate descrRoundedSize 2 descrStr = LB.replicate descrRoundedSize 2
msg s = SndMessage {msgId = 0, sharedMsgId = SharedMsgId "", msgBody = s}
batched = "[" <> xMsgNewStr <> "," <> descrStr <> "]"
runBatcherTest' maxChatMsgSize [xMsgNewStr, descrStr] [] ["[" <> xMsgNewStr <> "," <> descrStr <> "]"] runBatcherTest' maxChatMsgSize [msg xMsgNewStr, msg descrStr] [] [batched]
runBatcherTest :: Int64 -> [L.ByteString] -> [L.ByteString] -> [L.ByteString] -> SpecWith () runBatcherTest :: Int64 -> [SndMessage] -> [ChatError] -> [LB.ByteString] -> Spec
runBatcherTest batchLenLimit bStrs expectedLargeStrs expectedBatchedStrs = runBatcherTest maxLen msgs expectedErrors expectedBatches =
it it
( (show bStrs <> ", limit " <> show batchLenLimit <> ": should return ") ( (show (map (\SndMessage {msgBody} -> msgBody) msgs) <> ", limit " <> show maxLen <> ": should return ")
<> (show (length expectedLargeStrs) <> " large, ") <> (show (length expectedErrors) <> " large, ")
<> (show (length expectedBatchedStrs) <> " batches") <> (show (length expectedBatches) <> " batches")
) )
(runBatcherTest' batchLenLimit bStrs expectedLargeStrs expectedBatchedStrs) (runBatcherTest' maxLen msgs expectedErrors expectedBatches)
runBatcherTest' :: Int64 -> [L.ByteString] -> [L.ByteString] -> [L.ByteString] -> IO () runBatcherTest' :: Int64 -> [SndMessage] -> [ChatError] -> [LB.ByteString] -> IO ()
runBatcherTest' batchLenLimit bStrs expectedLargeStrs expectedBatchedStrs = do runBatcherTest' maxLen msgs expectedErrors expectedBatches = do
let (largeStrs, batches) = partitionBatches $ batchByteStringObjects batchLenLimit (fromList bStrs) let (errors, batches) = partitionEithers $ batchMessages maxLen msgs
batchedStrs = map (\(BSBatch batchBuilder _) -> BB.toLazyByteString batchBuilder) batches batchedStrs = map (\(MsgBatch builder _) -> toLazyByteString builder) batches
largeStrs `shouldBe` expectedLargeStrs testErrors errors `shouldBe` testErrors expectedErrors
batchedStrs `shouldBe` expectedBatchedStrs batchedStrs `shouldBe` expectedBatches
where
testErrors = map (\case ChatError (CEInternalError s) -> Just s; _ -> Nothing)