tensorwatch
tensorwatch copied to clipboard
Logging PyTorch scalars
With latest tensorwatch and pytorch 1.4, I would expect the following code, based on the example in readme to create a plot of loss:
import tensorwatch as tw
import time
import torch
w = tw.Watcher(filename='test.log')
s = w.create_stream(name='metric1')
w.make_notebook()
# loss is a pytorch scalar with value 2
loss = torch.tensor(2)
for i in range(1000):
# Without the next line, plotting does not work
# loss = loss.reshape((1,1))
print("loss",loss)
s.write(loss)
loss +=1
time.sleep(1)
However, the plot remains empty when logging the pytorch scalar. Only when I reshape it to a 1x1 matrix with loss.reshape((1,1)) does the plot update in jupyter notebook.
I personally found this confusing - I think it would be very useful to have direct support for scalars here.