Serialisation of state for online detectors
This PR implements the functionality to save and load state for online detectors. At a given time step, the save_state method can be called to create a "checkpoint". This can later be loaded via the load_state method. At any time, the reset_state method can be used to reset the state back to the t=0 timestep.
Note: As a POC, functionality had only be added for LSDDDrift(..., backend='pytorch') for now.
TODO's:
- [ ] Implement for remaining detectors and backends.
- [ ] Integrate with the
save_detectorandload_detectorfunctions, to allow state to be saved and loaded when the detector itself is serialized/unserialized. - [ ] Tests.
- [ ] Docs.
Outstanding considerations (specific to LSDD for now but maybe more widely applicible)
There might be an open question to resolve regarding what we define "state" to be. This PR currently considers it to be only the attributes that are updated in _update_state (self.t, self.test_window and self.k_xtc). In other words, "state" is defined as any attribute that is dependent on time (updated when a new instance x_t is given via score or predict).
However, there is already a notion of "state" introduced when we initialise a detector (or reinitialise it via the reset method). Here, in addition to the attributes already mentioned, we set self.ref_inds, self.c2s, and self.init_test_inds. This leads to considerations:
- Will there be confusion between the
resetandreset_statemethods, and do we need to change the docstrings or names? - There is randomness involved in the initialisation of
LSDDDrift(in_configure_ref_subset). It is likely that if the detector is instantiated later on, andload_stateis used to restart from a checkpoint, predictions will still be different compared to those that were observed aftersave_statewas called with the original detector. This would only be avoided if random seeds were set both times. With this in mind, do we want to change our definition of "state" to includeself.ref_inds,self.c2s, andself.init_test_inds?
Example
See colab notebook
- Stateful here relates to whatever state (attributes) changes between prediction calls, not what state is set in the init of the detector, which would make all detectors stateful of course. To avoid possible confusion 1 suggestion would be to name the methods
reset_detectorandreset_state? - I believe that in general randomness should be handled outside of the detector/library, similar to e.g. PyTorch models. Isn't that randomness already eliminated though when just loading a saved detector? What am I missing here?
- Stateful here relates to whatever state (attributes) changes between prediction calls, not what state is set in the init of the detector, which would make all detectors stateful of course. To avoid possible confusion 1 suggestion would be to name the methods
reset_detectorandreset_state?- I believe that in general randomness should be handled outside of the detector/library, similar to e.g. PyTorch models. Isn't that randomness already eliminated though when just loading a saved detector? What am I missing here?
Nice idea with the name change, and yes I agree, "state" does not refer to any attributes set in init, as that would be "config" (with our definitions).
My only concern is that users might expect a detector to give the same predictions as the original when loaded from a "checkpoint" via save/load_state. That is currently not the case (unless seeds are set manually when the original and re-loaded detectors are instantiated). save/load_detector do not help with this.
Maybe the answer is just to make it clear in the docstrings though... as in any case statistically the detectors behaviour should be the same after the checkpoint even if the exact predictions are not the same?
Codecov Report
Merging #604 (828e486) into master (f0b57b4) will increase coverage by
0.17%. The diff coverage is95.65%.
Additional details and impacted files
@@ Coverage Diff @@
## master #604 +/- ##
==========================================
+ Coverage 80.15% 80.32% +0.17%
==========================================
Files 133 137 +4
Lines 9177 9292 +115
==========================================
+ Hits 7356 7464 +108
- Misses 1821 1828 +7
| Flag | Coverage Δ | |
|---|---|---|
| macos-latest-3.10 | 76.87% <95.65%> (+0.21%) |
:arrow_up: |
| ubuntu-latest-3.10 | 80.21% <95.65%> (+0.17%) |
:arrow_up: |
| ubuntu-latest-3.7 | 80.11% <95.65%> (+0.17%) |
:arrow_up: |
| ubuntu-latest-3.8 | 80.16% <95.65%> (+0.17%) |
:arrow_up: |
| ubuntu-latest-3.9 | 80.16% <95.65%> (+0.17%) |
:arrow_up: |
| windows-latest-3.9 | 76.80% <95.65%> (+0.21%) |
:arrow_up: |
Flags with carried forward coverage won't be shown. Click here to find out more.
| Impacted Files | Coverage Δ | |
|---|---|---|
| alibi_detect/cd/base_online.py | 88.23% <82.60%> (-2.96%) |
:arrow_down: |
| alibi_detect/cd/lsdd_online.py | 93.61% <85.71%> (+0.75%) |
:arrow_up: |
| alibi_detect/cd/mmd_online.py | 94.33% <85.71%> (+0.58%) |
:arrow_up: |
| alibi_detect/utils/state/state.py | 97.50% <97.50%> (ø) |
|
| alibi_detect/base.py | 85.45% <100.00%> (+0.98%) |
:arrow_up: |
| alibi_detect/cd/cvm_online.py | 75.63% <100.00%> (+1.52%) |
:arrow_up: |
| alibi_detect/cd/fet_online.py | 88.97% <100.00%> (+0.17%) |
:arrow_up: |
| alibi_detect/cd/pytorch/lsdd_online.py | 95.78% <100.00%> (+0.18%) |
:arrow_up: |
| alibi_detect/cd/pytorch/mmd_online.py | 100.00% <100.00%> (ø) |
|
| alibi_detect/cd/tensorflow/lsdd_online.py | 95.60% <100.00%> (+0.25%) |
:arrow_up: |
| ... and 7 more |
~Will be conflicts until #618 is merged.~
Edit: Resolved.
Regarding the codecov report, the 4.32% decrease is not accurate. This is based on an old master commit where we were still counting tests in coverage.
Check out this pull request on ![]()
See visual diffs & provide feedback on Jupyter Notebooks.
Powered by ReviewNB
@arnaudvl @ojcobb (and @jklaise/@mauicv) could do with your thoughts on this.
In the latest implementation, I have removed the new reset_state method. State can still be reset with the existing reset method. reset calls _initialise internally, which involves a random operation (for the LSDD detector at least). Because of this, we have a lack of determinism even if random seeds are manually set (explanation below).
Question: Shall we keep a reset_state method separate to reset, or is this just confusing for users? @arnaudvl previously suggested we could rename reset to reset_detector to avoid confusion... but it might be worth discussing what we think the use case of each method actually is anyway. If we don't care about determinism in the case of Example 1, we can remove reset_state I think...
As Example 2 shows, determinism in the case of saving/loading of a detector is not affected by this decision anyway...
Difference between reset and reset_state
(all examples are for LSDDDriftOnline)
reset
reset is an existing method, which calls _initialise:
def reset(self) -> None:
"Resets the detector but does not reconfigure thresholds."
self._initialise()
_initialise typically sets some attributes to zero, and calls _configure_ref_subset (this method contains random ops!):
def _initialise(self) -> None:
self.t = 0 # corresponds to a test set of ref data
self.test_stats = np.array([]) # type: ignore[var-annotated]
self.drift_preds = np.array([]) # type: ignore[var-annotated]
self._configure_ref_subset()
reset_state
reset_state was/is a new method that specifically only resets the core "stateful" attributes (those updated by _update_state):
def reset_state(self):
"""
Reset the detectors state.
"""
self.t = 0
self.test_window = self.x_ref_eff[self.init_test_inds]
self.k_xtc = self.kernel(self.test_window, self.kernel_centers)
This requires _initialise to have been run (so that self.test_window has been set), however it doesn't re-run _initialise, therefore doesn't involve random ops.
Examples
Example 1: Determinism when resetting
When resetting an instantiated detector and repeating predictions, test stats will be repeatable if reset_state used. Example from test_lsdd_online_pt.py:
# Run for 50 time steps
test_stats_1 = []
for t, x_t in enumerate(x):
preds = dd.predict(x_t)
test_stats_1.append(preds['data']['test_stat'])
if t == 20:
dd.save_state(tmp_path)
# Clear state and repeat, check that same test_stats both times
dd.reset_state()
test_stats_2 = []
for t, x_t in enumerate(x):
preds = dd.predict(x_t)
test_stats_2.append(preds['data']['test_stat'])
np.testing.assert_array_equal(test_stats_1, test_stats_2) # passes!
This fails if reset is used, even if torch.manual_seed() is run before instantiating the detector and before reset. Setting seeds externally does not help here because the number of torch.random operations run prior to reaching _initialise is different in both cases (fresh instantiation also involves random operations in _configure_kernel_centers and _configure_thresholds).
Example 2: Determinism when saving/loading
When saving and loading a stateful detector via save_detector(..., save_state=True) and load_detector, predictions following save_detector and load_detector will consistent as long as seeds are set manually. Example from test_saving.py:
with fixed_seed(seed):
dd = detector(X_ref, ert=100, window_size=10, backend='pytorch')
# Run for 50 time-steps
test_stats = []
for t, x_t in enumerate(X_h0[:50]):
test_stats.append(dd.predict(x_t)['data']['test_stat'])
if t == 20:
# Save detector (with state)
save_detector(dd, tmp_path, save_state=True)
# Check state/ dir created
state_path = dd.state_path if detector == CVMDriftOnline else dd._detector.state_path
assert state_path == tmp_path.joinpath('state')
assert state_path.is_dir()
# Load
with fixed_seed(seed):
dd_new = load_detector(tmp_path)
# Check attributes and compare predictions at t=21
assert dd_new.t == 21
np.testing.assert_array_equal(dd_new.predict(X_h0[21])['data']['test_stat'], test_stats[21])
This use case does not depend on design of reset/reset_state etc.
Additional side-note, this issue with setting random seeds not giving deterministic behaviour for a given operation (in this case the torch.randperm called by reset -> _initialise -> _configure_ref_subset) is something we've run into a few times before. The issue is that even if a given random state is set externally (i.e. torch.manual_seed), the random state will be different by the time we get to the torch.randperm if a different number of random operations are called before we get there, since each random op cycles the random state. This is made even more difficult for something like LSDDDriftOnline, since we have random ops in a while loop, so cannot do not know how many random ops will be called.
The only solution I can think of for this is a scikit-learn style approach, where we accept random_state as a kwarg, and then use self.random_state in random ops we want to be deterministic. We would have to be careful with ops like the torch.randperm in _configure_ref_subset, since we do need this to possess a degree of randomness over each iteration in the while loop...
@arnaudvl @ojcobb a possible alternative strategy to make reset deterministic is to rework _initialize methods to ensure they are deterministic. For example, for LSDDDriftOnline, this would involve avoiding the while loop to find a new self.init_test_inds in _configure_ref_subset if init_test_inds already exists:
def _configure_ref_subset(self):
"""
Configure reference subset. If already configured, the stateful attributes `test_window` and `k_xtc` are
reset without re-configuring a new reference subset.
"""
etw_size = 2 * self.window_size - 1 # etw = extended test window
nkc_size = self.n - self.n_kernel_centers # nkc = non-kernel-centers
rw_size = nkc_size - etw_size # rw = ref-window
# Check if already configured, we will re-initialise stateful attributes w/o searching for new ref split if so
configure_ref = self.init_test_inds is None
if configure_ref:
# Make split and ensure it doesn't cause an initial detection
lsdd_init = None
while lsdd_init is None or lsdd_init >= self.get_threshold(0):
# Make split
perm = torch.randperm(nkc_size)
self.ref_inds, self.init_test_inds = perm[:rw_size], perm[-self.window_size:]
self.test_window = self.x_ref_eff[self.init_test_inds]
# Compute initial lsdd to check for initial detection
self.c2s = self.k_xc[self.ref_inds].mean(0) # (below Eqn 21)
self.k_xtc = self.kernel(self.test_window, self.kernel_centers)
h_init = self.c2s - self.k_xtc.mean(0) # (Eqn 21)
lsdd_init = h_init[None, :] @ self.H_lam_inv @ h_init[:, None] # (Eqn 11)
else:
# Reset stateful attributes using existing split
self.test_window = self.x_ref_eff[self.init_test_inds]
self.k_xtc = self.kernel(self.test_window, self.kernel_centers)
This seems like a reasonable compromise to me? However, the additional duplication/complexity is unnecessary if we truly don't care about repeatable predictions post-reset?
Offline vs online state
Based on offline, our current view is that we have a distinction between offline state (stateful attributes such as self.thresholds computed at instantiation) and online state (stateful attributes updated by _update_state when score/predict are called). In the future we wish to save/load offline state within save_detector/load_detector, to avoid repeating expensive instantiation procedures such as configuring thresholds. This PR only deals with online state.
The previous issue was that, for the multivariate detectors (LSDD and MMD), online state was initialised in a non-deterministic way within configure_ref_subset, which was called in _initialise, itself called in reset. This meant that, unlike univariate detectors, the reset method of multivariate detectors was not deterministic. Also, the reset method didn't just reset online state, but also reset offline state such as self.test_window etc.
The latest commits (09c4305 onwards) move the setting of online state into _initialise_state, and configure_ref_subset is taken out of reset. There is now a clear distinction; load_state, save_state, _initialise_state and reset deal with online state. Resetting offline state can be achieved by re-instantiating a fresh detector with a different random seed.
@ojcobb could I get your thoughts on these changes, please?
Another question...
Thinking ahead to when we handle offline and online state, do we need two reset methods? The latest commits change reset so that it resets online state only. Do we instead want to reserve reset for resetting offline state, and rename the current reset to reset_state (since the other methods for handling online state are all named _state)? I am wondering if this is actually neccesary since offline state can be reset by simply reinstantiating the detector (caveat that this depends on our definition of offline state_ i.e. all attributes, or just results of configure_thresholds etc).
I will take a closer look at the code next week but the overall approach you've described sounds sensible to me. I agree that there's no need for a method to reset offline state when this would just correspond to reinitialising the detector.
I will take a closer look at the code next week but the overall approach you've described sounds sensible to me. I agree that there's no need for a method to reset offline state when this would just correspond to reinitialising the detector.
Cheers! and yes no rush.
Beyond what we've discussed around the preemptive renaming of reset to reset_state it all looks good to me! Good work on the additional docstrings too!
Offline, @ojcobb asked why we don't have a separate reset and reset_state method, which is something mentioned in the notion doc.
This is because the second reset method will be added in the follow-up PR for handling "offline state". Since the other online state methods are all named ..._reset, the reset method for "online state" (reset in this PR) is named reset_state in the follow-up PR, and a new reset method is added to reset the "offline state". We need to decide if we want to make this change now or leave the follow-up. Making the change now would mean renaming reset to reset_state (in examples too), and perhaps fowarding reset to reset_state perhaps with a deprecation warning that its purpose will soon change...
Alternatively, we could leave reset as is, and add a new reset_all (or reinitialize?) or something similar for the offline state.
@jklaise @mauicv
Beyond what we've discussed around the preemptive renaming of
resettoreset_stateit all looks good to me! Good work on the additional docstrings too!
Thanks @ojcobb !
Perhaps a silly question about the general functionality? Is the intention that the user is saving state to the detector save directory or elsewhere?
from alibi_detect.cd import LSDDDriftOnline
from alibi_detect.saving import save_detector, load_detector
dd_new = load_detector(filepath)
dd.predict(x)
cd.save_state('check_pt')
if the user then late runs:
dd = load_detector(filepath)
That will not load the saved state right?
The 3.7 build failure is occurring before we even reach the state handling functionality. It appears to be a not-seen-before issue with the MMDDriftOnlineTF detector itself. See https://github.com/SeldonIO/alibi-detect/issues/706.
Perhaps a silly question about the general functionality? Is the intention that the user is saving state to the detector save directory or elsewhere?
from alibi_detect.cd import LSDDDriftOnline from alibi_detect.saving import save_detector, load_detector dd_new = load_detector(filepath) dd.predict(x) cd.save_state('check_pt')if the user then late runs:
dd = load_detector(filepath)That will not load the saved state right?
Not silly at all. I've gone back and forth over this functionality a fair bit. The current behaviour supports two distinct use cases:
-
"Runtime checkpointing": Users can give whatever filepath they like to
save_stateandload_stateto save/load checkpoints within their runtime. For example:cd = CVMDriftOnline(x_ref, ert, window_sizes) # Instantiate detector at t=0 cd.predict(x_1) # t=1 cd.save_state('checkpoint_t1') # Save state at t=1 cd.predict(x_2) # t=2 cd.predict(x_3) # t=3 cd.predict(x_4) # t=4 cd.save_state('checkpoint_t4') # Save state at t=4 # Go back to t=1 for whatever reason... cd.load_state('checkpoint_t1') -
Serialising detectors: Calling
save_detector(cd, filepath, save_state=True) will save the state infilepath/state/, along with the serialised detector itself.load_detector(filepath) will then load state if astate/dir exists insidefilepath.
I guess these two use cases could be mixed by manually doing cd.save_state(filepath + '/state') to add state to an already serialised detector, but that isn't really an intended use case...
Offline, @ojcobb asked why we don't have a separate
resetandreset_statemethod, which is something mentioned in the notion doc.This is because the second reset method will be added in the follow-up PR for handling "offline state". Since the other online state methods are all named
..._reset, the reset method for "online state" (resetin this PR) is namedreset_statein the follow-up PR, and a newresetmethod is added to reset the "offline state". We need to decide if we want to make this change now or leave the follow-up. Making the change now would mean renamingresettoreset_state(in examples too), and perhaps fowardingresettoreset_stateperhaps with a deprecation warning that its purpose will soon change...Alternatively, we could leave
resetas is, and add a newreset_all(orreinitialize?) or something similar for the offline state.@jklaise @mauicv
Is there a need to have a method to reset "offline" state at all? Should we just tell people to re-initialize the detector? Or would this be a convenience to avoid expensive re-initialization in case the constructor runs a lengthy computation?
Offline, @ojcobb asked why we don't have a separate
resetandreset_statemethod, which is something mentioned in the notion doc. This is because the second reset method will be added in the follow-up PR for handling "offline state". Since the other online state methods are all named..._reset, the reset method for "online state" (resetin this PR) is namedreset_statein the follow-up PR, and a newresetmethod is added to reset the "offline state". We need to decide if we want to make this change now or leave the follow-up. Making the change now would mean renamingresettoreset_state(in examples too), and perhaps fowardingresettoreset_stateperhaps with a deprecation warning that its purpose will soon change... Alternatively, we could leaveresetas is, and add a newreset_all(orreinitialize?) or something similar for the offline state. @jklaise @mauicvIs there a need to have a method to reset "offline" state at all? Should we just tell people to re-initialize the detector?
Tagging @ojcobb for this discussion too. I'm not strongly opinionated either way. I can see a reset (offline state) method being useful in the case where a detector has been instantiated with lots of arguments. i.e. to save the user needing to pack all their args/kwargs into a dict so they can easily reinstantate the detector later.
Or would this be a convenience to avoid expensive re-initialization in case the constructor runs a lengthy computation?
I don't think the reset (offline state) method would be for this. Our definition of "offline state" is essentially expensive-to-compute attributes that we want to save in order to avoid expensive computation at load time. In the 2nd PR these expensive operations are encapsulated in _initialise_offline_state. The offline reset method would pretty much just call _initialise_offline_state i.e. it would run the lengthy computations!
@ascillitoe in that case I don't see a particular reason to introduce another "reset" method, especially since having 2 would very likely lead to confusion!
Regarding possible renaming of reset to reset_state, I don't have strong opinions and can see pros for either convention (e.g. reset identifying an online detector that has some state - this means we can't repurpose reset for anything else in the future; reset_state - being more consistent with the other public methods load_state and save_state and being explicit that resetting resets some state - also leaves the door open to using the name reset for something else (or the same thing!) in the future).
e.g.
resetidentifying an online detector that has some state - this means we can't repurposeresetfor anything else in the future
This is a really important point I think. Since we are using _state to identify "online state" functionality, I'm now thinking it makes a lot of sense to rename reset to reset_state. The final think to think about with this would be how to deprecated reset. For example, point reset to reset_state for now and give a deprecation warning saying its functionality might change in future? Or simply just rename and list as a breaking change?
Maybe I'm missing something but I thought reset didn't exist at all until this PR? Agree that using reset_state likely makes most sense.
Maybe I'm missing something but I thought
resetdidn't exist at all until this PR? Agree that usingreset_statelikely makes most sense.
It did exist:
def reset(self) -> None:
"Resets the detector but does not reconfigure thresholds."
self._initialise()
"""
Resets the detector to its initial state (`t=0`). This does not include reconfiguring thresholds.
"""
self._initialise_state()
But its behavior has now changed. Previously, self._initialise called in reset did more than just initialising online state, so reset essentially reset all "online state" as well as some "offline state" (essentially everything except thresholds). This meant there was some potentially undesirable behavior with reset not being deterministic (see https://github.com/SeldonIO/alibi-detect/pull/604#issuecomment-1352896731). With the new behavior, reset exclusively only resets "online state" (takes us back to t=0).
Maybe I'm missing something but I thought
resetdidn't exist at all until this PR? Agree that usingreset_statelikely makes most sense.It did exist:
def reset(self) -> None: "Resets the detector but does not reconfigure thresholds." self._initialise() """ Resets the detector to its initial state (`t=0`). This does not include reconfiguring thresholds. """ self._initialise_state()But its behavior has now changed. Previously,
self._initialisecalled inresetdid more than just initialising online state, soresetessentially reset all "online state" as well as some "offline state" (essentially everything except thresholds). This meant there was some potentially undesirable behavior withresetnot being deterministic (see #604 (comment)). With the new behavior,resetexclusively only resets "online state" (takes us back tot=0).
Ah ok. Shall we introduce reset_state and deprecate the currently existing reset then?
EDIT: Alternatively, if reset was already used in a very similar way (but with extra things and some undesirable behaviour), it's probably also fine to stick with that, but we would not be able to repurpose that name in the future (without another deprecation cycle pushed further into the future).
Ah ok. Shall we introduce reset_state and deprecate the currently existing reset then?
I've done this in aa6b0d3, but kept reset as a link to reset_state with a DeprecationWarning. Can remove if you don't think this is necessary...
@ascillitoe Is this file move ok? I thought conftest.py should always be inside a folder named tests? alibi_detect/saving/tests/conftest.py → alibi_detect/conftest.py Doesn't this make conftest a public module of alibi-detect which we wouldn't want?
@ascillitoe Is this file move ok? I thought
conftest.pyshould always be inside a folder namedtests? alibi_detect/saving/tests/conftest.py → alibi_detect/conftest.py Doesn't this makeconftesta public module ofalibi-detectwhich we wouldn't want?
Mmn good point, thinking again it doesn't seem ideal to have it outside of tests/. I moved it so that we didn't have to duplicate the seed fixture in multiple places. I couldn't think of a better way to have a global conftest that is shared across all tests. Do you know of a way?
Re it becoming a public module I suspect you're mostly right. We do __all__ = ["ad", "cd", "models", "od", "utils", "saving"] in alibi_detect.__init__ so it won't be exposed by dir(alibi_detect) at least. However, seed could technically be imported by:
from alibi_detect.saving.conftest import seed
Weirdly though, with our alibi-detect v0.10.4, the following also works! (think I need to undo the file move, but maybe we also need to double check this separately)
from alibi_detect.saving.tests.conftest import seed
seed(0)
d190589 removes save_state from save_detector. The new logic is to always save state if self.t > 0. If this is not desired, .reset() can be called prior to using save_detector.
LGTM! Should we add some documentation somewhere that save will save the state by default, and if that's not desired one should call reset_state first? Or is it going to be too confusing for now?
Thanks! Will add this documentation now, just realised I added it in https://github.com/SeldonIO/alibi-detect/pull/628 instead of here. Doh!