Regenerate genai tracked object descriptions (#13930)

* add genai to frigateconfig

* add regenerate button if genai is enabled

* add endpoint and new zmq pub/sub model

* move publisher to app

* dont override

* logging

* debug timeouts

* clean up

* clean up

* allow saving of empty description

* ensure descriptions can be empty

* update search detail when results change

* revalidate explore page on focus

* global mutate hook

* description websocket hook and dispatcher

* revalidation and mutation

* fix merge conflicts

* update tests

* fix merge conflicts

* fix response message

* fix response message

* fix fastapi

* fix test

* remove log

* json content

* fix content response

* more json content fixes

* another one
This commit is contained in:
Josh Hawkins 2024-09-24 09:14:51 -05:00 committed by GitHub
parent cffc431bf0
commit ecbf0410eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 274 additions and 19 deletions

View File

@ -928,20 +928,62 @@ def set_description(
ids=[event_id], ids=[event_id],
) )
response_message = (
f"Event {event_id} description is now blank"
if new_description is None or len(new_description) == 0
else f"Event {event_id} description set to {new_description}"
)
return JSONResponse( return JSONResponse(
content=( content=(
{ {
"success": True, "success": True,
"message": "Event " "message": response_message,
+ event_id
+ " description set to "
+ new_description,
} }
), ),
status_code=200, status_code=200,
) )
@router.put("/events/<id>/description/regenerate")
def regenerate_description(request: Request, event_id: str):
try:
event: Event = Event.get(Event.id == event_id)
except DoesNotExist:
return JSONResponse(
content=({"success": False, "message": "Event " + event_id + " not found"}),
status_code=404,
)
if (
request.app.frigate_config.semantic_search.enabled
and request.app.frigate_config.genai.enabled
):
request.app.event_metadata_updater.publish(event.id)
return JSONResponse(
content=(
{
"success": True,
"message": "Event "
+ event_id
+ " description regeneration has been requested.",
}
),
status_code=200,
)
return JSONResponse(
content=(
{
"success": False,
"message": "Semantic search and generative AI are not enabled",
}
),
status_code=400,
)
@router.delete("/events/{event_id}") @router.delete("/events/{event_id}")
def delete_event(request: Request, event_id: str): def delete_event(request: Request, event_id: str):
try: try:

View File

