Quantcast
Channel: Active questions tagged python - Stack Overflow
Viewing all articles
Browse latest Browse all 23276

Run parallel function on Apple GPU/metal in Python

$
0
0

I have a rather simple function that runs quite fast in parallel using Numba, and I would like to know if I can run it on my Apple M3 Max GPU. However I have never worked on GPU code before (coming from Macs), so I am a little lost..

I have included a small use case:

import numpy as npimport numba as nbN = 15      # size of the a and bK = 127     # size of the resultL = 943     # size of the operatora = np.random.standard_normal(size=N)b = np.random.standard_normal(size=N)operator = np.zeros(shape=(L, 4), dtype=np.int64)operator[:, 0] = np.random.randint(size=L, low=0, high=N)operator[:, 1] = np.random.randint(size=L, low=0, high=N)operator[:, 2] = np.random.randint(size=L, low=0, high=K)operator[:, 3] = np.random.randint(size=L, low=1, high=10)@nb.njit(parallel=True)def shuffle_mul(a: np.ndarray, b: np.ndarray, operator: np.ndarray) -> np.ndarray:    res = np.zeros(shape=K, dtype=a.dtype)    for n in nb.prange(len(operator)):        i, j, k, count = operator[n]        res[k] += count * a[i] * b[j]    return resshuffle_mul(a, b, operator=operator)  # warm-up%timeit shuffle_mul(a, b, operator=operator)  # 173 µs ± 41.3 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Question 1: Is it even theoretically possible/interesting to run such a function on GPU? The fact that multiple instances might read a and b or add to res at the same time is a problem?

Question 2: What would be the best approach? Also, what library should I use? metalcompute, jax, PyTorch, ... ?

Thanks a lot!


Viewing all articles
Browse latest Browse all 23276

Trending Articles