Boyuan Chen
Boyuan Chen
@shoyer The change in only one line didn't work. Could you elaborate on where else to add remat?
To make viewing convenient, the whole loss function looks like this: ``` def loss_fn(network_fn): def batchify(fn): return jax.remat(lambda inputs: jnp.concatenate([fn(inputs[i:i + batchify_size]) for i in range(0, inputs.shape[0], batchify_size)], 0)) z_vals...
@cgarciae Thanks for your advice! The iterator is a blind spot for me, so I am not sure how fast it could perform. Do you think it will be as...
@cgarciae Another tricky part is that I don't put multiple images into one batch, but I need to divide each image into multiple batches. I guess I should make a...
@cgarciae Thanks for your advice! It actually improved the code speed on cpu for a great extent! Below is a demo on how I used it: ``` target_data = tf.data.Dataset.from_tensor_slices(target_batched)...
> @BoyuanJackChen Same problem here, have you solved the stop word issue? thx! Yes. If you are using huggingface transformers library, then you can do the following: ``` from transformers...
Thanks all! I have resolved the issue with the given advices.
Had the same error when running text-to-image code provided in README. I'm using cuda 11.8, python 3.10, Ubuntu 22.04, RTX4090.