@ -13,6 +13,9 @@ from starlette_context.plugins import Plugin
from frigate.api import app as main_app from frigate.api import app as main_app
from frigate.api import auth, event, export, media, notification, preview, review from frigate.api import auth, event, export, media, notification, preview, review
from frigate.api.auth import get_jwt_secret, limiter from frigate.api.auth import get_jwt_secret, limiter
from frigate.comms.event_metadata_updater import (
EventMetadataPublisher,
)
from frigate.config import FrigateConfig from frigate.config import FrigateConfig
from frigate.embeddings import EmbeddingsContext from frigate.embeddings import EmbeddingsContext
from frigate.events.external import ExternalEventProcessor from frigate.events.external import ExternalEventProcessor
@ -47,6 +50,7 @@ def create_fastapi_app(
onvif: OnvifController, onvif: OnvifController,
external_processor: ExternalEventProcessor, external_processor: ExternalEventProcessor,
stats_emitter: StatsEmitter, stats_emitter: StatsEmitter,
event_metadata_updater: EventMetadataPublisher,
): ):
logger.info("Starting FastAPI app") logger.info("Starting FastAPI app")
app = FastAPI( app = FastAPI(
@ -102,6 +106,7 @@ def create_fastapi_app(
app.camera_error_image = None app.camera_error_image = None
app.onvif = onvif app.onvif = onvif
app.stats_emitter = stats_emitter app.stats_emitter = stats_emitter
app.event_metadata_updater = event_metadata_updater
app.external_processor = external_processor app.external_processor = external_processor
app.jwt_token = get_jwt_secret() if frigate_config.auth.enabled else None app.jwt_token = get_jwt_secret() if frigate_config.auth.enabled else None

View File

@ -18,6 +18,10 @@ from frigate.api.auth import hash_password
from frigate.api.fastapi_app import create_fastapi_app from frigate.api.fastapi_app import create_fastapi_app
from frigate.comms.config_updater import ConfigPublisher from frigate.comms.config_updater import ConfigPublisher
from frigate.comms.dispatcher import Communicator, Dispatcher from frigate.comms.dispatcher import Communicator, Dispatcher
from frigate.comms.event_metadata_updater import (
EventMetadataPublisher,
EventMetadataTypeEnum,
)
from frigate.comms.inter_process import InterProcessCommunicator from frigate.comms.inter_process import InterProcessCommunicator
from frigate.comms.mqtt import MqttClient from frigate.comms.mqtt import MqttClient
from frigate.comms.webpush import WebPushClient from frigate.comms.webpush import WebPushClient
@ -332,6 +336,9 @@ class FrigateApp:
def init_inter_process_communicator(self) -> None: def init_inter_process_communicator(self) -> None:
self.inter_process_communicator = InterProcessCommunicator() self.inter_process_communicator = InterProcessCommunicator()
self.inter_config_updater = ConfigPublisher() self.inter_config_updater = ConfigPublisher()
self.event_metadata_updater = EventMetadataPublisher(
EventMetadataTypeEnum.regenerate_description
)
self.inter_zmq_proxy = ZmqProxy() self.inter_zmq_proxy = ZmqProxy()
def init_onvif(self) -> None: def init_onvif(self) -> None:
@ -656,6 +663,7 @@ class FrigateApp:
self.onvif_controller, self.onvif_controller,
self.external_event_processor, self.external_event_processor,
self.stats_emitter, self.stats_emitter,
self.event_metadata_updater,
), ),
host="127.0.0.1", host="127.0.0.1",
port=5001, port=5001,
@ -743,6 +751,7 @@ class FrigateApp:
# Stop Communicators # Stop Communicators
self.inter_process_communicator.stop() self.inter_process_communicator.stop()
self.inter_config_updater.stop() self.inter_config_updater.stop()
self.event_metadata_updater.stop()
self.inter_zmq_proxy.stop() self.inter_zmq_proxy.stop()
while len(self.detection_shms) > 0: while len(self.detection_shms) > 0:

View File

@ -140,6 +140,10 @@ class Dispatcher:
event: Event = Event.get(Event.id == payload["id"]) event: Event = Event.get(Event.id == payload["id"])
event.data["description"] = payload["description"] event.data["description"] = payload["description"]
event.save() event.save()
self.publish(
"event_update",
json.dumps({"id": event.id, "description": event.data["description"]}),
)
elif topic == "onConnect": elif topic == "onConnect":
camera_status = self.camera_activity.copy() camera_status = self.camera_activity.copy()

View File

@ -0,0 +1,44 @@
"""Facilitates communication between processes."""
import logging
from enum import Enum
from typing import Optional
from .zmq_proxy import Publisher, Subscriber
logger = logging.getLogger(__name__)
class EventMetadataTypeEnum(str, Enum):
all = ""
regenerate_description = "regenerate_description"
class EventMetadataPublisher(Publisher):
"""Simplifies receiving event metadata."""
topic_base = "event_metadata/"
def __init__(self, topic: EventMetadataTypeEnum) -> None:
topic = topic.value
super().__init__(topic)
class EventMetadataSubscriber(Subscriber):
"""Simplifies receiving event metadata."""
topic_base = "event_metadata/"
def __init__(self, topic: EventMetadataTypeEnum) -> None:
topic = topic.value
super().__init__(topic)
def check_for_update(
self, timeout: float = None
) -> Optional[tuple[EventMetadataTypeEnum, any]]:
return super().check_for_update(timeout)
def _return_object(self, topic: str, payload: any) -> any:
if payload is None:
return (None, None)
return (EventMetadataTypeEnum[topic[len(self.topic_base) :]], payload)

View File

@ -12,6 +12,10 @@ import numpy as np
from peewee import DoesNotExist from peewee import DoesNotExist
from PIL import Image from PIL import Image
from frigate.comms.event_metadata_updater import (
EventMetadataSubscriber,
EventMetadataTypeEnum,
)
from frigate.comms.events_updater import EventEndSubscriber, EventUpdateSubscriber from frigate.comms.events_updater import EventEndSubscriber, EventUpdateSubscriber
from frigate.comms.inter_process import InterProcessRequestor from frigate.comms.inter_process import InterProcessRequestor
from frigate.config import FrigateConfig from frigate.config import FrigateConfig
@ -40,6 +44,9 @@ class EmbeddingMaintainer(threading.Thread):
self.embeddings = Embeddings() self.embeddings = Embeddings()
self.event_subscriber = EventUpdateSubscriber() self.event_subscriber = EventUpdateSubscriber()
self.event_end_subscriber = EventEndSubscriber() self.event_end_subscriber = EventEndSubscriber()
self.event_metadata_subscriber = EventMetadataSubscriber(
EventMetadataTypeEnum.regenerate_description
)
self.frame_manager = SharedMemoryFrameManager() self.frame_manager = SharedMemoryFrameManager()
# create communication for updating event descriptions # create communication for updating event descriptions
self.requestor = InterProcessRequestor() self.requestor = InterProcessRequestor()
@ -52,9 +59,11 @@ class EmbeddingMaintainer(threading.Thread):
while not self.stop_event.is_set(): while not self.stop_event.is_set():
self._process_updates() self._process_updates()
self._process_finalized() self._process_finalized()
self._process_event_metadata()
self.event_subscriber.stop() self.event_subscriber.stop()
self.event_end_subscriber.stop() self.event_end_subscriber.stop()
self.event_metadata_subscriber.stop()
self.requestor.stop() self.requestor.stop()
logger.info("Exiting embeddings maintenance...") logger.info("Exiting embeddings maintenance...")
@ -140,6 +149,16 @@ class EmbeddingMaintainer(threading.Thread):
if event_id in self.tracked_events: if event_id in self.tracked_events:
del self.tracked_events[event_id] del self.tracked_events[event_id]
def _process_event_metadata(self):
# Check for regenerate description requests
(topic, event_id) = self.event_metadata_subscriber.check_for_update()
if topic is None:
return
if event_id:
self.handle_regenerate_description(event_id)
def _create_thumbnail(self, yuv_frame, box, height=500) -> Optional[bytes]: def _create_thumbnail(self, yuv_frame, box, height=500) -> Optional[bytes]:
"""Return jpg thumbnail of a region of the frame.""" """Return jpg thumbnail of a region of the frame."""
frame = cv2.cvtColor(yuv_frame, cv2.COLOR_YUV2BGR_I420) frame = cv2.cvtColor(yuv_frame, cv2.COLOR_YUV2BGR_I420)
@ -200,3 +219,20 @@ class EmbeddingMaintainer(threading.Thread):
len(thumbnails), len(thumbnails),
description, description,
) )
def handle_regenerate_description(self, event_id: str) -> None:
try:
event: Event = Event.get(Event.id == event_id)
except DoesNotExist:
logger.error(f"Event {event_id} not found for description regeneration")
return
camera_config = self.config.cameras[event.camera]
if not camera_config.genai.enabled or self.genai_client is None:
logger.error(f"GenAI not enabled for camera {event.camera}")
return
metadata = get_metadata(event)
thumbnail = base64.b64decode(event.thumbnail)
self._embed_description(event, [thumbnail], metadata)

