stats_manager.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. # -*- coding: utf-8 -*-
  2. #
  3. # PySceneDetect: Python-Based Video Scene Detector
  4. # -------------------------------------------------------------------
  5. # [ Site: https://scenedetect.com ]
  6. # [ Docs: https://scenedetect.com/docs/ ]
  7. # [ Github: https://github.com/Breakthrough/PySceneDetect/ ]
  8. #
  9. # Copyright (C) 2014-2024 Brandon Castellano <http://www.bcastell.com>.
  10. # PySceneDetect is licensed under the BSD 3-Clause License; see the
  11. # included LICENSE file, or visit one of the above pages for details.
  12. #
  13. """``scenedetect.stats_manager`` Module
  14. This module contains the :class:`StatsManager` class, which provides a key-value store for each
  15. :class:`SceneDetector <scenedetect.scene_detector.SceneDetector>` to write the metrics calculated
  16. for each frame. The :class:`StatsManager` must be registered to a
  17. :class:`SceneManager <scenedetect.scene_manager.SceneManager>` upon construction.
  18. The entire :class:`StatsManager` can be :meth:`saved to <StatsManager.save_to_csv>` a
  19. human-readable CSV file, allowing for precise determination of the ideal threshold (or other
  20. detection parameters) for the given input.
  21. """
  22. import csv
  23. from logging import getLogger
  24. import typing as ty
  25. # TODO: Replace below imports with `ty.` prefix.
  26. from typing import Any, Dict, Iterable, List, Optional, Set, TextIO, Union
  27. import os.path
  28. from scenedetect.frame_timecode import FrameTimecode
  29. logger = getLogger('pyscenedetect')
  30. ##
  31. ## StatsManager CSV File Column Names (Header Row)
  32. ##
  33. COLUMN_NAME_FRAME_NUMBER = "Frame Number"
  34. """Name of column containing frame numbers in the statsfile CSV."""
  35. COLUMN_NAME_TIMECODE = "Timecode"
  36. """Name of column containing timecodes in the statsfile CSV."""
  37. ##
  38. ## StatsManager Exceptions
  39. ##
  40. class FrameMetricRegistered(Exception):
  41. """[DEPRECATED - DO NOT USE] No longer used."""
  42. pass
  43. class FrameMetricNotRegistered(Exception):
  44. """[DEPRECATED - DO NOT USE] No longer used."""
  45. pass
  46. class StatsFileCorrupt(Exception):
  47. """Raised when frame metrics/stats could not be loaded from a provided CSV file."""
  48. def __init__(self,
  49. message: str = "Could not load frame metric data data from passed CSV file."):
  50. super().__init__(message)
  51. ##
  52. ## StatsManager Class Implementation
  53. ##
  54. # TODO(v1.0): Relax restriction on metric types only being float or int when loading from disk
  55. # is fully deprecated.
  56. class StatsManager:
  57. """Provides a key-value store for frame metrics/calculations which can be used
  58. for two-pass detection algorithms, as well as saving stats to a CSV file.
  59. Analyzing a statistics CSV file is also very useful for finding the optimal
  60. algorithm parameters for certain detection methods. Additionally, the data
  61. may be plotted by a graphing module (e.g. matplotlib) by obtaining the
  62. metric of interest for a series of frames by iteratively calling get_metrics(),
  63. after having called the detect_scenes(...) method on the SceneManager object
  64. which owns the given StatsManager instance.
  65. Only metrics consisting of `float` or `int` should be used currently.
  66. """
  67. def __init__(self, base_timecode: FrameTimecode = None):
  68. """Initialize a new StatsManager.
  69. Arguments:
  70. base_timecode: Timecode associated with this object. Must not be None (default value
  71. will be removed in a future release).
  72. """
  73. # Frame metrics is a dict of frame (int): metric_dict (Dict[str, float])
  74. # of each frame metric key and the value it represents (usually float).
  75. self._frame_metrics: Dict[FrameTimecode, Dict[str, float]] = dict()
  76. self._metric_keys: Set[str] = set()
  77. self._metrics_updated: bool = False # Flag indicating if metrics require saving.
  78. self._base_timecode: Optional[FrameTimecode] = base_timecode # Used for timing calculations.
  79. @property
  80. def metric_keys(self) -> ty.Iterable[str]:
  81. return self._metric_keys
  82. def register_metrics(self, metric_keys: Iterable[str]) -> None:
  83. """Register a list of metric keys that will be used by the detector."""
  84. self._metric_keys = self._metric_keys.union(set(metric_keys))
  85. # TODO(v1.0): Change frame_number to a FrameTimecode now that it is just a hash and will
  86. # be required for VFR support. This API is also really difficult to use, this type should just
  87. # function like a dictionary.
  88. def get_metrics(self, frame_number: int, metric_keys: Iterable[str]) -> List[Any]:
  89. """Return the requested statistics/metrics for a given frame.
  90. Arguments:
  91. frame_number (int): Frame number to retrieve metrics for.
  92. metric_keys (List[str]): A list of metric keys to look up.
  93. Returns:
  94. A list containing the requested frame metrics for the given frame number
  95. in the same order as the input list of metric keys. If a metric could
  96. not be found, None is returned for that particular metric.
  97. """
  98. return [self._get_metric(frame_number, metric_key) for metric_key in metric_keys]
  99. def set_metrics(self, frame_number: int, metric_kv_dict: Dict[str, Any]) -> None:
  100. """ Set Metrics: Sets the provided statistics/metrics for a given frame.
  101. Arguments:
  102. frame_number: Frame number to retrieve metrics for.
  103. metric_kv_dict: A dict mapping metric keys to the
  104. respective integer/floating-point metric values to set.
  105. """
  106. for metric_key in metric_kv_dict:
  107. self._set_metric(frame_number, metric_key, metric_kv_dict[metric_key])
  108. def metrics_exist(self, frame_number: int, metric_keys: Iterable[str]) -> bool:
  109. """ Metrics Exist: Checks if the given metrics/stats exist for the given frame.
  110. Returns:
  111. bool: True if the given metric keys exist for the frame, False otherwise.
  112. """
  113. return all([self._metric_exists(frame_number, metric_key) for metric_key in metric_keys])
  114. def is_save_required(self) -> bool:
  115. """ Is Save Required: Checks if the stats have been updated since loading.
  116. Returns:
  117. bool: True if there are frame metrics/statistics not yet written to disk,
  118. False otherwise.
  119. """
  120. return self._metrics_updated
  121. def save_to_csv(self,
  122. csv_file: Union[str, bytes, TextIO],
  123. base_timecode: Optional[FrameTimecode] = None,
  124. force_save=True) -> None:
  125. """ Save To CSV: Saves all frame metrics stored in the StatsManager to a CSV file.
  126. Arguments:
  127. csv_file: A file handle opened in write mode (e.g. open('...', 'w')) or a path as str.
  128. base_timecode: [DEPRECATED] DO NOT USE. For backwards compatibility.
  129. force_save: If True, writes metrics out even if an update is not required.
  130. Raises:
  131. OSError: If `path` cannot be opened or a write failure occurs.
  132. """
  133. # TODO(v0.7): Replace with DeprecationWarning that `base_timecode` will be removed in v0.8.
  134. if base_timecode is not None:
  135. logger.error('base_timecode is deprecated and has no effect.')
  136. if not (force_save or self.is_save_required()):
  137. logger.info("No metrics to write.")
  138. return
  139. # If we get a path instead of an open file handle, recursively call ourselves
  140. # again but with file handle instead of path.
  141. if isinstance(csv_file, (str, bytes)):
  142. with open(csv_file, 'w') as file:
  143. self.save_to_csv(csv_file=file, force_save=force_save)
  144. return
  145. csv_writer = csv.writer(csv_file, lineterminator='\n')
  146. metric_keys = sorted(list(self._metric_keys))
  147. csv_writer.writerow([COLUMN_NAME_FRAME_NUMBER, COLUMN_NAME_TIMECODE] + metric_keys)
  148. frame_keys = sorted(self._frame_metrics.keys())
  149. logger.info("Writing %d frames to CSV...", len(frame_keys))
  150. for frame_key in frame_keys:
  151. frame_timecode = self._base_timecode + frame_key
  152. csv_writer.writerow(
  153. [frame_timecode.get_frames() +
  154. 1, frame_timecode.get_timecode()] +
  155. [str(metric) for metric in self.get_metrics(frame_key, metric_keys)])
  156. @staticmethod
  157. def valid_header(row: List[str]) -> bool:
  158. """Check that the given CSV row is a valid header for a statsfile.
  159. Arguments:
  160. row: A row decoded from the CSV reader.
  161. Returns:
  162. True if `row` is a valid statsfile header, False otherwise.
  163. """
  164. if not row or not len(row) >= 2:
  165. return False
  166. if row[0] != COLUMN_NAME_FRAME_NUMBER or row[1] != COLUMN_NAME_TIMECODE:
  167. return False
  168. return True
  169. # TODO(v1.0): Create a replacement for a calculation cache that functions like load_from_csv
  170. # did, but is better integrated with detectors for cached calculations instead of statistics.
  171. def load_from_csv(self, csv_file: Union[str, bytes, TextIO]) -> Optional[int]:
  172. """[DEPRECATED] DO NOT USE
  173. Load all metrics stored in a CSV file into the StatsManager instance. Will be removed in a
  174. future release after becoming a no-op.
  175. Arguments:
  176. csv_file: A file handle opened in read mode (e.g. open('...', 'r')) or a path as str.
  177. Returns:
  178. int or None: Number of frames/rows read from the CSV file, or None if the
  179. input file was blank or could not be found.
  180. Raises:
  181. StatsFileCorrupt: Stats file is corrupt and can't be loaded, or wrong file
  182. was specified.
  183. """
  184. # TODO: Make this an error, then make load_from_csv() a no-op, and finally, remove it.
  185. logger.warning("load_from_csv() is deprecated and will be removed in a future release.")
  186. # If we get a path instead of an open file handle, check that it exists, and if so,
  187. # recursively call ourselves again but with file set instead of path.
  188. if isinstance(csv_file, (str, bytes)):
  189. if os.path.exists(csv_file):
  190. with open(csv_file, 'r') as file:
  191. return self.load_from_csv(csv_file=file)
  192. # Path doesn't exist.
  193. return None
  194. # If we get here, file is a valid file handle in read-only text mode.
  195. csv_reader = csv.reader(csv_file, lineterminator='\n')
  196. num_cols = None
  197. num_metrics = None
  198. num_frames = None
  199. # First Row: Frame Num, Timecode, [metrics...]
  200. try:
  201. row = next(csv_reader)
  202. # Backwards compatibility for previous versions of statsfile
  203. # which included an additional header row.
  204. if not self.valid_header(row):
  205. row = next(csv_reader)
  206. except StopIteration:
  207. # If the file is blank or we couldn't decode anything, assume the file was empty.
  208. return None
  209. if not self.valid_header(row):
  210. raise StatsFileCorrupt()
  211. num_cols = len(row)
  212. num_metrics = num_cols - 2
  213. if not num_metrics > 0:
  214. raise StatsFileCorrupt('No metrics defined in CSV file.')
  215. loaded_metrics = list(row[2:])
  216. num_frames = 0
  217. for row in csv_reader:
  218. metric_dict = {}
  219. if not len(row) == num_cols:
  220. raise StatsFileCorrupt('Wrong number of columns detected in stats file row.')
  221. frame_number = int(row[0])
  222. # Switch from 1-based to 0-based frame numbers.
  223. if frame_number > 0:
  224. frame_number -= 1
  225. self.set_metrics(frame_number, metric_dict)
  226. for i, metric in enumerate(row[2:]):
  227. if metric and metric != 'None':
  228. try:
  229. self._set_metric(frame_number, loaded_metrics[i], float(metric))
  230. except ValueError:
  231. raise StatsFileCorrupt('Corrupted value in stats file: %s' %
  232. metric) from ValueError
  233. num_frames += 1
  234. self._metric_keys = self._metric_keys.union(set(loaded_metrics))
  235. logger.info('Loaded %d metrics for %d frames.', num_metrics, num_frames)
  236. self._metrics_updated = False
  237. return num_frames
  238. # TODO: Get rid of these functions and simplify the implementation of this class.
  239. def _get_metric(self, frame_number: int, metric_key: str) -> Optional[Any]:
  240. if self._metric_exists(frame_number, metric_key):
  241. return self._frame_metrics[frame_number][metric_key]
  242. return None
  243. def _set_metric(self, frame_number: int, metric_key: str, metric_value: Any) -> None:
  244. self._metrics_updated = True
  245. if not frame_number in self._frame_metrics:
  246. self._frame_metrics[frame_number] = dict()
  247. self._frame_metrics[frame_number][metric_key] = metric_value
  248. def _metric_exists(self, frame_number: int, metric_key: str) -> bool:
  249. return (frame_number in self._frame_metrics
  250. and metric_key in self._frame_metrics[frame_number])