lightweight_mmm icon indicating copy to clipboard operation
lightweight_mmm copied to clipboard

TypeError: where() got some positional-only arguments passed as keyword arguments: 'condition, x, y'

Open datainsight1 opened this issue 1 year ago • 15 comments

TypeError Traceback (most recent call last) Cell In[9], line 4 2 number_warmup=100 3 number_samples=100 ----> 4 mmm.fit( 5 media=media_data_train, 6 media_prior=costs, 7 target=target_train, 8 extra_features=extra_features_train, 9 number_warmup=number_warmup, 10 number_samples=number_samples, 11 seed=SEED)

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/lightweight_mmm/lightweight_mmm.py:363, in LightweightMMM.fit(self, media, media_prior, target, extra_features, degrees_seasonality, seasonality_frequency, weekday_seasonality, media_names, number_warmup, number_samples, number_chains, target_accept_prob, init_strategy, custom_priors, seed) 353 kernel = numpyro.infer.NUTS( 354 model=self._model_function, 355 target_accept_prob=target_accept_prob, 356 init_strategy=init_strategy) 358 mcmc = numpyro.infer.MCMC( 359 sampler=kernel, 360 num_warmup=number_warmup, 361 num_samples=number_samples, 362 num_chains=number_chains) --> 363 mcmc.run( 364 rng_key=jax.random.PRNGKey(seed), 365 media_data=jnp.array(media), 366 extra_features=extra_features, 367 target_data=jnp.array(target), 368 media_prior=jnp.array(media_prior), 369 degrees_seasonality=degrees_seasonality, 370 frequency=seasonality_frequency, 371 transform_function=self._model_transform_function, 372 weekday_seasonality=weekday_seasonality, 373 custom_priors=custom_priors) 375 self.custom_priors = custom_priors 376 if media_names is not None:

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/mcmc.py:638, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs) 636 else: 637 if self.chain_method == "sequential": --> 638 states, last_state = _laxmap(partial_map_fn, map_args) 639 elif self.chain_method == "parallel": 640 states, last_state = pmap(partial_map_fn)(map_args)

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/mcmc.py:166, in _laxmap(f, xs) 164 for i in range(n): 165 x = jit(_get_value_from_index)(xs, i) --> 166 ys.append(f(x)) 168 return tree_map(lambda *args: jnp.stack(args), *ys)

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/mcmc.py:416, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields) 414 # Check if _sample_fn is None, then we need to initialize the sampler. 415 if init_state is None or (getattr(self.sampler, "_sample_fn", None) is None): --> 416 new_init_state = self.sampler.init( 417 rng_key, 418 self.num_warmup, 419 init_params, 420 model_args=args, 421 model_kwargs=kwargs, 422 ) 423 init_state = new_init_state if init_state is None else init_state 424 sample_fn, postprocess_fn = self._get_cached_fns()

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/hmc.py:713, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs) 708 # vectorized 709 else: 710 rng_key, rng_key_init_model = jnp.swapaxes( 711 vmap(random.split)(rng_key), 0, 1 712 ) --> 713 init_params = self._init_state( 714 rng_key_init_model, model_args, model_kwargs, init_params 715 ) 716 if self._potential_fn and init_params is None: 717 raise ValueError( 718 "Valid value of init_params must be provided with" " potential_fn." 719 )

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/hmc.py:657, in HMC._init_state(self, rng_key, model_args, model_kwargs, init_params) 650 def _init_state(self, rng_key, model_args, model_kwargs, init_params): 651 if self._model is not None: 652 ( 653 new_init_params, 654 potential_fn, 655 postprocess_fn, 656 model_trace, --> 657 ) = initialize_model( 658 rng_key, 659 self._model, 660 dynamic_args=True, 661 init_strategy=self._init_strategy, 662 model_args=model_args, 663 model_kwargs=model_kwargs, 664 forward_mode_differentiation=self._forward_mode_differentiation, 665 ) 666 if init_params is None: 667 init_params = new_init_params

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/util.py:656, in initialize_model(rng_key, model, init_strategy, dynamic_args, model_args, model_kwargs, forward_mode_differentiation, validate_grad) 646 model_kwargs = {} if model_kwargs is None else model_kwargs 647 substituted_model = substitute( 648 seed(model, rng_key if is_prng_key(rng_key) else rng_key[0]), 649 substitute_fn=init_strategy, 650 ) 651 ( 652 inv_transforms, 653 replay_model, 654 has_enumerate_support, 655 model_trace, --> 656 ) = _get_model_transforms(substituted_model, model_args, model_kwargs) 657 # substitute param sites from model_trace to model so 658 # we don't need to generate again parameters of numpyro.module 659 model = substitute( 660 model, 661 data={ (...) 665 }, 666 )

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/infer/util.py:450, in _get_model_transforms(model, model_args, model_kwargs) 448 def _get_model_transforms(model, model_args=(), model_kwargs=None): 449 model_kwargs = {} if model_kwargs is None else model_kwargs --> 450 model_trace = trace(model).get_trace(*model_args, **model_kwargs) 451 inv_transforms = {} 452 # model code may need to be replayed in the presence of deterministic sites

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/handlers.py:171, in trace.get_trace(self, *args, **kwargs) 163 def get_trace(self, *args, **kwargs): 164 """ 165 Run the wrapped callable and return the recorded trace. 166 (...) 169 :return: OrderedDict containing the execution trace. 170 """ --> 171 self(*args, **kwargs) 172 return self.trace

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs) 103 return self 104 with self: --> 105 return self.fn(*args, **kwargs)

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs) 103 return self 104 with self: --> 105 return self.fn(*args, **kwargs)

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/numpyro/primitives.py:105, in Messenger.call(self, *args, **kwargs) 103 return self 104 with self: --> 105 return self.fn(*args, **kwargs)

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/lightweight_mmm/models.py:385, in media_mix_model(media_data, target_data, media_prior, degrees_seasonality, frequency, transform_function, custom_priors, transform_kwargs, weekday_seasonality, extra_features) 380 elif transform_function == "carryover" and not transform_kwargs: 381 transform_kwargs = {"number_lags": 13 * 7} 383 media_transformed = numpyro.deterministic( 384 name="media_transformed", --> 385 value=transform_function(media_data, 386 custom_priors=custom_priors, 387 **transform_kwargs if transform_kwargs else {})) 388 seasonality = media_transforms.calculate_seasonality( 389 number_periods=data_size, 390 degrees=degrees_seasonality, 391 frequency=frequency, 392 gamma_seasonality=gamma_seasonality) 393 # For national model's case

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/lightweight_mmm/models.py:280, in transform_carryover(media_data, custom_priors, number_lags) 278 if media_data.ndim == 3: 279 exponent = jnp.expand_dims(exponent, axis=-1) --> 280 return media_transforms.apply_exponent_safe(data=carryover, exponent=exponent)