View File

@ -121,6 +121,7 @@ class TestHttp(unittest.TestCase):
None, None,
None, None,
None, None,
None,
) )
id = "123456.random" id = "123456.random"
id2 = "7890.random" id2 = "7890.random"
@ -157,6 +158,7 @@ class TestHttp(unittest.TestCase):
None, None,
None, None,
None, None,
None,
) )
id = "123456.random" id = "123456.random"
@ -178,6 +180,7 @@ class TestHttp(unittest.TestCase):
None, None,
None, None,
None, None,
None,
) )
id = "123456.random" id = "123456.random"
bad_id = "654321.other" bad_id = "654321.other"
@ -198,6 +201,7 @@ class TestHttp(unittest.TestCase):
None, None,
None, None,
None, None,
None,
) )
id = "123456.random" id = "123456.random"
@ -220,6 +224,7 @@ class TestHttp(unittest.TestCase):
None, None,
None, None,
None, None,
None,
) )
id = "123456.random" id = "123456.random"
@ -246,6 +251,7 @@ class TestHttp(unittest.TestCase):
None, None,
None, None,
None, None,
None,
) )
morning_id = "123456.random" morning_id = "123456.random"
evening_id = "654321.random" evening_id = "654321.random"
@ -284,6 +290,7 @@ class TestHttp(unittest.TestCase):
None, None,
None, None,
None, None,
None,
) )
id = "123456.random" id = "123456.random"
sub_label = "sub" sub_label = "sub"
@ -319,6 +326,7 @@ class TestHttp(unittest.TestCase):
None, None,
None, None,
None, None,
None,
) )
id = "123456.random" id = "123456.random"
sub_label = "sub" sub_label = "sub"
@ -343,6 +351,7 @@ class TestHttp(unittest.TestCase):
None, None,
None, None,
None, None,
None,
) )
with TestClient(app) as client: with TestClient(app) as client:
@ -360,6 +369,7 @@ class TestHttp(unittest.TestCase):
None, None,
None, None,
None, None,
None,
) )
id = "123456.random" id = "123456.random"
@ -383,6 +393,7 @@ class TestHttp(unittest.TestCase):
None, None,
None, None,
stats, stats,
None,
) )
with TestClient(app) as client: with TestClient(app) as client:

