Static types for the Keras API
The recommended API for use with TensorFlow is now Keras, so we should have static type definitions for it.
Hi @shadaj I'm interested in contributing Keras API facades - were there any big changes since the last scalapy-tensorflow update or can I do it in similar way to already defined facades?
HI @Avasil, that's awesome to hear! Not really, you should be able to add in facades just like the existing ones.
@shadaj Do you have any advice on figuring out proper type to return in facade?
I want to try mnist example with static types but I struggle with:
import me.shadaj.scalapy.py
import me.shadaj.scalapy.numpy.NDArray
@py.native trait Mnist extends py.Object {
def load_data(): ((NDArray[Long], NDArray[Long]), (NDArray[Long], NDArray[Long])) = py.native
}
// somewhere else
py.module("keras.datasets.mnist").as[Mnist].load_data()
I also tried py.module("tensorflow.keras.datasets.mnist").as[Mnist].load_data()
[error] Exception in thread "main" scala.MatchError: jep.NDArray@a64a2421 (of class jep.NDArray)
[error] at me.shadaj.scalapy.py.JepPyValue.getLong(JepInterpreter.scala:193)
[error] at me.shadaj.scalapy.py.JepPyValue.getLong$(JepInterpreter.scala:193)
[error] at me.shadaj.scalapy.py.JepJavaPyValue.getLong(JepInterpreter.scala:258)
[error] at me.shadaj.scalapy.py.Reader$$anon$6.read(Reader.scala:42)
[error] at me.shadaj.scalapy.py.Reader$$anon$6.read(Reader.scala:41)
[error] at me.shadaj.scalapy.py.Any.as(Any.scala:15)
[error] at me.shadaj.scalapy.py.Any.as$(Any.scala:15)
[error] at me.shadaj.scalapy.py.FacadeValueProvider.as(Facades.scala:5)
[error] at me.shadaj.scalapy.numpy.NDArray.apply(NDArray.scala:31)
[error] at me.shadaj.scalapy.numpy.NDArray.$anonfun$iterator$1(NDArray.scala:33)
[error] at me.shadaj.scalapy.numpy.NDArray.$anonfun$iterator$1$adapted(NDArray.scala:33)
[error] at scala.collection.Iterator$$anon$10.next(Iterator.scala:459)
[error] at scala.collection.Iterator.foreach(Iterator.scala:941)
[error] at scala.collection.Iterator.foreach$(Iterator.scala:941)
[error] at scala.collection.AbstractIterator.foreach(Iterator.scala:1429)
[error] at scala.collection.IterableLike.foreach(IterableLike.scala:74)
[error] at scala.collection.IterableLike.foreach$(IterableLike.scala:73)
[error] at me.shadaj.scalapy.numpy.NDArray.foreach(NDArray.scala:6)
[error] at scala.collection.TraversableOnce.addString(TraversableOnce.scala:362)
[error] at scala.collection.TraversableOnce.addString$(TraversableOnce.scala:358)
[error] at me.shadaj.scalapy.numpy.NDArray.addString(NDArray.scala:6)
[error] at scala.collection.TraversableOnce.mkString(TraversableOnce.scala:328)
[error] at scala.collection.TraversableOnce.mkString$(TraversableOnce.scala:327)
[error] at me.shadaj.scalapy.numpy.NDArray.mkString(NDArray.scala:6)
[error] at scala.collection.TraversableLike.toString(TraversableLike.scala:688)
[error] at scala.collection.TraversableLike.toString$(TraversableLike.scala:688)
[error] at scala.collection.SeqLike.toString(SeqLike.scala:693)
[error] at scala.collection.SeqLike.toString$(SeqLike.scala:693)
[error] at me.shadaj.scalapy.numpy.NDArray.toString(NDArray.scala:6)
[error] at java.base/java.lang.String.valueOf(String.java:2951)
[error] at java.base/java.lang.StringBuilder.append(StringBuilder.java:168)
[error] at scala.Tuple2.toString(Tuple2.scala:27)
[error] at java.base/java.lang.String.valueOf(String.java:2951)
[error] at java.base/java.io.PrintStream.println(PrintStream.java:897)
[error] at scala.Console$.println(Console.scala:271)
[error] at scala.Predef$.println(Predef.scala:397)
[error] at me.shadaj.scalapy.tensorflow.Example$.delayedEndpoint$me$shadaj$scalapy$tensorflow$Example$1(Example.scala:12)
[error] at me.shadaj.scalapy.tensorflow.Example$delayedInit$body.apply(Example.scala:7)
[error] at scala.Function0.apply$mcV$sp(Function0.scala:39)
[error] at scala.Function0.apply$mcV$sp$(Function0.scala:39)
[error] at scala.runtime.AbstractFunction0.apply$mcV$sp(AbstractFunction0.scala:17)
[error] at scala.App.$anonfun$main$1$adapted(App.scala:80)
[error] at scala.collection.immutable.List.foreach(List.scala:392)
[error] at scala.App.main(App.scala:80)
[error] at scala.App.main$(App.scala:78)
[error] at me.shadaj.scalapy.tensorflow.Example$.main(Example.scala:7)
[error] at me.shadaj.scalapy.tensorflow.Example.main(Example.scala)
Basically I'm going at it a bit blindly and would appreciate any tips :D
BTW should Keras have a reference in TensorFlow or should it be top-level?
Ah, this is a bug with the old Jep backend for ScalaPy that was fixed in 0.3.0+17-2bfe86de. For now, you should be able to just upgrade to that version. There will likely be a full release soon, so before we merge in we can upgrade to that.
I think Keras should probably be top-level, since AFAIK most developers import it separately from regular TensorFlow.