Have you ever tried to plot a PyTorch tensor with matplotlib like:


and then received the following error?

AttributeError: 'Tensor' object has no attribute 'ndim'

You can get around this easily by letting all PyTorch tensors know how to respond to ndim like this:

torch.Tensor.ndim = property(lambda self: len(self.shape))

Basically, this uses the property decorator to create ndim as a property which reads its value as the length of self.shape.

Thus, after you define this, a PyTorch tensor has ndim, so it can be plotted like shown here:

import torch
import matplotlib.pyplot as plt

x = torch.linspace(-5,5,100)
x_squared = x * x

plt.plot(x, x_squared)  # Fails: 'Tensor' object has no attribute 'ndim'

torch.Tensor.ndim = property(lambda self: len(self.shape))  # Fix it

plt.plot(x, x_squared)  # Works now
Resulting plot output
Resulting plot output