Restructure embeddings (#14266)

* Restructure embeddings

* Use ZMQ to proxy embeddings requests

* Handle serialization

* Formatting

* Remove unused
This commit is contained in:
Nicolas Mowen 2024-10-10 09:42:24 -06:00 committed by GitHub
parent a2ca18a714
commit 8ade85edec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 241 additions and 142 deletions

View File

@ -472,7 +472,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
status_code=404, status_code=404,
) )
thumb_result = context.embeddings.search_thumbnail(search_event) thumb_result = context.search_thumbnail(search_event)
thumb_ids = dict( thumb_ids = dict(
zip( zip(
[result[0] for result in thumb_result], [result[0] for result in thumb_result],
@ -487,7 +487,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
search_types = search_type.split(",") search_types = search_type.split(",")
if "thumbnail" in search_types: if "thumbnail" in search_types:
thumb_result = context.embeddings.search_thumbnail(query) thumb_result = context.search_thumbnail(query)
thumb_ids = dict( thumb_ids = dict(
zip( zip(
[result[0] for result in thumb_result], [result[0] for result in thumb_result],
@ -504,7 +504,7 @@ def events_search(request: Request, params: EventsSearchQueryParams = Depends())
) )
if "description" in search_types: if "description" in search_types:
desc_result = context.embeddings.search_description(query) desc_result = context.search_description(query)
desc_ids = dict( desc_ids = dict(
zip( zip(
[result[0] for result in desc_result], [result[0] for result in desc_result],
@ -944,9 +944,9 @@ def set_description(
# If semantic search is enabled, update the index # If semantic search is enabled, update the index
if request.app.frigate_config.semantic_search.enabled: if request.app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = request.app.embeddings context: EmbeddingsContext = request.app.embeddings
context.embeddings.upsert_description( context.update_description(
event_id=event_id, event_id,
description=new_description, new_description,
) )
response_message = ( response_message = (
@ -1033,8 +1033,8 @@ def delete_event(request: Request, event_id: str):
# If semantic search is enabled, update the index # If semantic search is enabled, update the index
if request.app.frigate_config.semantic_search.enabled: if request.app.frigate_config.semantic_search.enabled:
context: EmbeddingsContext = request.app.embeddings context: EmbeddingsContext = request.app.embeddings
context.embeddings.delete_thumbnail(id=[event_id]) context.db.delete_embeddings_thumbnail(id=[event_id])
context.embeddings.delete_description(id=[event_id]) context.db.delete_embeddings_description(id=[event_id])
return JSONResponse( return JSONResponse(
content=({"success": True, "message": "Event " + event_id + " deleted"}), content=({"success": True, "message": "Event " + event_id + " deleted"}),
status_code=200, status_code=200,

View File

@ -276,7 +276,7 @@ class FrigateApp:
def init_embeddings_client(self) -> None: def init_embeddings_client(self) -> None:
if self.config.semantic_search.enabled: if self.config.semantic_search.enabled:
# Create a client for other processes to use # Create a client for other processes to use
self.embeddings = EmbeddingsContext(self.config, self.db) self.embeddings = EmbeddingsContext(self.db)
def init_external_event_processor(self) -> None: def init_external_event_processor(self) -> None:
self.external_event_processor = ExternalEventProcessor(self.config) self.external_event_processor = ExternalEventProcessor(self.config)
@ -699,7 +699,7 @@ class FrigateApp:
# Save embeddings stats to disk # Save embeddings stats to disk
if self.embeddings: if self.embeddings:
self.embeddings.save_stats() self.embeddings.stop()
# Stop Communicators # Stop Communicators
self.inter_process_communicator.stop() self.inter_process_communicator.stop()

View File

@ -0,0 +1,62 @@
"""Facilitates communication between processes."""
from enum import Enum
from typing import Callable
import zmq
SOCKET_REP_REQ = "ipc:///tmp/cache/embeddings"
class EmbeddingsRequestEnum(Enum):
embed_description = "embed_description"
embed_thumbnail = "embed_thumbnail"
generate_search = "generate_search"
class EmbeddingsResponder:
def __init__(self) -> None:
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REP)
self.socket.bind(SOCKET_REP_REQ)
def check_for_request(self, process: Callable) -> None:
while True: # load all messages that are queued
has_message, _, _ = zmq.select([self.socket], [], [], 1)
if not has_message:
break
try:
(topic, value) = self.socket.recv_json(flags=zmq.NOBLOCK)
response = process(topic, value)
if response is not None:
self.socket.send_json(response)
else:
self.socket.send_json([])
except zmq.ZMQError:
break
def stop(self) -> None:
self.socket.close()
self.context.destroy()
class EmbeddingsRequestor:
"""Simplifies sending data to EmbeddingsResponder and getting a reply."""
def __init__(self) -> None:
self.context = zmq.Context()
self.socket = self.context.socket(zmq.REQ)
self.socket.connect(SOCKET_REP_REQ)
def send_data(self, topic: str, data: any) -> str:
"""Sends data and then waits for reply."""
self.socket.send_json((topic, data))
return self.socket.recv_json()
def stop(self) -> None:
self.socket.close()
self.context.destroy()

View File

@ -20,3 +20,11 @@ class SqliteVecQueueDatabase(SqliteQueueDatabase):
conn.enable_load_extension(True) conn.enable_load_extension(True)
conn.load_extension(self.sqlite_vec_path) conn.load_extension(self.sqlite_vec_path)
conn.enable_load_extension(False) conn.enable_load_extension(False)
def delete_embeddings_thumbnail(self, event_ids: list[str]) -> None:
ids = ",".join(["?" for _ in event_ids])
self.execute_sql(f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids)
def delete_embeddings_description(self, event_ids: list[str]) -> None:
ids = ",".join(["?" for _ in event_ids])
self.execute_sql(f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids)

View File

@ -7,14 +7,16 @@ import os
import signal import signal
import threading import threading
from types import FrameType from types import FrameType
from typing import Optional from typing import Optional, Union
from setproctitle import setproctitle from setproctitle import setproctitle
from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsRequestor
from frigate.config import FrigateConfig from frigate.config import FrigateConfig
from frigate.const import CONFIG_DIR from frigate.const import CONFIG_DIR
from frigate.db.sqlitevecq import SqliteVecQueueDatabase from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event from frigate.models import Event
from frigate.util.builtin import serialize
from frigate.util.services import listen from frigate.util.services import listen
from .embeddings import Embeddings from .embeddings import Embeddings
@ -70,10 +72,11 @@ def manage_embeddings(config: FrigateConfig) -> None:
class EmbeddingsContext: class EmbeddingsContext:
def __init__(self, config: FrigateConfig, db: SqliteVecQueueDatabase): def __init__(self, db: SqliteVecQueueDatabase):
self.embeddings = Embeddings(config.semantic_search, db) self.db = db
self.thumb_stats = ZScoreNormalization() self.thumb_stats = ZScoreNormalization()
self.desc_stats = ZScoreNormalization() self.desc_stats = ZScoreNormalization()
self.requestor = EmbeddingsRequestor()
# load stats from disk # load stats from disk
try: try:
@ -84,7 +87,7 @@ class EmbeddingsContext:
except FileNotFoundError: except FileNotFoundError:
pass pass
def save_stats(self): def stop(self):
"""Write the stats to disk as JSON on exit.""" """Write the stats to disk as JSON on exit."""
contents = { contents = {
"thumb_stats": self.thumb_stats.to_dict(), "thumb_stats": self.thumb_stats.to_dict(),
@ -92,3 +95,100 @@ class EmbeddingsContext:
} }
with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "w") as f: with open(os.path.join(CONFIG_DIR, ".search_stats.json"), "w") as f:
json.dump(contents, f) json.dump(contents, f)
self.requestor.stop()
def search_thumbnail(
self, query: Union[Event, str], event_ids: list[str] = None
) -> list[tuple[str, float]]:
if query.__class__ == Event:
cursor = self.db.execute_sql(
"""
SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
""",
[query.id],
)
row = cursor.fetchone() if cursor else None
if row:
query_embedding = row[0]
else:
# If no embedding found, generate it and return it
query_embedding = serialize(
self.requestor.send_data(
EmbeddingsRequestEnum.embed_thumbnail.value,
{"id": query.id, "thumbnail": query.thumbnail},
)
)
else:
query_embedding = serialize(
self.requestor.send_data(
EmbeddingsRequestEnum.generate_search.value, query
)
)
sql_query = """
SELECT
id,
distance
FROM vec_thumbnails
WHERE thumbnail_embedding MATCH ?
AND k = 100
"""
# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
# but it seems to be broken in this version
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
# order by distance DESC is not implemented in this version of sqlite-vec
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"
parameters = [query_embedding] + event_ids if event_ids else [query_embedding]
results = self.db.execute_sql(sql_query, parameters).fetchall()
return results
def search_description(
self, query_text: str, event_ids: list[str] = None
) -> list[tuple[str, float]]:
query_embedding = serialize(
self.requestor.send_data(
EmbeddingsRequestEnum.generate_search.value, query_text
)
)
# Prepare the base SQL query
sql_query = """
SELECT
id,
distance
FROM vec_descriptions
WHERE description_embedding MATCH ?
AND k = 100
"""
# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
# but it seems to be broken in this version
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
# order by distance DESC is not implemented in this version of sqlite-vec
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"
parameters = [query_embedding] + event_ids if event_ids else [query_embedding]
results = self.db.execute_sql(sql_query, parameters).fetchall()
return results
def update_description(self, event_id: str, description: str) -> None:
self.requestor.send_data(
EmbeddingsRequestEnum.embed_description.value,
{"id": event_id, "description": description},
)

View File

@ -3,11 +3,8 @@
import base64 import base64
import io import io
import logging import logging
import struct
import time import time
from typing import List, Tuple, Union
import numpy as np
from PIL import Image from PIL import Image
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
@ -17,6 +14,7 @@ from frigate.const import UPDATE_MODEL_STATE
from frigate.db.sqlitevecq import SqliteVecQueueDatabase from frigate.db.sqlitevecq import SqliteVecQueueDatabase
from frigate.models import Event from frigate.models import Event
from frigate.types import ModelStatusTypesEnum from frigate.types import ModelStatusTypesEnum
from frigate.util.builtin import serialize
from .functions.onnx import GenericONNXEmbedding from .functions.onnx import GenericONNXEmbedding
@ -54,30 +52,6 @@ def get_metadata(event: Event) -> dict:
) )
def serialize(vector: Union[List[float], np.ndarray, float]) -> bytes:
"""Serializes a list of floats, numpy array, or single float into a compact "raw bytes" format"""
if isinstance(vector, np.ndarray):
# Convert numpy array to list of floats
vector = vector.flatten().tolist()
elif isinstance(vector, (float, np.float32, np.float64)):
# Handle single float values
vector = [vector]
elif not isinstance(vector, list):
raise TypeError(
f"Input must be a list of floats, a numpy array, or a single float. Got {type(vector)}"
)
try:
return struct.pack("%sf" % len(vector), *vector)
except struct.error as e:
raise ValueError(f"Failed to pack vector: {e}. Vector: {vector}")
def deserialize(bytes_data: bytes) -> List[float]:
"""Deserializes a compact "raw bytes" format into a list of floats"""
return list(struct.unpack("%sf" % (len(bytes_data) // 4), bytes_data))
class Embeddings: class Embeddings:
"""SQLite-vec embeddings database.""" """SQLite-vec embeddings database."""
@ -190,106 +164,6 @@ class Embeddings:
return embedding return embedding
def delete_thumbnail(self, event_ids: List[str]) -> None:
ids = ",".join(["?" for _ in event_ids])
self.db.execute_sql(
f"DELETE FROM vec_thumbnails WHERE id IN ({ids})", event_ids
)
def delete_description(self, event_ids: List[str]) -> None:
ids = ",".join(["?" for _ in event_ids])
self.db.execute_sql(
f"DELETE FROM vec_descriptions WHERE id IN ({ids})", event_ids
)
def search_thumbnail(
self, query: Union[Event, str], event_ids: List[str] = None
) -> List[Tuple[str, float]]:
if query.__class__ == Event:
cursor = self.db.execute_sql(
"""
SELECT thumbnail_embedding FROM vec_thumbnails WHERE id = ?
""",
[query.id],
)
row = cursor.fetchone() if cursor else None
if row:
query_embedding = deserialize(
row[0]
) # Deserialize the thumbnail embedding
else:
# If no embedding found, generate it and return it
thumbnail = base64.b64decode(query.thumbnail)
query_embedding = self.upsert_thumbnail(query.id, thumbnail)
else:
query_embedding = self.text_embedding([query])[0]
sql_query = """
SELECT
id,
distance
FROM vec_thumbnails
WHERE thumbnail_embedding MATCH ?
AND k = 100
"""
# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
# but it seems to be broken in this version
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
# order by distance DESC is not implemented in this version of sqlite-vec
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"
parameters = (
[serialize(query_embedding)] + event_ids
if event_ids
else [serialize(query_embedding)]
)
results = self.db.execute_sql(sql_query, parameters).fetchall()
return results
def search_description(
self, query_text: str, event_ids: List[str] = None
) -> List[Tuple[str, float]]:
query_embedding = self.text_embedding([query_text])[0]
# Prepare the base SQL query
sql_query = """
SELECT
id,
distance
FROM vec_descriptions
WHERE description_embedding MATCH ?
AND k = 100
"""
# Add the IN clause if event_ids is provided and not empty
# this is the only filter supported by sqlite-vec as of 0.1.3
# but it seems to be broken in this version
if event_ids:
sql_query += " AND id IN ({})".format(",".join("?" * len(event_ids)))
# order by distance DESC is not implemented in this version of sqlite-vec
# when it's implemented, we can use cosine similarity
sql_query += " ORDER BY distance"
parameters = (
[serialize(query_embedding)] + event_ids
if event_ids
else [serialize(query_embedding)]
)
results = self.db.execute_sql(sql_query, parameters).fetchall()
return results
def reindex(self) -> None: def reindex(self) -> None:
logger.info("Indexing event embeddings...") logger.info("Indexing event embeddings...")

View File

@ -12,6 +12,7 @@ import numpy as np
from peewee import DoesNotExist from peewee import DoesNotExist
from playhouse.sqliteq import SqliteQueueDatabase from playhouse.sqliteq import SqliteQueueDatabase
from frigate.comms.embeddings_updater import EmbeddingsRequestEnum, EmbeddingsResponder
from frigate.comms.event_metadata_updater import ( from frigate.comms.event_metadata_updater import (
EventMetadataSubscriber, EventMetadataSubscriber,
EventMetadataTypeEnum, EventMetadataTypeEnum,
@ -23,6 +24,7 @@ from frigate.const import CLIPS_DIR, UPDATE_EVENT_DESCRIPTION
from frigate.events.types import EventTypeEnum from frigate.events.types import EventTypeEnum
from frigate.genai import get_genai_client from frigate.genai import get_genai_client
from frigate.models import Event from frigate.models import Event
from frigate.util.builtin import serialize
from frigate.util.image import SharedMemoryFrameManager, calculate_region from frigate.util.image import SharedMemoryFrameManager, calculate_region
from .embeddings import Embeddings from .embeddings import Embeddings
@ -48,6 +50,7 @@ class EmbeddingMaintainer(threading.Thread):
self.event_metadata_subscriber = EventMetadataSubscriber( self.event_metadata_subscriber = EventMetadataSubscriber(
EventMetadataTypeEnum.regenerate_description EventMetadataTypeEnum.regenerate_description
) )
self.embeddings_responder = EmbeddingsResponder()
self.frame_manager = SharedMemoryFrameManager() self.frame_manager = SharedMemoryFrameManager()
# create communication for updating event descriptions # create communication for updating event descriptions
self.requestor = InterProcessRequestor() self.requestor = InterProcessRequestor()
@ -58,6 +61,7 @@ class EmbeddingMaintainer(threading.Thread):
def run(self) -> None: def run(self) -> None:
"""Maintain a SQLite-vec database for semantic search.""" """Maintain a SQLite-vec database for semantic search."""
while not self.stop_event.is_set(): while not self.stop_event.is_set():
self._process_requests()
self._process_updates() self._process_updates()
self._process_finalized() self._process_finalized()
self._process_event_metadata() self._process_event_metadata()
@ -65,9 +69,30 @@ class EmbeddingMaintainer(threading.Thread):
self.event_subscriber.stop() self.event_subscriber.stop()
self.event_end_subscriber.stop() self.event_end_subscriber.stop()
self.event_metadata_subscriber.stop() self.event_metadata_subscriber.stop()
self.embeddings_responder.stop()
self.requestor.stop() self.requestor.stop()
logger.info("Exiting embeddings maintenance...") logger.info("Exiting embeddings maintenance...")
def _process_requests(self) -> None:
"""Process embeddings requests"""
def handle_request(topic: str, data: str) -> str:
if topic == EmbeddingsRequestEnum.embed_description.value:
return serialize(
self.embeddings.upsert_description(data["id"], data["description"]),
pack=False,
)
elif topic == EmbeddingsRequestEnum.embed_thumbnail.value:
thumbnail = base64.b64decode(data["thumbnail"])
return serialize(
self.embeddings.upsert_thumbnail(data["id"], thumbnail),
pack=False,
)
elif topic == EmbeddingsRequestEnum.generate_search.value:
return serialize(self.embeddings.text_embedding([data])[0], pack=False)
self.embeddings_responder.check_for_request(handle_request)
def _process_updates(self) -> None: def _process_updates(self) -> None:
"""Process event updates""" """Process event updates"""
update = self.event_subscriber.check_for_update() update = self.event_subscriber.check_for_update()

View File

@ -8,10 +8,11 @@ import multiprocessing as mp
import queue import queue
import re import re
import shlex import shlex
import struct
import urllib.parse import urllib.parse
from collections.abc import Mapping from collections.abc import Mapping
from pathlib import Path from pathlib import Path
from typing import Any, Optional, Tuple from typing import Any, Optional, Tuple, Union
import numpy as np import numpy as np
import pytz import pytz
@ -342,3 +343,32 @@ def generate_color_palette(n):
colors.append(interpolate(color1, color2, factor)) colors.append(interpolate(color1, color2, factor))
return colors return colors
def serialize(
vector: Union[list[float], np.ndarray, float], pack: bool = True
) -> bytes:
"""Serializes a list of floats, numpy array, or single float into a compact "raw bytes" format"""
if isinstance(vector, np.ndarray):
# Convert numpy array to list of floats
vector = vector.flatten().tolist()
elif isinstance(vector, (float, np.float32, np.float64)):
# Handle single float values
vector = [vector]
elif not isinstance(vector, list):
raise TypeError(
f"Input must be a list of floats, a numpy array, or a single float. Got {type(vector)}"
)
try:
if pack:
return struct.pack("%sf" % len(vector), *vector)
else:
return vector
except struct.error as e:
raise ValueError(f"Failed to pack vector: {e}. Vector: {vector}")
def deserialize(bytes_data: bytes) -> list[float]:
"""Deserializes a compact "raw bytes" format into a list of floats"""
return list(struct.unpack("%sf" % (len(bytes_data) // 4), bytes_data))