Newer
Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#
# jaxpi.py
# Neil Gershenfeld 12/21/24
# Jax pi calculation benchmark
# pi = 3.14159265358979323846
#
import jax
import jax.numpy as jnp
import numpy as np
import time
#
NPTS = 100000000
#
a = 0.5
b = 0.75
c = 0.25
#
# alternate compilation values to prevent caching
#
a0 = 0.6
b0 = 0.7
c0 = 0.2
#
print("\nNumPy version:")
def num_calcpi(a,b,c):
i = np.arange(1,(NPTS+1),dtype=float)
pi = np.sum(a/((i-b)*(i-c)))
return pi
start_time = time.time()
pi = num_calcpi(a,b,c)
end_time = time.time()
mflops = NPTS*5.0/(1.0e6*(end_time-start_time))
print("NPTS = %d, pi = %f"%(NPTS,pi))
print("time = %f, estimated MFlops = %f"%(end_time-start_time,mflops))
#
print("\ncompile Jax version:")
def jax_calcpi(a,b,c):
i = jnp.arange(1,(NPTS+1),dtype=float)
pi = jnp.sum(a/((i-b)*(i-c)))
return pi
start_time = time.time()
pi = jax_calcpi(a0,b0,c0).block_until_ready()
end_time = time.time()
print("time = %f"%(end_time-start_time))
#
print("\nrun Jax version:")
start_time = time.time()
pi = jax_calcpi(a,b,c).block_until_ready()
end_time = time.time()
mflops = NPTS*5.0/(1.0e6*(end_time-start_time))
print("NPTS = %d, pi = %f"%(NPTS,pi))
print("time = %f, estimated MFlops = %f"%(end_time-start_time,mflops))
#
print("\ncompile Jax Jit version:")
jax_jit_calcpi = jax.jit(jax_calcpi)
start_time = time.time()
pi = jax_jit_calcpi(a0,b0,c0).block_until_ready()
end_time = time.time()
print("time = %f"%(end_time-start_time))
#
print("\nrun Jax Jit version:")
start_time = time.time()
pi = jax_jit_calcpi(a,b,c).block_until_ready()
end_time = time.time()
mflops = NPTS*5.0/(1.0e6*(end_time-start_time))
print("NPTS = %d, pi = %f"%(NPTS,pi))
print("time = %f, estimated MFlops = %f"%(end_time-start_time,mflops))