# # jaxpi.py # Neil Gershenfeld 12/21/24 # Jax pi calculation benchmark # pi = 3.14159265358979323846 # import jax import jax.numpy as jnp import numpy as np import time # NPTS = 100000000 # a = 0.5 b = 0.75 c = 0.25 # # alternate compilation values to prevent caching # a0 = 0.6 b0 = 0.7 c0 = 0.2 # print("\nNumPy version:") def num_calcpi(a,b,c): i = np.arange(1,(NPTS+1),dtype=float) pi = np.sum(a/((i-b)*(i-c))) return pi start_time = time.time() pi = num_calcpi(a,b,c) end_time = time.time() mflops = NPTS*5.0/(1.0e6*(end_time-start_time)) print("NPTS = %d, pi = %f"%(NPTS,pi)) print("time = %f, estimated MFlops = %f"%(end_time-start_time,mflops)) # print("\ncompile Jax version:") def jax_calcpi(a,b,c): i = jnp.arange(1,(NPTS+1),dtype=float) pi = jnp.sum(a/((i-b)*(i-c))) return pi start_time = time.time() pi = jax_calcpi(a0,b0,c0).block_until_ready() end_time = time.time() print("time = %f"%(end_time-start_time)) # print("\nrun Jax version:") start_time = time.time() pi = jax_calcpi(a,b,c).block_until_ready() end_time = time.time() mflops = NPTS*5.0/(1.0e6*(end_time-start_time)) print("NPTS = %d, pi = %f"%(NPTS,pi)) print("time = %f, estimated MFlops = %f"%(end_time-start_time,mflops)) # print("\ncompile Jax Jit version:") jax_jit_calcpi = jax.jit(jax_calcpi) start_time = time.time() pi = jax_jit_calcpi(a0,b0,c0).block_until_ready() end_time = time.time() print("time = %f"%(end_time-start_time)) # print("\nrun Jax Jit version:") start_time = time.time() pi = jax_jit_calcpi(a,b,c).block_until_ready() end_time = time.time() mflops = NPTS*5.0/(1.0e6*(end_time-start_time)) print("NPTS = %d, pi = %f"%(NPTS,pi)) print("time = %f, estimated MFlops = %f"%(end_time-start_time,mflops))