Plot PyTorch tensors with matplotlib
This article covers:
Have you ever tried to plot a PyTorch tensor with matplotlib like:
and then received the following error?
AttributeError: 'Tensor' object has no attribute 'ndim'
You can get around this easily by letting all PyTorch tensors know how to respond to ndim like this:
torch.Tensor.ndim = property(lambda self: len(self.shape))
Basically, this uses the
property decorator to create ndim as a property which reads its value as the length of self.shape.
Thus, after you define this, a PyTorch tensor has ndim, so it can be plotted like shown here:
import torch import matplotlib.pyplot as plt x = torch.linspace(-5,5,100) x_squared = x * x plt.plot(x, x_squared) # Fails: 'Tensor' object has no attribute 'ndim' torch.Tensor.ndim = property(lambda self: len(self.shape)) # Fix it plt.plot(x, x_squared) # Works now plt.show()
Read other posts
comments powered by Disqus