Quantcast
Channel: Active questions tagged python - Stack Overflow
Viewing all articles
Browse latest Browse all 23131

How can I differentiate between a PyTorch Tensor and a nested tensor?

$
0
0

Recently, PyTorch introduced the nested tensor. However, if I create a nested tensor, e.g.,

import torcha = torch.randn(20, 128)nt = torch.nested.nested_tensor([a, a], dtype=torch.float32)

and then look at its class type, it shows:

type(nt)torch.Tensor

i.e., the class type is just a regular PyTorch Tensor. So, type(nt) == torch.Tensor and isinstance(nt, torch.Tensor) will both return True.

So, my question is, is there a way to differentiate between a regular tensor and a nested tensor?

One way I can think of is that the size method to nested tensors (currently) works differently to that for regular tensors in that it requires an argument otherwise it raises a RuntimeError. So, a solution might be:

def is_nested_tensor(nt):    if not isinstance(nt, torch.Tensor):        return False    try:        # try calling size without an argument        nt.size()        return False    except RuntimeError:        return True    return False

but is there something simpler that doesn't rely on something like the size method not changing in the future?


Viewing all articles
Browse latest Browse all 23131

Trending Articles



<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>