Let's consider a subclass of numpy
's ndarray
class:
import numpy as npclass ArraySubClass(np.ndarray): def __new__(cls, input_array: np.ndarray): obj = np.asarray(input_array).view(cls) return obj
Then, taking a slice of an ArraySubClass
return an object of the same type, as explained in the documentation:
>>> type(ArraySubClass(np.zeros((3, 3)))[:, 0])<class '__main__.ArraySubClass'>
So far so good, but I start getting unexpected behavior when I use the static type checked pyright
, as seen with the example below:
def f(x: ArraySubClass): print(x)f(ArraySubClass(np.zeros((3, 3)))[:, 0])
The last line raises an error from pyright
:
Argument of type "ndarray[Any, Unknown]" cannot be assigned to parameter "x" of type "ArraySubClass" in function "f" "ndarray[Any, Unknown]" is incompatible with "ArraySubClass"
What causes this behavior? Is this a pyright
bug? Is the signature of the __get_item__
method from np.ndarray
given by the type hints incorrect? Or perhaps should I override this method in ArraySubClass
with the correct signature?