greenhouse/utils/callbacks.py
Glenn Jocher 7eb23e3c1d
YOLOv5 v6.0 compatibility update (#1857)
* Initial commit

* Initial commit

* Cleanup

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix precommit errors

* Remove TF builds from CI

* export last.pt

* Created using Colaboratory

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2021-11-14 22:22:59 +01:00

77 lines
2.2 KiB
Python

# YOLOv3 🚀 by Ultralytics, GPL-3.0 license
"""
Callback utils
"""
class Callbacks:
""""
Handles all registered callbacks for Hooks
"""
# Define the available callbacks
_callbacks = {
'on_pretrain_routine_start': [],
'on_pretrain_routine_end': [],
'on_train_start': [],
'on_train_epoch_start': [],
'on_train_batch_start': [],
'optimizer_step': [],
'on_before_zero_grad': [],
'on_train_batch_end': [],
'on_train_epoch_end': [],
'on_val_start': [],
'on_val_batch_start': [],
'on_val_image_end': [],
'on_val_batch_end': [],
'on_val_end': [],
'on_fit_epoch_end': [], # fit = train + val
'on_model_save': [],
'on_train_end': [],
'teardown': [],
}
def register_action(self, hook, name='', callback=None):
"""
Register a new action to a callback hook
Args:
hook The callback hook name to register the action to
name The name of the action for later reference
callback The callback to fire
"""
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
assert callable(callback), f"callback '{callback}' is not callable"
self._callbacks[hook].append({'name': name, 'callback': callback})
def get_registered_actions(self, hook=None):
""""
Returns all the registered actions by callback hook
Args:
hook The name of the hook to check, defaults to all
"""
if hook:
return self._callbacks[hook]
else:
return self._callbacks
def run(self, hook, *args, **kwargs):
"""
Loop through the registered actions and fire all callbacks
Args:
hook The name of the hook to check, defaults to all
args Arguments to receive from
kwargs Keyword Arguments to receive from
"""
assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
for logger in self._callbacks[hook]:
logger['callback'](*args, **kwargs)