Recently, PyTorch introduced the nested tensor. However, if I create a nested tensor, e.g.,
import torcha = torch.randn(20, 128)nt = torch.nested.nested_tensor([a, a], dtype=torch.float32)and then look at its class type, it shows:
type(nt)torch.Tensori.e., the class type is just a regular PyTorch Tensor. So, type(nt) == torch.Tensor and isinstance(nt, torch.Tensor) will both return True.
So, my question is, is there a way to differentiate between a regular tensor and a nested tensor?
One way I can think of is that the size method to nested tensors (currently) works differently to that for regular tensors in that it requires an argument otherwise it raises a RuntimeError. So, a solution might be:
def is_nested_tensor(nt): if not isinstance(nt, torch.Tensor): return False try: # try calling size without an argument nt.size() return False except RuntimeError: return True return Falsebut is there something simpler that doesn't rely on something like the size method not changing in the future?