mirror of
https://github.com/blakeblackshear/frigate.git
synced 2024-11-25 10:20:32 -06:00
Use JinaAI models for embeddings (#14252)
* add generic onnx model class and use jina ai clip models for all embeddings * fix merge confligt * add generic onnx model class and use jina ai clip models for all embeddings * fix merge confligt * preferred providers * fix paths * disable download progress bar * remove logging of path * drop and recreate tables on reindex * use cache paths * fix model name * use trust remote code per transformers docs * ensure tokenizer and feature extractor are correctly loaded * revert * manually download and cache feature extractor config * remove unneeded * remove old clip and minilm code * docs update
This commit is contained in:
parent
dbeaf43b8f
commit
d4925622f9
@ -5,7 +5,7 @@ title: Using Semantic Search
|
||||
|
||||
Semantic Search in Frigate allows you to find tracked objects within your review items using either the image itself, a user-defined text description, or an automatically generated one. This feature works by creating _embeddings_ — numerical vector representations — for both the images and text descriptions of your tracked objects. By comparing these embeddings, Frigate assesses their similarities to deliver relevant search results.
|
||||
|
||||
Frigate has support for two models to create embeddings, both of which run locally: [OpenAI CLIP](https://openai.com/research/clip) and [all-MiniLM-L6-v2](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2). Embeddings are then saved to Frigate's database.
|
||||
Frigate has support for [Jina AI's CLIP model](https://huggingface.co/jinaai/jina-clip-v1) to create embeddings, which runs locally. Embeddings are then saved to Frigate's database.
|
||||
|
||||
Semantic Search is accessed via the _Explore_ view in the Frigate UI.
|
||||
|
||||
@ -27,13 +27,11 @@ If you are enabling the Search feature for the first time, be advised that Friga
|
||||
|
||||
:::
|
||||
|
||||
### OpenAI CLIP
|
||||
### Jina AI CLIP
|
||||
|
||||
This model is able to embed both images and text into the same vector space, which allows `image -> image` and `text -> image` similarity searches. Frigate uses this model on tracked objects to encode the thumbnail image and store it in the database. When searching for tracked objects via text in the search box, Frigate will perform a `text -> image` similarity search against this embedding. When clicking "Find Similar" in the tracked object detail pane, Frigate will perform an `image -> image` similarity search to retrieve the closest matching thumbnails.
|
||||
The vision model is able to embed both images and text into the same vector space, which allows `image -> image` and `text -> image` similarity searches. Frigate uses this model on tracked objects to encode the thumbnail image and store it in the database. When searching for tracked objects via text in the search box, Frigate will perform a `text -> image` similarity search against this embedding. When clicking "Find Similar" in the tracked object detail pane, Frigate will perform an `image -> image` similarity search to retrieve the closest matching thumbnails.
|
||||
|
||||
### all-MiniLM-L6-v2
|
||||
|
||||
This is a sentence embedding model that has been fine tuned on over 1 billion sentence pairs. This model is used to embed tracked object descriptions and perform searches against them. Descriptions can be created, viewed, and modified on the Search page when clicking on the gray tracked object chip at the top left of each review item. See [the Generative AI docs](/configuration/genai.md) for more information on how to automatically generate tracked object descriptions.
|
||||
The text model is used to embed tracked object descriptions and perform searches against them. Descriptions can be created, viewed, and modified on the Search page when clicking on the gray tracked object chip at the top left of each review item. See [the Generative AI docs](/configuration/genai.md) for more information on how to automatically generate tracked object descriptions.
|
||||
|
||||
## Usage
|
||||
|
||||
|
@ -73,7 +73,7 @@ class EmbeddingsContext:
|
||||
def __init__(self, db: SqliteVecQueueDatabase):
|
||||
self.embeddings = Embeddings(db)
|
||||
self.thumb_stats = ZScoreNormalization()
|
||||
self.desc_stats = ZScoreNormalization(scale_factor=3, bias=-2.5)
|
||||
self.desc_stats = ZScoreNormalization()
|
||||
|
||||
# load stats from disk
|
||||
try:
|
||||
|
@ -7,6 +7,7 @@ import struct
|
||||
import time
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
|
||||
@ -16,8 +17,7 @@ from frigate.db.sqlitevecq import SqliteVecQueueDatabase
|
||||
from frigate.models import Event
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
|
||||
from .functions.clip import ClipEmbedding
|
||||
from .functions.minilm_l6_v2 import MiniLMEmbedding
|
||||
from .functions.onnx import GenericONNXEmbedding
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -53,9 +53,23 @@ def get_metadata(event: Event) -> dict:
|
||||
)
|
||||
|
||||
|
||||
def serialize(vector: List[float]) -> bytes:
|
||||
"""Serializes a list of floats into a compact "raw bytes" format"""
|
||||
def serialize(vector: Union[List[float], np.ndarray, float]) -> bytes:
|
||||
"""Serializes a list of floats, numpy array, or single float into a compact "raw bytes" format"""
|
||||
if isinstance(vector, np.ndarray):
|
||||
# Convert numpy array to list of floats
|
||||
vector = vector.flatten().tolist()
|
||||
elif isinstance(vector, (float, np.float32, np.float64)):
|
||||
# Handle single float values
|
||||
vector = [vector]
|
||||
elif not isinstance(vector, list):
|
||||
raise TypeError(
|
||||
f"Input must be a list of floats, a numpy array, or a single float. Got {type(vector)}"
|
||||
)
|
||||
|
||||
try:
|
||||
return struct.pack("%sf" % len(vector), *vector)
|
||||
except struct.error as e:
|
||||
raise ValueError(f"Failed to pack vector: {e}. Vector: {vector}")
|
||||
|
||||
|
||||
def deserialize(bytes_data: bytes) -> List[float]:
|
||||
@ -74,10 +88,10 @@ class Embeddings:
|
||||
self._create_tables()
|
||||
|
||||
models = [
|
||||
"sentence-transformers/all-MiniLM-L6-v2-model.onnx",
|
||||
"sentence-transformers/all-MiniLM-L6-v2-tokenizer",
|
||||
"clip-clip_image_model_vitb32.onnx",
|
||||
"clip-clip_text_model_vitb32.onnx",
|
||||
"jinaai/jina-clip-v1-text_model_fp16.onnx",
|
||||
"jinaai/jina-clip-v1-tokenizer",
|
||||
"jinaai/jina-clip-v1-vision_model_fp16.onnx",
|
||||
"jinaai/jina-clip-v1-preprocessor_config.json",
|
||||
]
|
||||
|
||||
for model in models:
|
||||
@ -89,10 +103,33 @@ class Embeddings:
|
||||
},
|
||||
)
|
||||
|
||||
self.clip_embedding = ClipEmbedding(
|
||||
preferred_providers=["CPUExecutionProvider"]
|
||||
def jina_text_embedding_function(outputs):
|
||||
return outputs[0]
|
||||
|
||||
def jina_vision_embedding_function(outputs):
|
||||
return outputs[0]
|
||||
|
||||
self.text_embedding = GenericONNXEmbedding(
|
||||
model_name="jinaai/jina-clip-v1",
|
||||
model_file="text_model_fp16.onnx",
|
||||
tokenizer_file="tokenizer",
|
||||
download_urls={
|
||||
"text_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/text_model_fp16.onnx",
|
||||
},
|
||||
embedding_function=jina_text_embedding_function,
|
||||
model_type="text",
|
||||
preferred_providers=["CPUExecutionProvider"],
|
||||
)
|
||||
self.minilm_embedding = MiniLMEmbedding(
|
||||
|
||||
self.vision_embedding = GenericONNXEmbedding(
|
||||
model_name="jinaai/jina-clip-v1",
|
||||
model_file="vision_model_fp16.onnx",
|
||||
download_urls={
|
||||
"vision_model_fp16.onnx": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/onnx/vision_model_fp16.onnx",
|
||||
"preprocessor_config.json": "https://huggingface.co/jinaai/jina-clip-v1/resolve/main/preprocessor_config.json",
|
||||
},
|
||||
embedding_function=jina_vision_embedding_function,
|
||||
model_type="vision",
|
||||
preferred_providers=["CPUExecutionProvider"],
|
||||
)
|
||||
|
||||
@ -101,7 +138,7 @@ class Embeddings:
|
||||
self.db.execute_sql("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_thumbnails USING vec0(
|
||||
id TEXT PRIMARY KEY,
|
||||
thumbnail_embedding FLOAT[512]
|
||||
thumbnail_embedding FLOAT[768]
|
||||
);
|
||||
""")
|
||||
|
||||
@ -109,15 +146,22 @@ class Embeddings:
|
||||
self.db.execute_sql("""
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS vec_descriptions USING vec0(
|
||||
id TEXT PRIMARY KEY,
|
||||
description_embedding FLOAT[384]
|
||||
description_embedding FLOAT[768]
|
||||
);
|
||||
""")
|
||||
|
||||
def _drop_tables(self):
|
||||
self.db.execute_sql("""
|
||||
DROP TABLE vec_descriptions;
|
||||
""")
|
||||
self.db.execute_sql("""
|
||||
DROP TABLE vec_thumbnails;
|
||||
""")
|
||||
|
||||
def upsert_thumbnail(self, event_id: str, thumbnail: bytes):
|
||||
# Convert thumbnail bytes to PIL Image
|
||||
image = Image.open(io.BytesIO(thumbnail)).convert("RGB")
|
||||
# Generate embedding using CLIP
|
||||
embedding = self.clip_embedding([image])[0]
|
||||
embedding = self.vision_embedding([image])[0]
|
||||
|
||||
self.db.execute_sql(
|
||||
"""
|
||||
@ -130,8 +174,7 @@ class Embeddings:
|
||||
return embedding
|
||||
|
||||
def upsert_description(self, event_id: str, description: str):
|
||||
# Generate embedding using MiniLM
|
||||
embedding = self.minilm_embedding([description])[0]
|
||||
embedding = self.text_embedding([description])[0]
|
||||
|
||||
self.db.execute_sql(
|
||||
"""
|
||||
@ -177,7 +220,7 @@ class Embeddings:
|
||||
thumbnail = base64.b64decode(query.thumbnail)
|
||||
query_embedding = self.upsert_thumbnail(query.id, thumbnail)
|
||||
else:
|
||||
query_embedding = self.clip_embedding([query])[0]
|
||||
query_embedding = self.text_embedding([query])[0]
|
||||
|
||||
sql_query = """
|
||||
SELECT
|
||||
@ -211,7 +254,7 @@ class Embeddings:
|
||||
def search_description(
|
||||
self, query_text: str, event_ids: List[str] = None
|
||||
) -> List[Tuple[str, float]]:
|
||||
query_embedding = self.minilm_embedding([query_text])[0]
|
||||
query_embedding = self.text_embedding([query_text])[0]
|
||||
|
||||
# Prepare the base SQL query
|
||||
sql_query = """
|
||||
@ -246,6 +289,9 @@ class Embeddings:
|
||||
def reindex(self) -> None:
|
||||
logger.info("Indexing event embeddings...")
|
||||
|
||||
self._drop_tables()
|
||||
self._create_tables()
|
||||
|
||||
st = time.time()
|
||||
totals = {
|
||||
"thumb": 0,
|
||||
|
@ -1,166 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from onnx_clip import OnnxClip, Preprocessor, Tokenizer
|
||||
from PIL import Image
|
||||
|
||||
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util.downloader import ModelDownloader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Clip(OnnxClip):
|
||||
"""Override load models to use pre-downloaded models from cache directory."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "ViT-B/32",
|
||||
batch_size: Optional[int] = None,
|
||||
providers: List[str] = ["CPUExecutionProvider"],
|
||||
):
|
||||
"""
|
||||
Instantiates the model and required encoding classes.
|
||||
|
||||
Args:
|
||||
model: The model to utilize. Currently ViT-B/32 and RN50 are
|
||||
allowed.
|
||||
batch_size: If set, splits the lists in `get_image_embeddings`
|
||||
and `get_text_embeddings` into batches of this size before
|
||||
passing them to the model. The embeddings are then concatenated
|
||||
back together before being returned. This is necessary when
|
||||
passing large amounts of data (perhaps ~100 or more).
|
||||
"""
|
||||
allowed_models = ["ViT-B/32", "RN50"]
|
||||
if model not in allowed_models:
|
||||
raise ValueError(f"`model` must be in {allowed_models}. Got {model}.")
|
||||
if model == "ViT-B/32":
|
||||
self.embedding_size = 512
|
||||
elif model == "RN50":
|
||||
self.embedding_size = 1024
|
||||
self.image_model, self.text_model = self._load_models(model, providers)
|
||||
self._tokenizer = Tokenizer()
|
||||
self._preprocessor = Preprocessor()
|
||||
self._batch_size = batch_size
|
||||
|
||||
@staticmethod
|
||||
def _load_models(
|
||||
model: str,
|
||||
providers: List[str],
|
||||
) -> tuple[ort.InferenceSession, ort.InferenceSession]:
|
||||
"""
|
||||
Load models from cache directory.
|
||||
"""
|
||||
if model == "ViT-B/32":
|
||||
IMAGE_MODEL_FILE = "clip_image_model_vitb32.onnx"
|
||||
TEXT_MODEL_FILE = "clip_text_model_vitb32.onnx"
|
||||
elif model == "RN50":
|
||||
IMAGE_MODEL_FILE = "clip_image_model_rn50.onnx"
|
||||
TEXT_MODEL_FILE = "clip_text_model_rn50.onnx"
|
||||
else:
|
||||
raise ValueError(f"Unexpected model {model}. No `.onnx` file found.")
|
||||
|
||||
models = []
|
||||
for model_file in [IMAGE_MODEL_FILE, TEXT_MODEL_FILE]:
|
||||
path = os.path.join(MODEL_CACHE_DIR, "clip", model_file)
|
||||
models.append(Clip._load_model(path, providers))
|
||||
|
||||
return models[0], models[1]
|
||||
|
||||
@staticmethod
|
||||
def _load_model(path: str, providers: List[str]):
|
||||
if os.path.exists(path):
|
||||
return ort.InferenceSession(path, providers=providers)
|
||||
else:
|
||||
logger.warning(f"CLIP model file {path} not found.")
|
||||
return None
|
||||
|
||||
|
||||
class ClipEmbedding:
|
||||
"""Embedding function for CLIP model."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "ViT-B/32",
|
||||
silent: bool = False,
|
||||
preferred_providers: List[str] = ["CPUExecutionProvider"],
|
||||
):
|
||||
self.model_name = model
|
||||
self.silent = silent
|
||||
self.preferred_providers = preferred_providers
|
||||
self.model_files = self._get_model_files()
|
||||
self.model = None
|
||||
|
||||
self.downloader = ModelDownloader(
|
||||
model_name="clip",
|
||||
download_path=os.path.join(MODEL_CACHE_DIR, "clip"),
|
||||
file_names=self.model_files,
|
||||
download_func=self._download_model,
|
||||
silent=self.silent,
|
||||
)
|
||||
self.downloader.ensure_model_files()
|
||||
|
||||
def _get_model_files(self):
|
||||
if self.model_name == "ViT-B/32":
|
||||
return ["clip_image_model_vitb32.onnx", "clip_text_model_vitb32.onnx"]
|
||||
elif self.model_name == "RN50":
|
||||
return ["clip_image_model_rn50.onnx", "clip_text_model_rn50.onnx"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected model {self.model_name}. No `.onnx` file found."
|
||||
)
|
||||
|
||||
def _download_model(self, path: str):
|
||||
s3_url = (
|
||||
f"https://lakera-clip.s3.eu-west-1.amazonaws.com/{os.path.basename(path)}"
|
||||
)
|
||||
try:
|
||||
ModelDownloader.download_from_url(s3_url, path, self.silent)
|
||||
self.downloader.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": f"{self.model_name}-{os.path.basename(path)}",
|
||||
"state": ModelStatusTypesEnum.downloaded,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
self.downloader.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": f"{self.model_name}-{os.path.basename(path)}",
|
||||
"state": ModelStatusTypesEnum.error,
|
||||
},
|
||||
)
|
||||
|
||||
def _load_model(self):
|
||||
if self.model is None:
|
||||
self.downloader.wait_for_download()
|
||||
self.model = Clip(self.model_name, providers=self.preferred_providers)
|
||||
|
||||
def __call__(self, input: Union[List[str], List[Image.Image]]) -> List[np.ndarray]:
|
||||
self._load_model()
|
||||
if (
|
||||
self.model is None
|
||||
or self.model.image_model is None
|
||||
or self.model.text_model is None
|
||||
):
|
||||
logger.info(
|
||||
"CLIP model is not fully loaded. Please wait for the download to complete."
|
||||
)
|
||||
return []
|
||||
|
||||
embeddings = []
|
||||
for item in input:
|
||||
if isinstance(item, Image.Image):
|
||||
result = self.model.get_image_embeddings([item])
|
||||
embeddings.append(result[0])
|
||||
elif isinstance(item, str):
|
||||
result = self.model.get_text_embeddings([item])
|
||||
embeddings.append(result[0])
|
||||
else:
|
||||
raise ValueError(f"Unsupported input type: {type(item)}")
|
||||
return embeddings
|
@ -1,107 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
# importing this without pytorch or others causes a warning
|
||||
# https://github.com/huggingface/transformers/issues/27214
|
||||
# suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util.downloader import ModelDownloader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MiniLMEmbedding:
|
||||
"""Embedding function for ONNX MiniLM-L6 model."""
|
||||
|
||||
DOWNLOAD_PATH = f"{MODEL_CACHE_DIR}/all-MiniLM-L6-v2"
|
||||
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
||||
IMAGE_MODEL_FILE = "model.onnx"
|
||||
TOKENIZER_FILE = "tokenizer"
|
||||
|
||||
def __init__(self, preferred_providers=["CPUExecutionProvider"]):
|
||||
self.preferred_providers = preferred_providers
|
||||
self.tokenizer = None
|
||||
self.session = None
|
||||
|
||||
self.downloader = ModelDownloader(
|
||||
model_name=self.MODEL_NAME,
|
||||
download_path=self.DOWNLOAD_PATH,
|
||||
file_names=[self.IMAGE_MODEL_FILE, self.TOKENIZER_FILE],
|
||||
download_func=self._download_model,
|
||||
)
|
||||
self.downloader.ensure_model_files()
|
||||
|
||||
def _download_model(self, path: str):
|
||||
try:
|
||||
if os.path.basename(path) == self.IMAGE_MODEL_FILE:
|
||||
s3_url = f"https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2/resolve/main/onnx/{self.IMAGE_MODEL_FILE}"
|
||||
ModelDownloader.download_from_url(s3_url, path)
|
||||
elif os.path.basename(path) == self.TOKENIZER_FILE:
|
||||
logger.info("Downloading MiniLM tokenizer")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.MODEL_NAME, clean_up_tokenization_spaces=True
|
||||
)
|
||||
tokenizer.save_pretrained(path)
|
||||
|
||||
self.downloader.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": f"{self.MODEL_NAME}-{os.path.basename(path)}",
|
||||
"state": ModelStatusTypesEnum.downloaded,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
self.downloader.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": f"{self.MODEL_NAME}-{os.path.basename(path)}",
|
||||
"state": ModelStatusTypesEnum.error,
|
||||
},
|
||||
)
|
||||
|
||||
def _load_model_and_tokenizer(self):
|
||||
if self.tokenizer is None or self.session is None:
|
||||
self.downloader.wait_for_download()
|
||||
self.tokenizer = self._load_tokenizer()
|
||||
self.session = self._load_model(
|
||||
os.path.join(self.DOWNLOAD_PATH, self.IMAGE_MODEL_FILE),
|
||||
self.preferred_providers,
|
||||
)
|
||||
|
||||
def _load_tokenizer(self):
|
||||
tokenizer_path = os.path.join(self.DOWNLOAD_PATH, self.TOKENIZER_FILE)
|
||||
return AutoTokenizer.from_pretrained(
|
||||
tokenizer_path, clean_up_tokenization_spaces=True
|
||||
)
|
||||
|
||||
def _load_model(self, path: str, providers: List[str]):
|
||||
if os.path.exists(path):
|
||||
return ort.InferenceSession(path, providers=providers)
|
||||
else:
|
||||
logger.warning(f"MiniLM model file {path} not found.")
|
||||
return None
|
||||
|
||||
def __call__(self, texts: List[str]) -> List[np.ndarray]:
|
||||
self._load_model_and_tokenizer()
|
||||
|
||||
if self.session is None or self.tokenizer is None:
|
||||
logger.error("MiniLM model or tokenizer is not loaded.")
|
||||
return []
|
||||
|
||||
inputs = self.tokenizer(
|
||||
texts, padding=True, truncation=True, return_tensors="np"
|
||||
)
|
||||
input_names = [input.name for input in self.session.get_inputs()]
|
||||
onnx_inputs = {name: inputs[name] for name in input_names if name in inputs}
|
||||
|
||||
outputs = self.session.run(None, onnx_inputs)
|
||||
embeddings = outputs[0].mean(axis=1)
|
||||
|
||||
return [embedding for embedding in embeddings]
|
174
frigate/embeddings/functions/onnx.py
Normal file
174
frigate/embeddings/functions/onnx.py
Normal file
@ -0,0 +1,174 @@
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from io import BytesIO
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
# importing this without pytorch or others causes a warning
|
||||
# https://github.com/huggingface/transformers/issues/27214
|
||||
# suppressed by setting env TRANSFORMERS_NO_ADVISORY_WARNINGS=1
|
||||
from transformers import AutoFeatureExtractor, AutoTokenizer
|
||||
from transformers.utils.logging import disable_progress_bar
|
||||
|
||||
from frigate.const import MODEL_CACHE_DIR, UPDATE_MODEL_STATE
|
||||
from frigate.types import ModelStatusTypesEnum
|
||||
from frigate.util.downloader import ModelDownloader
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
category=FutureWarning,
|
||||
message="The class CLIPFeatureExtractor is deprecated",
|
||||
)
|
||||
|
||||
# disables the progress bar for downloading tokenizers and feature extractors
|
||||
disable_progress_bar()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenericONNXEmbedding:
|
||||
"""Generic embedding function for ONNX models (text and vision)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
model_file: str,
|
||||
download_urls: Dict[str, str],
|
||||
embedding_function: Callable[[List[np.ndarray]], np.ndarray],
|
||||
model_type: str,
|
||||
preferred_providers: List[str] = ["CPUExecutionProvider"],
|
||||
tokenizer_file: Optional[str] = None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.model_file = model_file
|
||||
self.tokenizer_file = tokenizer_file
|
||||
self.download_urls = download_urls
|
||||
self.embedding_function = embedding_function
|
||||
self.model_type = model_type # 'text' or 'vision'
|
||||
self.preferred_providers = preferred_providers
|
||||
|
||||
self.download_path = os.path.join(MODEL_CACHE_DIR, self.model_name)
|
||||
self.tokenizer = None
|
||||
self.feature_extractor = None
|
||||
self.session = None
|
||||
|
||||
self.downloader = ModelDownloader(
|
||||
model_name=self.model_name,
|
||||
download_path=self.download_path,
|
||||
file_names=list(self.download_urls.keys())
|
||||
+ ([self.tokenizer_file] if self.tokenizer_file else []),
|
||||
download_func=self._download_model,
|
||||
)
|
||||
self.downloader.ensure_model_files()
|
||||
|
||||
def _download_model(self, path: str):
|
||||
try:
|
||||
file_name = os.path.basename(path)
|
||||
if file_name in self.download_urls:
|
||||
ModelDownloader.download_from_url(self.download_urls[file_name], path)
|
||||
elif file_name == self.tokenizer_file and self.model_type == "text":
|
||||
if not os.path.exists(path + "/" + self.model_name):
|
||||
logger.info(f"Downloading {self.model_name} tokenizer")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
self.model_name,
|
||||
trust_remote_code=True,
|
||||
cache_dir=f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer",
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
tokenizer.save_pretrained(path)
|
||||
|
||||
self.downloader.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": f"{self.model_name}-{file_name}",
|
||||
"state": ModelStatusTypesEnum.downloaded,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
self.downloader.requestor.send_data(
|
||||
UPDATE_MODEL_STATE,
|
||||
{
|
||||
"model": f"{self.model_name}-{file_name}",
|
||||
"state": ModelStatusTypesEnum.error,
|
||||
},
|
||||
)
|
||||
|
||||
def _load_model_and_tokenizer(self):
|
||||
if self.session is None:
|
||||
self.downloader.wait_for_download()
|
||||
if self.model_type == "text":
|
||||
self.tokenizer = self._load_tokenizer()
|
||||
else:
|
||||
self.feature_extractor = self._load_feature_extractor()
|
||||
self.session = self._load_model(
|
||||
os.path.join(self.download_path, self.model_file),
|
||||
self.preferred_providers,
|
||||
)
|
||||
|
||||
def _load_tokenizer(self):
|
||||
tokenizer_path = os.path.join(f"{MODEL_CACHE_DIR}/{self.model_name}/tokenizer")
|
||||
return AutoTokenizer.from_pretrained(
|
||||
self.model_name,
|
||||
cache_dir=tokenizer_path,
|
||||
trust_remote_code=True,
|
||||
clean_up_tokenization_spaces=True,
|
||||
)
|
||||
|
||||
def _load_feature_extractor(self):
|
||||
return AutoFeatureExtractor.from_pretrained(
|
||||
f"{MODEL_CACHE_DIR}/{self.model_name}",
|
||||
)
|
||||
|
||||
def _load_model(self, path: str, providers: List[str]):
|
||||
if os.path.exists(path):
|
||||
return ort.InferenceSession(path, providers=providers)
|
||||
else:
|
||||
logger.warning(f"{self.model_name} model file {path} not found.")
|
||||
return None
|
||||
|
||||
def _process_image(self, image):
|
||||
if isinstance(image, str):
|
||||
if image.startswith("http"):
|
||||
response = requests.get(image)
|
||||
image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
|
||||
return image
|
||||
|
||||
def __call__(
|
||||
self, inputs: Union[List[str], List[Image.Image], List[str]]
|
||||
) -> List[np.ndarray]:
|
||||
self._load_model_and_tokenizer()
|
||||
|
||||
if self.session is None or (
|
||||
self.tokenizer is None and self.feature_extractor is None
|
||||
):
|
||||
logger.error(
|
||||
f"{self.model_name} model or tokenizer/feature extractor is not loaded."
|
||||
)
|
||||
return []
|
||||
|
||||
if self.model_type == "text":
|
||||
processed_inputs = self.tokenizer(
|
||||
inputs, padding=True, truncation=True, return_tensors="np"
|
||||
)
|
||||
else:
|
||||
processed_images = [self._process_image(img) for img in inputs]
|
||||
processed_inputs = self.feature_extractor(
|
||||
images=processed_images, return_tensors="np"
|
||||
)
|
||||
|
||||
input_names = [input.name for input in self.session.get_inputs()]
|
||||
onnx_inputs = {
|
||||
name: processed_inputs[name]
|
||||
for name in input_names
|
||||
if name in processed_inputs
|
||||
}
|
||||
|
||||
outputs = self.session.run(None, onnx_inputs)
|
||||
embeddings = self.embedding_function(outputs)
|
||||
|
||||
return [embedding for embedding in embeddings]
|
@ -184,31 +184,31 @@ export default function Explore() {
|
||||
|
||||
// model states
|
||||
|
||||
const { payload: minilmModelState } = useModelState(
|
||||
"sentence-transformers/all-MiniLM-L6-v2-model.onnx",
|
||||
const { payload: textModelState } = useModelState(
|
||||
"jinaai/jina-clip-v1-text_model_fp16.onnx",
|
||||
);
|
||||
const { payload: minilmTokenizerState } = useModelState(
|
||||
"sentence-transformers/all-MiniLM-L6-v2-tokenizer",
|
||||
const { payload: textTokenizerState } = useModelState(
|
||||
"jinaai/jina-clip-v1-tokenizer",
|
||||
);
|
||||
const { payload: clipImageModelState } = useModelState(
|
||||
"clip-clip_image_model_vitb32.onnx",
|
||||
const { payload: visionModelState } = useModelState(
|
||||
"jinaai/jina-clip-v1-vision_model_fp16.onnx",
|
||||
);
|
||||
const { payload: clipTextModelState } = useModelState(
|
||||
"clip-clip_text_model_vitb32.onnx",
|
||||
const { payload: visionFeatureExtractorState } = useModelState(
|
||||
"jinaai/jina-clip-v1-preprocessor_config.json",
|
||||
);
|
||||
|
||||
const allModelsLoaded = useMemo(() => {
|
||||
return (
|
||||
minilmModelState === "downloaded" &&
|
||||
minilmTokenizerState === "downloaded" &&
|
||||
clipImageModelState === "downloaded" &&
|
||||
clipTextModelState === "downloaded"
|
||||
textModelState === "downloaded" &&
|
||||
textTokenizerState === "downloaded" &&
|
||||
visionModelState === "downloaded" &&
|
||||
visionFeatureExtractorState === "downloaded"
|
||||
);
|
||||
}, [
|
||||
minilmModelState,
|
||||
minilmTokenizerState,
|
||||
clipImageModelState,
|
||||
clipTextModelState,
|
||||
textModelState,
|
||||
textTokenizerState,
|
||||
visionModelState,
|
||||
visionFeatureExtractorState,
|
||||
]);
|
||||
|
||||
const renderModelStateIcon = (modelState: ModelState) => {
|
||||
@ -226,10 +226,10 @@ export default function Explore() {
|
||||
|
||||
if (
|
||||
config?.semantic_search.enabled &&
|
||||
(!minilmModelState ||
|
||||
!minilmTokenizerState ||
|
||||
!clipImageModelState ||
|
||||
!clipTextModelState)
|
||||
(!textModelState ||
|
||||
!textTokenizerState ||
|
||||
!visionModelState ||
|
||||
!visionFeatureExtractorState)
|
||||
) {
|
||||
return (
|
||||
<ActivityIndicator className="absolute left-1/2 top-1/2 -translate-x-1/2 -translate-y-1/2" />
|
||||
@ -252,25 +252,26 @@ export default function Explore() {
|
||||
</div>
|
||||
<div className="flex w-96 flex-col gap-2 py-5">
|
||||
<div className="flex flex-row items-center justify-center gap-2">
|
||||
{renderModelStateIcon(clipImageModelState)}
|
||||
CLIP image model
|
||||
{renderModelStateIcon(visionModelState)}
|
||||
Vision model
|
||||
</div>
|
||||
<div className="flex flex-row items-center justify-center gap-2">
|
||||
{renderModelStateIcon(clipTextModelState)}
|
||||
CLIP text model
|
||||
{renderModelStateIcon(visionFeatureExtractorState)}
|
||||
Vision model feature extractor
|
||||
</div>
|
||||
<div className="flex flex-row items-center justify-center gap-2">
|
||||
{renderModelStateIcon(minilmModelState)}
|
||||
MiniLM sentence model
|
||||
{renderModelStateIcon(textModelState)}
|
||||
Text model
|
||||
</div>
|
||||
<div className="flex flex-row items-center justify-center gap-2">
|
||||
{renderModelStateIcon(minilmTokenizerState)}
|
||||
MiniLM tokenizer
|
||||
{renderModelStateIcon(textTokenizerState)}
|
||||
Text tokenizer
|
||||
</div>
|
||||
</div>
|
||||
{(minilmModelState === "error" ||
|
||||
clipImageModelState === "error" ||
|
||||
clipTextModelState === "error") && (
|
||||
{(textModelState === "error" ||
|
||||
textTokenizerState === "error" ||
|
||||
visionModelState === "error" ||
|
||||
visionFeatureExtractorState === "error") && (
|
||||
<div className="my-3 max-w-96 text-center text-danger">
|
||||
An error has occurred. Check Frigate logs.
|
||||
</div>
|
||||
|
Loading…
Reference in New Issue
Block a user