aihwkit icon indicating copy to clipboard operation
aihwkit copied to clipboard

How does CUDA weights get transfered to pytorch?

Open ZhenmingYu opened this issue 3 years ago • 2 comments

Dear @maljoras

Thanks for answering all my questions.

I was debugging my CUDA device when I noticed this weird behavior that my weights are not being updated.

for epoch in range(100):
    pred = model(x)
    loss = mse_loss(pred, y)
    loss.backward()
    opt.step()

    print('pred: {:.16f}'.format(pred[1][0]))
    weights = model.get_weights()[0][0][0]
    count = count + 1
    print('Epoch%s:weights: {:.20f}'.format(weights)%count)

And I got:

Epoch0:weights: 0.00000000000000000000
pred: -0.0470588281750679
Epoch1:weights: 0.00000000000000000000
pred: 0.0941176563501358
Epoch2:weights: 0.00000000000000000000
pred: 0.0470588281750679
Epoch3:weights: 0.00000000000000000000
pred: 0.0470588281750679
Epoch4:weights: 0.00000000000000000000
pred: -0.0470588281750679
Epoch5:weights: 0.00000000000000000000
pred: -0.0470588281750679
.......
Epoch95:weights: 0.00000000000000000000
pred: 0.0000000000000000
Epoch96:weights: 0.00000000000000000000
pred: 0.0470588281750679
Epoch97:weights: 0.00000000000000000000
pred: 0.0470588281750679
Epoch98:weights: 0.00000000000000000000
pred: -0.0470588281750679
Epoch99:weights: 0.00000000000000000000
pred: -0.0941176563501358
Epoch100:weights: 0.00000000000000000000

So I went to my UpdateFunctor and tried to print out the weights there:

template  struct UpdateFunctorJARTv1b {

  __device__ __forceinline__ void operator()(
      T &apparent_weight,
      uint32_t n,
      uint32_t negative,
      float4 &par_4,
      float2 &par_2,
      // const float4 par_4,
      // const float2 par_2,
      T &persistent_weight,
      const T *global_pars,
      const int global_params_count,
      T noise_std_dw,
      curandState &local_state) {

    UNUSED(global_params_count); // fixed

    const T read_voltage        = global_pars[0];
    const T pulse_voltage_SET   = global_pars[1];
    const T pulse_voltage_RESET = global_pars[2];
    const T pulse_length        = global_pars[3];
......
    const T Ndiscmax_std        = global_pars[60];
    const T Ndiscmin_std        = global_pars[61];
    const T ldet_std            = global_pars[62];
    const T rdet_std            = global_pars[63];
    
    const T &weight_min_bound = par_4.x;                          // [0]
    T &device_specific_Ndiscmin_cuda = par_4.y; // [1]
    const T &weight_max_bound = par_4.z;                          // [2]
    T &device_specific_Ndiscmax_cuda = par_4.w; // [3]

    T &device_specific_ldet_cuda = par_2.x; // [0]
    T &device_specific_A_cuda = par_2.y; // [1]

    T &w = apparent_weight;
    T &Ndisc = persistent_weight;
    printf("w before update %.20f\n", apparent_weight);
    printf("Ndisc before update %.20e\n", persistent_weight);

    // n is larger 0 in any case
    if (n == 1) {
      update_once(read_voltage, pulse_voltage_SET, pulse_voltage_RESET, pulse_length, base_time_step,
                                  ......
                                  Ndiscmax_std, Ndiscmin_std, ldet_std, rdet_std,
                                  local_state);
    } else {
      for (int i_updates = 0; i_updates 

And I was surprised to see that my weights do get updated internally

w before update 0.00000000000000000000
Ndisc before update 3.42099665000402799202e+25
w after update 0.22026008367538452148
Ndisc after update 5.76589326109617243575e+25
w before update 0.22026008367538452148
Ndisc before update 5.76589326109617243575e+25
w after update 0.31056439876556396484
Ndisc after update 7.28105624609006493326e+25
w before update 0.31056439876556396484
Ndisc before update 7.28105624609006493326e+25
w after update 0.35697650909423828125
Ndisc after update 8.27853072773657625971e+25
......
w before update 0.53407245874404907227
Ndisc before update 1.48834197251304102075e+26
w after update 0.53555470705032348633
Ndisc after update 1.49716643370930182751e+26
w before update 0.53555470705032348633
Ndisc before update 1.49716643370930182751e+26
w after update 0.53698539733886718750
Ndisc after update 1.50576796600343185433e+26

How could it behave this way? Is there any intermediate step to transfer the weights from CUDA to PyTorch that I am missing out on? The CPU version of the exact same PyTorch setup works just fine, so as far as I understand, the config parameters should be correct as well.

Thank you so much for your patience.

Zhenming

ZhenmingYu avatar Aug 12 '22 15:08 ZhenmingYu

Maybe the overloading of the persistent weights makes some troubles, but in principle as long as you set the weights in the update_once they should be reflected in the python code. It might be though that using the persistent weight functionality for another purpose is troublesome since you need to carefully override all base methods that assume that the persistent weight is something else. Maybe you forgot something there ?

Also, having 60+ global parameters is maybe not a good idea, there is actually a limit on the number. Do you really need that many? These all will be transferred to GPU and it will slow down the calculation.

I am not sure whether you want to make your new device a public contribution to the code base, but if you want to do that, you could open a PR request and I could take a look where there might be problems and give some suggestions. CUDA progamming can be tricky.

maljoras avatar Aug 12 '22 15:08 maljoras

Dear @maljoras

We do intend to make my new device publicly available, but ideally after we publish it at a conference in October. (It’s a shame that the window for adding authors has passed, I would have gladly added you for the contribution.)

I have created a mirror repo on GitHub and sent you an invite. Could you please help me check on it?

Maybe the overloading of the persistent weights makes some troubles, but in principle as long as you set the weights in the update_once they should be reflected in the python code. It might be though that using the persistent weight functionality for another purpose is troublesome since you need to carefully override all base methods that assume that the persistent weight is something else. Maybe you forgot something there ?

I override these functions and make these arrays protected so that I can modify them in the child class device.

Also, having 60+ global parameters is maybe not a good idea, there is actually a limit on the number. Do you really need that many? These all will be transferred to GPU and it will slow down the calculation.

I also double-checked the unnecessary global parameters, and shrink the number to 53. But they are needed since this is a very complex model. Regarding this, I also see that in the PiecewiseStep device it's required that the global_params_count has to be factors of two. Is that a CUDA requirement? Or is it just for the convenience of PiecewiseStep look-up tables?

I am not sure whether you want to make your new device a public contribution to the code base, but if you want to do that, you could open a PR request and I could take a look where there might be problems and give some suggestions. CUDA progamming can be tricky.

In the long term, I do wish to merge into the master branch of aihwkit to promote ease of using, but there might be a few legal issues to be sorted out. We are also trying to go through the steps to establish a collaboration with the Neuromorphic Devices and Systems Group of IBM Research Zürich. We are planning to use aihwkit with this new model to fit with their RRAM device and explore better programming algorithms at the edge. I wish the paperwork can went through and all of us can benefit from this eventually.

Thanks for all the help along the way! Regards, Zhenming

ZhenmingYu avatar Aug 15 '22 09:08 ZhenmingYu