Skip to content
Snippets Groups Projects
Commit 886e4f28 authored by Neil Gershenfeld's avatar Neil Gershenfeld
Browse files

wip

parent ba905fc6
Branches
No related tags found
No related merge requests found
Pipeline #47073 passed
#
# jaxpi.py
# Neil Gershenfeld 12/21/24
# Jax pi calculation benchmark
# pi = 3.14159265358979323846
#
import jax
import jax.numpy as jnp
import numpy as np
import time
#
NPTS = 100000000
#
a = 0.5
b = 0.75
c = 0.25
#
# alternate compilation values to prevent caching
#
a0 = 0.6
b0 = 0.7
c0 = 0.2
#
print("\nNumPy version:")
def num_calcpi(a,b,c):
i = np.arange(1,(NPTS+1),dtype=float)
pi = np.sum(a/((i-b)*(i-c)))
return pi
start_time = time.time()
pi = num_calcpi(a,b,c)
end_time = time.time()
mflops = NPTS*5.0/(1.0e6*(end_time-start_time))
print("NPTS = %d, pi = %f"%(NPTS,pi))
print("time = %f, estimated MFlops = %f"%(end_time-start_time,mflops))
#
print("\ncompile Jax version:")
def jax_calcpi(a,b,c):
i = jnp.arange(1,(NPTS+1),dtype=float)
pi = jnp.sum(a/((i-b)*(i-c)))
return pi
start_time = time.time()
pi = jax_calcpi(a0,b0,c0).block_until_ready()
end_time = time.time()
print("time = %f"%(end_time-start_time))
#
print("\nrun Jax version:")
start_time = time.time()
pi = jax_calcpi(a,b,c).block_until_ready()
end_time = time.time()
mflops = NPTS*5.0/(1.0e6*(end_time-start_time))
print("NPTS = %d, pi = %f"%(NPTS,pi))
print("time = %f, estimated MFlops = %f"%(end_time-start_time,mflops))
#
print("\ncompile Jax Jit version:")
jax_jit_calcpi = jax.jit(jax_calcpi)
start_time = time.time()
pi = jax_jit_calcpi(a0,b0,c0).block_until_ready()
end_time = time.time()
print("time = %f"%(end_time-start_time))
#
print("\nrun Jax Jit version:")
start_time = time.time()
pi = jax_jit_calcpi(a,b,c).block_until_ready()
end_time = time.time()
mflops = NPTS*5.0/(1.0e6*(end_time-start_time))
print("NPTS = %d, pi = %f"%(NPTS,pi))
print("time = %f, estimated MFlops = %f"%(end_time-start_time,mflops))
......@@ -15,6 +15,7 @@
|1,090|[numbapig.py](Python/numbapig.py)|Python, Numba, CUDA, 5120 cores|NVIDIA V100|March, 2020|
|1,062|[taichipi.py](Python/taichipi.py)|Python, Taichi, 5120 cores|NVIDIA V100|March, 2023|
|811|prior|Cray XT4|C, MPI, 2048 processes|prior|
|604|[jaxpi.py](Python/jaxpi.py)|Python, Jax, 5120 cores|NVIDIA V100|December, 2024|
|501|[rayonpi.rs](Rust/rayonpi.rs)|Rust, Rayon, 96 cores<br>cargo run --release|Graviton4|December, 2024|
|484|[threadpi.rs](Rust/threadpi.rs)|Rust, threads, 96 cores<br>cargo run --release -- 96|Graviton4|December, 2024|
|315|[numbapip.py](Python/numbapip.py)|Python, Numba, parallel, fastmath<br>96 cores|Intel 2x Xeon Platinum 8175M|February, 2020|
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment