## Plot PyTorch tensors with matplotlib

Have you ever tried to plot a PyTorch tensor with matplotlib like:

```
plt.plot(tensor)
```

and then received the following error?

```
AttributeError: 'Tensor' object has no attribute 'ndim'
```

You can get around this easily by letting all PyTorch tensors know how to respond to ndim like this:

```
torch.Tensor.ndim = property(lambda self: len(self.shape))
```

Basically, this uses the `property`

decorator to create ndim as a property which reads its value as the length of self.shape.

Thus, after you define this, a PyTorch tensor has ndim, so it can be plotted like shown here:

```
import torch
import matplotlib.pyplot as plt
x = torch.linspace(-5,5,100)
x_squared = x * x
plt.plot(x, x_squared) # Fails: 'Tensor' object has no attribute 'ndim'
torch.Tensor.ndim = property(lambda self: len(self.shape)) # Fix it
plt.plot(x, x_squared) # Works now
plt.show()
```

Read other posts

comments powered by Disqus