mirror of
https://github.com/blakeblackshear/frigate.git
synced 2025-02-25 18:55:25 -06:00
Semantic Search API (#12105)
* initial event search api implementation * fix lint * fix tests * move chromadb imports and pysqlite hotswap to fix tests * remove unused import * switch default limit to 50 * fix events accidently pulling inside chroma results loop
This commit is contained in:
parent
36cbffcc5e
commit
9e825811f2
@ -1,12 +1,9 @@
|
|||||||
import faulthandler
|
import faulthandler
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
from flask import cli
|
from flask import cli
|
||||||
|
|
||||||
# Hotsawp the sqlite3 module for Chroma compatibility
|
from frigate.app import FrigateApp
|
||||||
__import__("pysqlite3")
|
|
||||||
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
|
|
||||||
|
|
||||||
faulthandler.enable()
|
faulthandler.enable()
|
||||||
|
|
||||||
@ -15,8 +12,6 @@ threading.current_thread().name = "frigate"
|
|||||||
cli.show_server_banner = lambda *x: None
|
cli.show_server_banner = lambda *x: None
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from frigate.app import FrigateApp
|
|
||||||
|
|
||||||
frigate_app = FrigateApp()
|
frigate_app = FrigateApp()
|
||||||
|
|
||||||
frigate_app.start()
|
frigate_app.start()
|
||||||
|
@ -23,6 +23,7 @@ from frigate.api.preview import PreviewBp
|
|||||||
from frigate.api.review import ReviewBp
|
from frigate.api.review import ReviewBp
|
||||||
from frigate.config import FrigateConfig
|
from frigate.config import FrigateConfig
|
||||||
from frigate.const import CONFIG_DIR
|
from frigate.const import CONFIG_DIR
|
||||||
|
from frigate.embeddings import EmbeddingsContext
|
||||||
from frigate.events.external import ExternalEventProcessor
|
from frigate.events.external import ExternalEventProcessor
|
||||||
from frigate.models import Event, Timeline
|
from frigate.models import Event, Timeline
|
||||||
from frigate.plus import PlusApi
|
from frigate.plus import PlusApi
|
||||||
@ -52,6 +53,7 @@ bp.register_blueprint(AuthBp)
|
|||||||
def create_app(
|
def create_app(
|
||||||
frigate_config,
|
frigate_config,
|
||||||
database: SqliteQueueDatabase,
|
database: SqliteQueueDatabase,
|
||||||
|
embeddings: EmbeddingsContext,
|
||||||
detected_frames_processor,
|
detected_frames_processor,
|
||||||
storage_maintainer: StorageMaintainer,
|
storage_maintainer: StorageMaintainer,
|
||||||
onvif: OnvifController,
|
onvif: OnvifController,
|
||||||
@ -79,6 +81,7 @@ def create_app(
|
|||||||
database.close()
|
database.close()
|
||||||
|
|
||||||
app.frigate_config = frigate_config
|
app.frigate_config = frigate_config
|
||||||
|
app.embeddings = embeddings
|
||||||
app.detected_frames_processor = detected_frames_processor
|
app.detected_frames_processor = detected_frames_processor
|
||||||
app.storage_maintainer = storage_maintainer
|
app.storage_maintainer = storage_maintainer
|
||||||
app.onvif = onvif
|
app.onvif = onvif
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
"""Event apis."""
|
"""Event apis."""
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
@ -8,6 +10,7 @@ from pathlib import Path
|
|||||||
from urllib.parse import unquote
|
from urllib.parse import unquote
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
import numpy as np
|
||||||
from flask import (
|
from flask import (
|
||||||
Blueprint,
|
Blueprint,
|
||||||
current_app,
|
current_app,
|
||||||
@ -15,13 +18,16 @@ from flask import (
|
|||||||
make_response,
|
make_response,
|
||||||
request,
|
request,
|
||||||
)
|
)
|
||||||
from peewee import DoesNotExist, fn, operator
|
from peewee import JOIN, DoesNotExist, fn, operator
|
||||||
|
from PIL import Image
|
||||||
from playhouse.shortcuts import model_to_dict
|
from playhouse.shortcuts import model_to_dict
|
||||||
|
|
||||||
from frigate.const import (
|
from frigate.const import (
|
||||||
CLIPS_DIR,
|
CLIPS_DIR,
|
||||||
)
|
)
|
||||||
from frigate.models import Event, Timeline
|
from frigate.embeddings import EmbeddingsContext
|
||||||
|
from frigate.embeddings.embeddings import get_metadata
|
||||||
|
from frigate.models import Event, ReviewSegment, Timeline
|
||||||
from frigate.object_processing import TrackedObject
|
from frigate.object_processing import TrackedObject
|
||||||
from frigate.util.builtin import get_tz_modifiers
|
from frigate.util.builtin import get_tz_modifiers
|
||||||
|
|
||||||
@ -245,6 +251,189 @@ def events():
|
|||||||
return jsonify(list(events))
|
return jsonify(list(events))
|
||||||
|
|
||||||
|
|
||||||
|
@EventBp.route("/events/search")
|
||||||
|
def events_search():
|
||||||
|
query = request.args.get("query", type=str)
|
||||||
|
search_type = request.args.get("search_type", "text", type=str)
|
||||||
|
include_thumbnails = request.args.get("include_thumbnails", default=1, type=int)
|
||||||
|
limit = request.args.get("limit", 50, type=int)
|
||||||
|
|
||||||
|
# Filters
|
||||||
|
cameras = request.args.get("cameras", "all", type=str)
|
||||||
|
labels = request.args.get("labels", "all", type=str)
|
||||||
|
zones = request.args.get("zones", "all", type=str)
|
||||||
|
after = request.args.get("after", type=float)
|
||||||
|
before = request.args.get("before", type=float)
|
||||||
|
|
||||||
|
if not query:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "A search query must be supplied",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not current_app.frigate_config.semantic_search.enabled:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "Semantic search is not enabled",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
context: EmbeddingsContext = current_app.embeddings
|
||||||
|
|
||||||
|
selected_columns = [
|
||||||
|
Event.id,
|
||||||
|
Event.camera,
|
||||||
|
Event.label,
|
||||||
|
Event.sub_label,
|
||||||
|
Event.zones,
|
||||||
|
Event.start_time,
|
||||||
|
Event.end_time,
|
||||||
|
Event.data,
|
||||||
|
ReviewSegment.thumb_path,
|
||||||
|
]
|
||||||
|
|
||||||
|
if include_thumbnails:
|
||||||
|
selected_columns.append(Event.thumbnail)
|
||||||
|
|
||||||
|
# Build the where clause for the embeddings query
|
||||||
|
embeddings_filters = []
|
||||||
|
|
||||||
|
if cameras != "all":
|
||||||
|
camera_list = cameras.split(",")
|
||||||
|
embeddings_filters.append({"camera": {"$in": camera_list}})
|
||||||
|
|
||||||
|
if labels != "all":
|
||||||
|
label_list = labels.split(",")
|
||||||
|
embeddings_filters.append({"label": {"$in": label_list}})
|
||||||
|
|
||||||
|
if zones != "all":
|
||||||
|
filtered_zones = zones.split(",")
|
||||||
|
zone_filters = [{f"zones_{zone}": {"$eq": True}} for zone in filtered_zones]
|
||||||
|
if len(zone_filters) > 1:
|
||||||
|
embeddings_filters.append({"$or": zone_filters})
|
||||||
|
else:
|
||||||
|
embeddings_filters.append(zone_filters[0])
|
||||||
|
|
||||||
|
if after:
|
||||||
|
embeddings_filters.append({"start_time": {"$gt": after}})
|
||||||
|
|
||||||
|
if before:
|
||||||
|
embeddings_filters.append({"start_time": {"$lt": before}})
|
||||||
|
|
||||||
|
where = None
|
||||||
|
if len(embeddings_filters) > 1:
|
||||||
|
where = {"$and": embeddings_filters}
|
||||||
|
elif len(embeddings_filters) == 1:
|
||||||
|
where = embeddings_filters[0]
|
||||||
|
|
||||||
|
thumb_ids = {}
|
||||||
|
desc_ids = {}
|
||||||
|
|
||||||
|
if search_type == "thumbnail":
|
||||||
|
# Grab the ids of events that match the thumbnail image embeddings
|
||||||
|
try:
|
||||||
|
search_event: Event = Event.get(Event.id == query)
|
||||||
|
except DoesNotExist:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "Event not found",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
404,
|
||||||
|
)
|
||||||
|
thumbnail = base64.b64decode(search_event.thumbnail)
|
||||||
|
img = np.array(Image.open(io.BytesIO(thumbnail)).convert("RGB"))
|
||||||
|
thumb_result = context.embeddings.thumbnail.query(
|
||||||
|
query_images=[img],
|
||||||
|
n_results=limit,
|
||||||
|
where=where,
|
||||||
|
)
|
||||||
|
thumb_ids = dict(zip(thumb_result["ids"][0], thumb_result["distances"][0]))
|
||||||
|
else:
|
||||||
|
thumb_result = context.embeddings.thumbnail.query(
|
||||||
|
query_texts=[query],
|
||||||
|
n_results=limit,
|
||||||
|
where=where,
|
||||||
|
)
|
||||||
|
# Do a rudimentary normalization of the difference in distances returned by CLIP and MiniLM.
|
||||||
|
thumb_ids = dict(
|
||||||
|
zip(
|
||||||
|
thumb_result["ids"][0],
|
||||||
|
context.thumb_stats.normalize(thumb_result["distances"][0]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
desc_result = context.embeddings.description.query(
|
||||||
|
query_texts=[query],
|
||||||
|
n_results=limit,
|
||||||
|
where=where,
|
||||||
|
)
|
||||||
|
desc_ids = dict(
|
||||||
|
zip(
|
||||||
|
desc_result["ids"][0],
|
||||||
|
context.desc_stats.normalize(desc_result["distances"][0]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
results = {}
|
||||||
|
for event_id in thumb_ids.keys() | desc_ids:
|
||||||
|
min_distance = min(
|
||||||
|
i
|
||||||
|
for i in (thumb_ids.get(event_id), desc_ids.get(event_id))
|
||||||
|
if i is not None
|
||||||
|
)
|
||||||
|
results[event_id] = {
|
||||||
|
"distance": min_distance,
|
||||||
|
"source": "thumbnail"
|
||||||
|
if min_distance == thumb_ids.get(event_id)
|
||||||
|
else "description",
|
||||||
|
}
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
return jsonify([])
|
||||||
|
|
||||||
|
# Get the event data
|
||||||
|
events = (
|
||||||
|
Event.select(*selected_columns)
|
||||||
|
.join(
|
||||||
|
ReviewSegment,
|
||||||
|
JOIN.LEFT_OUTER,
|
||||||
|
on=(fn.json_extract(ReviewSegment.data, "$.detections").contains(Event.id)),
|
||||||
|
)
|
||||||
|
.where(Event.id << list(results.keys()))
|
||||||
|
.dicts()
|
||||||
|
.iterator()
|
||||||
|
)
|
||||||
|
events = list(events)
|
||||||
|
|
||||||
|
events = [
|
||||||
|
{k: v for k, v in event.items() if k != "data"}
|
||||||
|
| {
|
||||||
|
k: v
|
||||||
|
for k, v in event["data"].items()
|
||||||
|
if k in ["type", "score", "top_score", "description"]
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
"search_distance": results[event["id"]]["distance"],
|
||||||
|
"search_source": results[event["id"]]["source"],
|
||||||
|
}
|
||||||
|
for event in events
|
||||||
|
]
|
||||||
|
events = sorted(events, key=lambda x: x["search_distance"])[:limit]
|
||||||
|
|
||||||
|
return jsonify(events)
|
||||||
|
|
||||||
|
|
||||||
@EventBp.route("/events/summary")
|
@EventBp.route("/events/summary")
|
||||||
def events_summary():
|
def events_summary():
|
||||||
tz_name = request.args.get("timezone", default="utc", type=str)
|
tz_name = request.args.get("timezone", default="utc", type=str)
|
||||||
@ -604,6 +793,52 @@ def set_sub_label(id):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@EventBp.route("/events/<id>/description", methods=("POST",))
|
||||||
|
def set_description(id):
|
||||||
|
try:
|
||||||
|
event: Event = Event.get(Event.id == id)
|
||||||
|
except DoesNotExist:
|
||||||
|
return make_response(
|
||||||
|
jsonify({"success": False, "message": "Event " + id + " not found"}), 404
|
||||||
|
)
|
||||||
|
|
||||||
|
json: dict[str, any] = request.get_json(silent=True) or {}
|
||||||
|
new_description = json.get("description")
|
||||||
|
|
||||||
|
if new_description is None or len(new_description) == 0:
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"message": "description cannot be empty",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
400,
|
||||||
|
)
|
||||||
|
|
||||||
|
event.data["description"] = new_description
|
||||||
|
event.save()
|
||||||
|
|
||||||
|
# If semantic search is enabled, update the index
|
||||||
|
if current_app.frigate_config.semantic_search.enabled:
|
||||||
|
context: EmbeddingsContext = current_app.embeddings
|
||||||
|
context.embeddings.description.upsert(
|
||||||
|
documents=[new_description],
|
||||||
|
metadatas=[get_metadata(event)],
|
||||||
|
ids=[id],
|
||||||
|
)
|
||||||
|
|
||||||
|
return make_response(
|
||||||
|
jsonify(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"message": "Event " + id + " description set to " + new_description,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
200,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@EventBp.route("/events/<id>", methods=("DELETE",))
|
@EventBp.route("/events/<id>", methods=("DELETE",))
|
||||||
def delete_event(id):
|
def delete_event(id):
|
||||||
try:
|
try:
|
||||||
@ -625,6 +860,11 @@ def delete_event(id):
|
|||||||
|
|
||||||
event.delete_instance()
|
event.delete_instance()
|
||||||
Timeline.delete().where(Timeline.source_id == id).execute()
|
Timeline.delete().where(Timeline.source_id == id).execute()
|
||||||
|
# If semantic search is enabled, update the index
|
||||||
|
if current_app.frigate_config.semantic_search.enabled:
|
||||||
|
context: EmbeddingsContext = current_app.embeddings
|
||||||
|
context.embeddings.thumbnail.delete(ids=[id])
|
||||||
|
context.embeddings.description.delete(ids=[id])
|
||||||
return make_response(
|
return make_response(
|
||||||
jsonify({"success": True, "message": "Event " + id + " deleted"}), 200
|
jsonify({"success": True, "message": "Event " + id + " deleted"}), 200
|
||||||
)
|
)
|
||||||
|
@ -37,8 +37,7 @@ from frigate.const import (
|
|||||||
MODEL_CACHE_DIR,
|
MODEL_CACHE_DIR,
|
||||||
RECORD_DIR,
|
RECORD_DIR,
|
||||||
)
|
)
|
||||||
from frigate.embeddings import manage_embeddings
|
from frigate.embeddings import EmbeddingsContext, manage_embeddings
|
||||||
from frigate.embeddings.embeddings import Embeddings
|
|
||||||
from frigate.events.audio import listen_to_audio
|
from frigate.events.audio import listen_to_audio
|
||||||
from frigate.events.cleanup import EventCleanup
|
from frigate.events.cleanup import EventCleanup
|
||||||
from frigate.events.external import ExternalEventProcessor
|
from frigate.events.external import ExternalEventProcessor
|
||||||
@ -322,7 +321,7 @@ class FrigateApp:
|
|||||||
|
|
||||||
def init_embeddings_manager(self) -> None:
|
def init_embeddings_manager(self) -> None:
|
||||||
# Create a client for other processes to use
|
# Create a client for other processes to use
|
||||||
self.embeddings = Embeddings()
|
self.embeddings = EmbeddingsContext()
|
||||||
embedding_process = mp.Process(
|
embedding_process = mp.Process(
|
||||||
target=manage_embeddings,
|
target=manage_embeddings,
|
||||||
name="embeddings_manager",
|
name="embeddings_manager",
|
||||||
@ -384,6 +383,7 @@ class FrigateApp:
|
|||||||
self.flask_app = create_app(
|
self.flask_app = create_app(
|
||||||
self.config,
|
self.config,
|
||||||
self.db,
|
self.db,
|
||||||
|
self.embeddings,
|
||||||
self.detected_frames_processor,
|
self.detected_frames_processor,
|
||||||
self.storage_maintainer,
|
self.storage_maintainer,
|
||||||
self.onvif_controller,
|
self.onvif_controller,
|
||||||
@ -811,6 +811,9 @@ class FrigateApp:
|
|||||||
self.frigate_watchdog.join()
|
self.frigate_watchdog.join()
|
||||||
self.db.stop()
|
self.db.stop()
|
||||||
|
|
||||||
|
# Save embeddings stats to disk
|
||||||
|
self.embeddings.save_stats()
|
||||||
|
|
||||||
# Stop Communicators
|
# Stop Communicators
|
||||||
self.inter_process_communicator.stop()
|
self.inter_process_communicator.stop()
|
||||||
self.inter_config_updater.stop()
|
self.inter_config_updater.stop()
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
"""ChromaDB embeddings database."""
|
"""ChromaDB embeddings database."""
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import signal
|
import signal
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
from types import FrameType
|
from types import FrameType
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@ -12,9 +12,14 @@ from playhouse.sqliteq import SqliteQueueDatabase
|
|||||||
from setproctitle import setproctitle
|
from setproctitle import setproctitle
|
||||||
|
|
||||||
from frigate.config import FrigateConfig
|
from frigate.config import FrigateConfig
|
||||||
|
from frigate.const import CONFIG_DIR
|
||||||
from frigate.models import Event
|
from frigate.models import Event
|
||||||
from frigate.util.services import listen
|
from frigate.util.services import listen
|
||||||
|
|
||||||
|
from .embeddings import Embeddings
|
||||||
|
from .maintainer import EmbeddingMaintainer
|
||||||
|
from .util import ZScoreNormalization
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -48,12 +53,6 @@ def manage_embeddings(config: FrigateConfig) -> None:
|
|||||||
models = [Event]
|
models = [Event]
|
||||||
db.bind(models)
|
db.bind(models)
|
||||||
|
|
||||||
# Hotsawp the sqlite3 module for Chroma compatibility
|
|
||||||
__import__("pysqlite3")
|
|
||||||
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
|
|
||||||
from .embeddings import Embeddings
|
|
||||||
from .maintainer import EmbeddingMaintainer
|
|
||||||
|
|
||||||
embeddings = Embeddings()
|
embeddings = Embeddings()
|
||||||
|
|
||||||
# Check if we need to re-index events
|
# Check if we need to re-index events
|
||||||
@ -65,3 +64,28 @@ def manage_embeddings(config: FrigateConfig) -> None:
|
|||||||
stop_event,
|
stop_event,
|
||||||
)
|
)
|
||||||
maintainer.start()
|
maintainer.start()
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingsContext:
|
||||||
|
def __init__(self):
|
||||||
|
self.embeddings = Embeddings()
|
||||||
|
self.thumb_stats = ZScoreNormalization()
|
||||||
|
self.desc_stats = ZScoreNormalization()
|
||||||
|
|
||||||
|
# load stats from disk
|
||||||
|
try:
|
||||||
|
with open(f"{CONFIG_DIR}/.search_stats.json", "r") as f:
|
||||||
|
data = json.loads(f.read())
|
||||||
|
self.thumb_stats.from_dict(data["thumb_stats"])
|
||||||
|
self.desc_stats.from_dict(data["desc_stats"])
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def save_stats(self):
|
||||||
|
"""Write the stats to disk as JSON on exit."""
|
||||||
|
contents = {
|
||||||
|
"thumb_stats": self.thumb_stats.to_dict(),
|
||||||
|
"desc_stats": self.desc_stats.to_dict(),
|
||||||
|
}
|
||||||
|
with open(f"{CONFIG_DIR}/.search_stats.json", "w") as f:
|
||||||
|
f.write(json.dumps(contents))
|
||||||
|
@ -3,19 +3,32 @@
|
|||||||
import base64
|
import base64
|
||||||
import io
|
import io
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from chromadb import Collection
|
|
||||||
from chromadb import HttpClient as ChromaClient
|
|
||||||
from chromadb.config import Settings
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from playhouse.shortcuts import model_to_dict
|
from playhouse.shortcuts import model_to_dict
|
||||||
|
|
||||||
from frigate.models import Event
|
from frigate.models import Event
|
||||||
|
|
||||||
from .functions.clip import ClipEmbedding
|
# Hotsawp the sqlite3 module for Chroma compatibility
|
||||||
from .functions.minilm_l6_v2 import MiniLMEmbedding
|
try:
|
||||||
|
from chromadb import Collection
|
||||||
|
from chromadb import HttpClient as ChromaClient
|
||||||
|
from chromadb.config import Settings
|
||||||
|
|
||||||
|
from .functions.clip import ClipEmbedding
|
||||||
|
from .functions.minilm_l6_v2 import MiniLMEmbedding
|
||||||
|
except RuntimeError:
|
||||||
|
__import__("pysqlite3")
|
||||||
|
sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")
|
||||||
|
from chromadb import Collection
|
||||||
|
from chromadb import HttpClient as ChromaClient
|
||||||
|
from chromadb.config import Settings
|
||||||
|
|
||||||
|
from .functions.clip import ClipEmbedding
|
||||||
|
from .functions.minilm_l6_v2 import MiniLMEmbedding
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
47
frigate/embeddings/util.py
Normal file
47
frigate/embeddings/util.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
"""Z-score normalization for search distance."""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class ZScoreNormalization:
|
||||||
|
"""Running Z-score normalization for search distance."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.n = 0
|
||||||
|
self.mean = 0
|
||||||
|
self.m2 = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def variance(self):
|
||||||
|
return self.m2 / (self.n - 1) if self.n > 1 else 0.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def stddev(self):
|
||||||
|
return math.sqrt(self.variance)
|
||||||
|
|
||||||
|
def normalize(self, distances: list[float]):
|
||||||
|
self._update(distances)
|
||||||
|
if self.stddev == 0:
|
||||||
|
return distances
|
||||||
|
return [(x - self.mean) / self.stddev for x in distances]
|
||||||
|
|
||||||
|
def _update(self, distances: list[float]):
|
||||||
|
for x in distances:
|
||||||
|
self.n += 1
|
||||||
|
delta = x - self.mean
|
||||||
|
self.mean += delta / self.n
|
||||||
|
delta2 = x - self.mean
|
||||||
|
self.m2 += delta * delta2
|
||||||
|
|
||||||
|
def to_dict(self):
|
||||||
|
return {
|
||||||
|
"n": self.n,
|
||||||
|
"mean": self.mean,
|
||||||
|
"m2": self.m2,
|
||||||
|
}
|
||||||
|
|
||||||
|
def from_dict(self, data: dict):
|
||||||
|
self.n = data["n"]
|
||||||
|
self.mean = data["mean"]
|
||||||
|
self.m2 = data["m2"]
|
||||||
|
return self
|
@ -120,6 +120,7 @@ class TestHttp(unittest.TestCase):
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
PlusApi(),
|
PlusApi(),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@ -156,6 +157,7 @@ class TestHttp(unittest.TestCase):
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
PlusApi(),
|
PlusApi(),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@ -177,6 +179,7 @@ class TestHttp(unittest.TestCase):
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
PlusApi(),
|
PlusApi(),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@ -197,6 +200,7 @@ class TestHttp(unittest.TestCase):
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
PlusApi(),
|
PlusApi(),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@ -219,6 +223,7 @@ class TestHttp(unittest.TestCase):
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
PlusApi(),
|
PlusApi(),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@ -245,6 +250,7 @@ class TestHttp(unittest.TestCase):
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
PlusApi(),
|
PlusApi(),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@ -283,6 +289,7 @@ class TestHttp(unittest.TestCase):
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
PlusApi(),
|
PlusApi(),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@ -318,6 +325,7 @@ class TestHttp(unittest.TestCase):
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
PlusApi(),
|
PlusApi(),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@ -343,6 +351,7 @@ class TestHttp(unittest.TestCase):
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
PlusApi(),
|
PlusApi(),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@ -360,6 +369,7 @@ class TestHttp(unittest.TestCase):
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
PlusApi(),
|
PlusApi(),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
@ -381,6 +391,7 @@ class TestHttp(unittest.TestCase):
|
|||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
PlusApi(),
|
PlusApi(),
|
||||||
stats,
|
stats,
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user