Skip to content
Snippets Groups Projects
jaxpi.py 1.68 KiB
Newer Older
  • Learn to ignore specific revisions
  • Neil Gershenfeld's avatar
    wip
    Neil Gershenfeld committed
    #
    # jaxpi.py
    # Neil Gershenfeld 12/21/24
    # Jax pi calculation benchmark
    # pi = 3.14159265358979323846
    #
    import jax
    import jax.numpy as jnp
    import numpy as np
    import time
    #
    NPTS = 100000000
    #
    a = 0.5
    b = 0.75
    c = 0.25
    #
    # alternate compilation values to prevent caching
    #
    a0 = 0.6
    b0 = 0.7
    c0 = 0.2
    #
    print("\nNumPy version:")
    def num_calcpi(a,b,c):
       i = np.arange(1,(NPTS+1),dtype=float)
       pi = np.sum(a/((i-b)*(i-c)))
       return pi
    start_time = time.time()
    pi = num_calcpi(a,b,c)
    end_time = time.time()
    mflops = NPTS*5.0/(1.0e6*(end_time-start_time))
    print("NPTS = %d, pi = %f"%(NPTS,pi))
    print("time = %f, estimated MFlops = %f"%(end_time-start_time,mflops))
    #
    print("\ncompile Jax version:")
    def jax_calcpi(a,b,c):
       i = jnp.arange(1,(NPTS+1),dtype=float)
       pi = jnp.sum(a/((i-b)*(i-c)))
       return pi
    start_time = time.time()
    pi = jax_calcpi(a0,b0,c0).block_until_ready()
    end_time = time.time()
    print("time = %f"%(end_time-start_time))
    #
    print("\nrun Jax version:")
    start_time = time.time()
    pi = jax_calcpi(a,b,c).block_until_ready()
    end_time = time.time()
    mflops = NPTS*5.0/(1.0e6*(end_time-start_time))
    print("NPTS = %d, pi = %f"%(NPTS,pi))
    print("time = %f, estimated MFlops = %f"%(end_time-start_time,mflops))
    #
    print("\ncompile Jax Jit version:")
    jax_jit_calcpi = jax.jit(jax_calcpi)
    start_time = time.time()
    pi = jax_jit_calcpi(a0,b0,c0).block_until_ready()
    end_time = time.time()
    print("time = %f"%(end_time-start_time))
    #
    print("\nrun Jax Jit version:")
    start_time = time.time()
    pi = jax_jit_calcpi(a,b,c).block_until_ready()
    end_time = time.time()
    mflops = NPTS*5.0/(1.0e6*(end_time-start_time))
    print("NPTS = %d, pi = %f"%(NPTS,pi))
    print("time = %f, estimated MFlops = %f"%(end_time-start_time,mflops))