[... skipping hidden 11 frame]

File ~/anaconda3/envs/pytorch_p310/lib/python3.10/site-packages/lightweight_mmm/media_transforms.py:189, in apply_exponent_safe(data, exponent) 172 @jax.jit 173 def apply_exponent_safe( 174 data: jnp.ndarray, 175 exponent: jnp.ndarray, 176 ) -> jnp.ndarray: 177 """Applies an exponent to given data in a gradient safe way. 178 179 More info on the double jnp.where can be found: (...) 187 The result of the exponent operation with the inputs provided. 188 """ --> 189 exponent_safe = jnp.where(condition=(data == 0), x=1, y=data) ** exponent 190 return jnp.where(condition=(data == 0), x=0, y=exponent_safe)

TypeError: where() got some positional-only arguments passed as keyword arguments: 'condition, x, y'

datainsight1 avatar May 10 '24 21:05 datainsight1

Hi, I had the same issue, have you resolved it?

ShirleyChai730 avatar May 14 '24 16:05 ShirleyChai730

HI @ShirleyChai730 : I haven't yet been able to resolve the above issue.

datainsight1 avatar May 14 '24 16:05 datainsight1

Hi @ShirleyChai730 This is probably due to an update of the jax library. I explicitly installed jax and jaxlib version 0.4.20 and it works fine in my local.