View File

@ -321,3 +321,10 @@ export function useImproveContrast(camera: string): {
); );
return { payload: payload as ToggleableSetting, send }; return { payload: payload as ToggleableSetting, send };
} }
export function useEventUpdate(): { payload: string } {
const {
value: { payload },
} = useWs("event_update", "");
return useDeepMemo(JSON.parse(payload as string));
}

View File

@ -45,6 +45,8 @@ import {
import { ReviewSegment } from "@/types/review"; import { ReviewSegment } from "@/types/review";
import { useNavigate } from "react-router-dom"; import { useNavigate } from "react-router-dom";
import Chip from "@/components/indicators/Chip"; import Chip from "@/components/indicators/Chip";
import { capitalizeFirstLetter } from "@/utils/stringUtil";
import useGlobalMutation from "@/hooks/use-global-mutate";
const SEARCH_TABS = [ const SEARCH_TABS = [
"details", "details",
@ -232,6 +234,10 @@ function ObjectDetailsTab({
}: ObjectDetailsTabProps) { }: ObjectDetailsTabProps) {
const apiHost = useApiHost(); const apiHost = useApiHost();
// mutation / revalidation
const mutate = useGlobalMutation();
// data // data
const [desc, setDesc] = useState(search?.data.description); const [desc, setDesc] = useState(search?.data.description);
@ -282,6 +288,13 @@ function ObjectDetailsTab({
position: "top-center", position: "top-center",
}); });
} }
mutate(
(key) =>
typeof key === "string" &&
(key.includes("events") ||
key.includes("events/search") ||
key.includes("explore")),
);
}) })
.catch(() => { .catch(() => {
toast.error("Failed to update the description", { toast.error("Failed to update the description", {
@ -289,7 +302,35 @@ function ObjectDetailsTab({
}); });
setDesc(search.data.description); setDesc(search.data.description);
}); });
}, [desc, search]); }, [desc, search, mutate]);
const regenerateDescription = useCallback(() => {
if (!search) {
return;
}
axios
.put(`events/${search.id}/description/regenerate`)
.then((resp) => {
if (resp.status == 200) {
toast.success(
`A new description has been requested from ${capitalizeFirstLetter(config?.genai.provider ?? "Generative AI")}. Depending on the speed of your provider, the new description may take some time to regenerate.`,
{
position: "top-center",
duration: 7000,
},
);
}
})
.catch(() => {
toast.error(
`Failed to call ${capitalizeFirstLetter(config?.genai.provider ?? "Generative AI")} for a new description`,
{
position: "top-center",
},
);
});
}, [search, config]);
return ( return (
<div className="flex flex-col gap-5"> <div className="flex flex-col gap-5">
@ -355,7 +396,10 @@ function ObjectDetailsTab({
value={desc} value={desc}
onChange={(e) => setDesc(e.target.value)} onChange={(e) => setDesc(e.target.value)}
/> />
<div className="flex w-full flex-row justify-end"> <div className="flex w-full flex-row justify-end gap-2">
{config?.genai.enabled && (
<Button onClick={regenerateDescription}>Regenerate</Button>
)}
<Button variant="select" onClick={updateDescription}> <Button variant="select" onClick={updateDescription}>
Save Save
</Button> </Button>

View File

@ -0,0 +1,16 @@
// https://github.com/vercel/swr/issues/1670#issuecomment-1844114401
import { useCallback } from "react";
import { cache, mutate } from "swr/_internal";
const useGlobalMutation = () => {
return useCallback((swrKey: string | ((key: string) => boolean), ...args) => {
if (typeof swrKey === "function") {
const keys = Array.from(cache.keys()).filter(swrKey);
keys.forEach((key) => mutate(key, ...args));
} else {
mutate(swrKey, ...args);
}
}, []) as typeof mutate;
};
export default useGlobalMutation;

View File

@ -1,3 +1,4 @@
import { useEventUpdate } from "@/api/ws";
import { useApiFilterArgs } from "@/hooks/use-api-filter"; import { useApiFilterArgs } from "@/hooks/use-api-filter";
import { SearchFilter, SearchQuery, SearchResult } from "@/types/search"; import { SearchFilter, SearchQuery, SearchResult } from "@/types/search";
import SearchView from "@/views/search/SearchView"; import SearchView from "@/views/search/SearchView";
@ -123,19 +124,19 @@ export default function Explore() {
return [url, { ...params, limit: API_LIMIT }]; return [url, { ...params, limit: API_LIMIT }];
}; };
const { data, size, setSize, isValidating } = useSWRInfinite<SearchResult[]>( const { data, size, setSize, isValidating, mutate } = useSWRInfinite<
getKey, SearchResult[]
{ >(getKey, {
revalidateFirstPage: true, revalidateFirstPage: true,
revalidateAll: false, revalidateOnFocus: true,
onLoadingSlow: () => { revalidateAll: false,
if (!similaritySearch) { onLoadingSlow: () => {
setIsSlowLoading(true); if (!similaritySearch) {
} setIsSlowLoading(true);
}, }
loadingTimeout: 10000,
}, },
); loadingTimeout: 10000,
});
const searchResults = useMemo( const searchResults = useMemo(
() => (data ? ([] as SearchResult[]).concat(...data) : []), () => (data ? ([] as SearchResult[]).concat(...data) : []),
@ -164,6 +165,16 @@ export default function Explore() {
} }
}, [isReachingEnd, isLoadingMore, setSize, size, searchResults, searchQuery]); }, [isReachingEnd, isLoadingMore, setSize, size, searchResults, searchQuery]);
// mutation and revalidation
const eventUpdate = useEventUpdate();
useEffect(() => {
mutate();
// mutate / revalidate when event description updates come in
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [eventUpdate]);
return ( return (
<> <>
{isSlowLoading && !similaritySearch ? ( {isSlowLoading && !similaritySearch ? (

View File

@ -298,6 +298,16 @@ export interface FrigateConfig {
retry_interval: number; retry_interval: number;
}; };
genai: {
enabled: boolean;
provider: string;
base_url?: string;
api_key?: string;
model: string;
prompt: string;
object_prompts: { [key: string]: string };
};
go2rtc: { go2rtc: {
streams: string[]; streams: string[];
webrtc: { webrtc: {

View File

@ -37,7 +37,7 @@ export default function ExploreView({ onSelectSearch }: ExploreViewProps) {
}, },
], ],
{ {
revalidateOnFocus: false, revalidateOnFocus: true,
}, },
); );

View File

@ -23,6 +23,7 @@ import useKeyboardListener, {
import scrollIntoView from "scroll-into-view-if-needed"; import scrollIntoView from "scroll-into-view-if-needed";
import InputWithTags from "@/components/input/InputWithTags"; import InputWithTags from "@/components/input/InputWithTags";
import { ScrollArea, ScrollBar } from "@/components/ui/scroll-area"; import { ScrollArea, ScrollBar } from "@/components/ui/scroll-area";
import { isEqual } from "lodash";
type SearchViewProps = { type SearchViewProps = {
search: string; search: string;
@ -140,6 +141,21 @@ export default function SearchView({
setSelectedIndex(index); setSelectedIndex(index);
}, []); }, []);
// update search detail when results change
useEffect(() => {
if (searchDetail && searchResults) {
const flattenedResults = searchResults.flat();
const updatedSearchDetail = flattenedResults.find(
(result) => result.id === searchDetail.id,
);
if (updatedSearchDetail && !isEqual(updatedSearchDetail, searchDetail)) {
setSearchDetail(updatedSearchDetail);
}
}
}, [searchResults, searchDetail]);
// confidence score - probably needs tweaking // confidence score - probably needs tweaking
const zScoreToConfidence = (score: number, source: string) => { const zScoreToConfidence = (score: number, source: string) => {