axon icon indicating copy to clipboard operation
axon copied to clipboard

Add more examples

Open seanmor5 opened this issue 4 years ago • 7 comments

Willing to accept examples on different datasets and models to demonstrate different parts of the Axon API and to demonstrate Axon's viability in the ecosystem. The TensorFlow guides are a great place to look for different datasets and problems. If you're blocked on any specific issue feel free to comment on the relevant issue with your use case :)

seanmor5 avatar Apr 08 '21 20:04 seanmor5

Hi @seanmor5 I am trying to create an example to predict diabetes with Axon and Nx, but I am still trying to understand how it works.

Currently I have an error:

X = #Nx.Tensor< f32[615][8]

Y = #Nx.Tensor< s64[615]

This is the code I'm trying to create: https://gist.github.com/tiagodavi/a905abeaf4d1f92c21f9df9043d196fe

StreamExecutor device (0): Host, Default Version
** (ArgumentError) expected input shapes to be equal, got {615} != {615, 1}
    (axon 0.1.0-dev) lib/axon/shared.ex:22: anonymous fn/1 in Axon.Shared."__defn:assert_shape!__"/2
    (nx 0.1.0) lib/nx/defn/compiler.ex:114: Nx.Defn.Compiler.__remote__/4
    (axon 0.1.0-dev) lib/axon/losses.ex:122: Axon.Losses."__defn:binary_cross_entropy__"/3
    (axon 0.1.0-dev) lib/axon/loop.ex:325: anonymous fn/5 in Axon.Loop.train_step/3
    (nx 0.1.0) lib/nx/defn/grad.ex:20: Nx.Defn.Grad.transform/3
    (axon 0.1.0-dev) lib/axon/loop.ex:332: anonymous fn/4 in Axon.Loop.train_step/3
    (axon 0.1.0-dev) lib/axon/loop.ex:1135: anonymous fn/4 in Axon.Loop.build_batch_fn/2
    (nx 0.1.0) lib/nx/defn/compiler.ex:101: Nx.Defn.Compiler.runtime_fun/4

tiagodavi avatar Jan 20 '22 20:01 tiagodavi

Hi @tiagodavi! Axon's implementation of BCE expects y_true to have a last dimension of size 1 (there's an explicit check for shape equality between y_true and y_pred). If you add a new axis to your y_true: Nx.new_axis(y, -1) - then the error should go away.

We can probably relax the strictly equal shape constraint, feel free to open a PR otherwise I will open an issue to track.

Also, you might be interested in trying out [Explorer](https://github.com/elixir-nx/explorer) for easier Nx/Axon interop with structured data :)

seanmor5 avatar Jan 20 '22 20:01 seanmor5

Thank you @seanmor5 .

I was able to fix the error, but accuracy is quite bad. I am probably doing something wrong still.

model =
     input
     |> Axon.dense(features, activation: :relu)
     |> Axon.dense(features, activation: :relu)
     |> Axon.dense(1, activation: :sigmoid)

   trained_model =
     model
     |> Axon.Loop.trainer(:binary_cross_entropy, :adam)      
     |> Axon.Loop.run([{x_train, y_train}], epochs: 10, compiler: EXLA)  


   # trying to interpret sigmoid here
   result = 
     model
     |> Axon.predict(trained_model, x_test, compiler: EXLA)
     |> Nx.map([type: {:s, 64}], fn x -> 
       if x > 0.5, do: 1, else: 0
     end)
   
   IO.inspect Axon.Metrics.accuracy(y_test, result)

   #Nx.Tensor<
     f32
     0.3464052379131317
   >

tiagodavi avatar Jan 20 '22 21:01 tiagodavi

Axon's accuracy should do that thresholding for you. What do you get if you just feed the result of Axon.predict(model, trained_model, x_test, compiler: EXLA) into Axon.Metrics.accuracy?

seanmor5 avatar Jan 20 '22 22:01 seanmor5

Something like that?

result =     
      model
      |> Axon.predict(trained_model, x_test, compiler: EXLA)
      |> Axon.Metrics.accuracy(y_test)

    IO.inspect result

#Nx.Tensor< f32 0.013071895577013493

I'll take this course to see if I can learn it better: https://grox.io/language/nx/course

tiagodavi avatar Jan 20 '22 22:01 tiagodavi

It is probably a bug, please send me the gist!

seanmor5 avatar Jan 20 '22 23:01 seanmor5

Sure, this is the most updated one: https://gist.github.com/tiagodavi/a905abeaf4d1f92c21f9df9043d196fe

tiagodavi avatar Jan 21 '22 12:01 tiagodavi