import mpi, string, sys

# function we're integrating
def f(x): return 4.0 / (1.0 + x**2)

# first processor reads in the number of
# intervals to use in the midpoint rule
# and then sends it to all the other processors
# who receive it.

if mpi.rank == 0:
    # sys.argv and sys.argc behave as argc, argv in C (command-line
    # arguments, but they are loaded into variables in the sys
    # module at startup
    n = string.atoi(sys.argv[1])

# get intial time from the mpi "wall clock"
t1 = mpi.wtime()


# send the total number of intervals to everybody else
if mpi.rank == 0:
    for i in range(1,mpi.size):
        mpi.send(n,i)
else:
    n,status = mpi.recv(0)

# everybody creates a local copy of h, local_sum
h = 1.0 / n
local_sum = 0.0

# loop over the number of intervals on my processor
for i in range( mpi.rank+1,n+1,mpi.size ):
    x = h*(i-0.5)     # midpoint of interval
    y = f(x)          # function at midpoint
    local_sum += y    # add to sum

local_sum *= h        # scale by size of intervals

# everybody sends their own local_sum to processor 0,
# which receives it and adds it up
if mpi.rank == 0:
    global_sum = local_sum
    for i in range(1,mpi.size):
        local_i, status = mpi.recv(i)
        global_sum += local_i
else:
    mpi.send(local_sum,0)


# call the "wall clock" again and subtract t1 to get the time elapsed
# since the last call
if mpi.rank == 0:
    print "time: " , mpi.wtime() - t1

# first processor prints out the result
if mpi.rank == 0:
    print 'PI is about ' , global_sum


