Fix missing args in `tfe.stats.RunningVariance.from_shape`
Both the event_ndims and name arguments of the class method tfp.experimental.stats.RunningVariance.from_shape were missing, causing TypeError: from_example() got an unexpected keyword argument 'event_ndims'.
N.B. this fix sorts out the same TypeError when using RunningVariance.from_example().
How did you get the original error? RunningVariance is expected to have event_ndims=0, because it treats each element as an independent (co)variance estimation problem.
@SiegeLordEx The original problem arose with
tfp.experimental.stats.RunningVariance.from_example(np.array([1., 1., 1., 1.]))
The fundamental issue is with the Python MRO with RunningVariance derived from RunningCovariance:
-
RunningVariance.from_examplecallsRunningCovariance.from_example; -
RunningCovariance.from_examplecallsRunningVariance.from_shape;
The issue is caused because inside RunningCovariance.from_example, cls is a reference to the derived class, not (as I think the author assumed) the base class. Thus RunningCovariance.from_example assumes that RunningVariance.from_shape takes the same arguments as its own RunningCovariance.from_shape method.
Took me a while to revise my Python MRO semantics, but here's the MRE:
class Foo:
@classmethod
def bar(cls, baz, quz):
print("Foo.bar cls:", cls)
cls.quux(baz, quz)
@classmethod
def quux(cls, corge, grault):
print("Foo.quux cls:", cls)
class Garply(Foo):
@classmethod
def bar(cls, baz):
print("Garply.bar cls:", cls)
super().bar(baz, quz)
@classmethod
def quux(cls, corge, grault):
print("Garply.quux cls:", cls)
super().quux(corge, grault)
If Garply.bar's signature is incompatible with Foo.bar's signature (e.g. missing an argument), then the call stack fails.