This week I wanted to start with a "normal" approach that felt pythonic with a particle class and a bunch of for loops / if statements, but still using numpy arrays. Then rewrite it with everything possible vectorized in numpy and see the difference in solve times for each.
Below is a 5000 particle simulation with 1000 time steps that took 0.18 s to solve for (and incidentally much longer to write to a .mp4 or .gif file).
This week I wanted to start with a "normal" approach that felt pythonic with a particle class and a bunch of for loops / if statements. Then rewrite it with everything possible vectorized in numpy and see the difference in solve times for each. Below is a 5000 particle simulation with 1000 time steps that took 0.18 s to solve for (and incidentally much longer to write to a .mp4 or .gif file). There's one nice vectorization step where where you can advance each particle by adding the product of velocity and timestep, but then a lot of np.where statements to check for collisions.
-[Here you can find the fast all numpy code used to make the simulation above.](https://gitlab.cba.mit.edu/davepreiss/nmm/-/blob/a6f0474d8cdb6bec89cc871ceecbd14fa6b71472/week1/maxwells_demon_fast.py)
[](maxwells_demon.py)
-[And here you can find the first slower one.](https://gitlab.cba.mit.edu/davepreiss/nmm/-/blob/558752ea110a91f466dedd88306fcb5f8abbc0e9/week1/maxwells_demon.py)
Interestingly Numba resulted in a ~4.5 s solve time compared to a 0.17 s time without it. From a very helpful conversation with Erik - for low particle counts
Erik
gpu has overhead to move from cpu to gpu
gpu is good at multithreading
optimized for throughput over latency
might be slower on gpu for small numbers of particles
### Numba / Jax
jax and numba will both
Dissapointingly Jax doesn't have native support for windows, but Numba was easy to get up and running with just a single @jit(nopython=True) decorator. On first run it resulted in a 4.5 s solve time, which is more than an order of magnitude increase. After talking with Erik, it seems like JIT compilation adds a lot of time overhead to pre-compile, so you need to make sure you are at a simulation complexity (particle count or time scale) to catch up to the raw numpy. After this I didn't bother trying trying to run anything on the GPU, but will likely wind up in a place where this is necessary in later weeks.