vladyorsh
vladyorsh
The non-causal version performs the processing of non-coarsened inputs two times -- for off-diagonal blocks and for on-diagonal blocks: ``` qkvs = [(q, k, v, mask)] ... #Coarsening qkvs =...
As far as I understood, the model implementations for "ours" BERT4Rec and ALBERT4Rec tests on ML-1M use the model code located at `recommenders/dnn_sequential_recommender/models`. While the evaluation configs pass the `intermediate_size`...
I'm trying to checkpoint the flax's TrainState in the distributed setup, where each node has an access to multiple devices: ``` def save_checkpoint(args, state, step): state = unreplicate(state) # flax.jax._utils.unreplicate...