Munger245 avatar May 16 '24 01:05 Munger245

Thank you @Munger245 . It works.

datainsight1 avatar May 16 '24 18:05 datainsight1

Hi @ShirleyChai730 This is probably due to an update of the jax library. I explicitly installed jax and jaxlib version 0.4.20 and it works fine in my local.

Thanks for pointing out this. I tried 0.4.20 and it still didn't work but I tried the older version 0.4.19 it works.

ShirleyChai730 avatar May 16 '24 21:05 ShirleyChai730

@ShirleyChai730 I am also getting this error on mac m2. What is the version of lightweight_mmm that worked on your machine? Can you please share requirement file here with python version?

rahulmisal27 avatar Jun 06 '24 15:06 rahulmisal27

@rahulmisal27 : I am using the latest version of lightweight mmm and it works.

datainsight1 avatar Jun 11 '24 17:06 datainsight1

I tried installing jax and jaxlib 0.4.20 and have the same error, how did you fix it? @datainsight1

bristobal avatar Jun 24 '24 21:06 bristobal

In a fresh Python 3.10 environment I needed to fix these versions to get things working:

jax==0.4.20 jaxlib==0.4.20 scipy==1.12.0

jamesvrt avatar Jun 27 '24 21:06 jamesvrt

hi there! im running into the same error with python 3.11 environment.. Anyone has figured out which version of jax is appropriate for this env?

ezjsiwu avatar Jul 02 '24 19:07 ezjsiwu

Hi, I have same issue.

[7/13/24 edit] Thanks for @jamesvrt, it worked in my environment (pipenv virtual environment Python 3.10)!

8-u8 avatar Jul 11 '24 06:07 8-u8

I had the same error message and installing jax and jaxlib versions 0.4.20 did not work for me. I have since fixed it and i'll list below the steps I took in case anyone has the same issue. Firstly, I created a python virtual environment using Anaconda with python version 3.10.14 as that's the latest version we know that works according to lightweight_mmm/setup.py. Secondly, I checked the lightweight_mmm/requirements/requirements.txt file to find the package versions listed in there which say that jax and jaxlib have to be versions 0.3.18 or higher. Apperantly, this version does not even exist, so I have used 0.4.18 instead. The final error I was facing was with the version of numpyro so I've once again used the version listed in requirements.txt file and installed version 0.9.2. The bit of code that does all this is: %pip install jax==0.4.18 jaxlib==0.4.18 numpyro==0.9.2. Finally, I am using the latest version of lightweight_mmm 0.1.9. You can check the versions of your packages by running %pip show jax jaxlib numpyro lightweight_mmm.

rora00 avatar Jul 20 '24 13:07 rora00

I encountered the same problem. My python version is 3.11.5. Finally, I followed the instructions of the two issues and installed the following versions:

seaborn==0.11.1
scipy==1.12.0
numpy==1.26.0
pyarrow==14.0.0
jax==0.4.18
jaxlib==0.4.18
numpyro==0.11.0
lightweight-mmm==0.1.9

This is useful for me!

lsypro avatar Aug 29 '24 11:08 lsypro

I am using python 3.10. In my case i also have to update the numpyro library to make it work. Packages updated below:

scipy==1.12.0 jax==0.4.19 jaxlib==0.4.19 numpyro==0.13.2 lightweight-mmm==0.1.9

AdeuAndreu avatar Oct 23 '24 11:10 AdeuAndreu

I encountered the same problem. My python version is 3.11.5. Finally, I followed the instructions of the two issues and installed the following versions:

seaborn==0.11.1
scipy==1.12.0
numpy==1.26.0
pyarrow==14.0.0
jax==0.4.18
jaxlib==0.4.18
numpyro==0.11.0
lightweight-mmm==0.1.9

This is useful for me!

Thank you ! This set up worked for me with python 3.11.7

anudanda avatar Oct 28 '24 02:10 anudanda