import mpi, string, sys

# function we're integrating
def f(x): return 4.0 / (1.0 + x**2)

# first processor reads in the number of
# intervals to use in the midpoint rule
# and then sends it to all the other processors
# who receive it.

t1 = mpi.wtime()

if mpi.rank == 0:
    n = string.atoi(sys.argv[1])
    mpi.bcast(n)
else:
    n = mpi.bcast()

# everybody creates a local copy of h, local_sum
h = 1.0 / n
local_sum = 0.0

# loop over the number of intervals on my processor
for i in range( mpi.rank+1,n+1,mpi.size ):
    x = h*(i-0.5)     # midpoint of interval
    y = f(x)          # function at midpoint
    local_sum += y    # add to sum

local_sum *= h        # scale by size of intervals

# sum up local sum over all the processors and
# store the result in global sum
# on processor 0, global_sum is the result
# it is None on the rest of the processors
global_sum = mpi.reduce( local_sum , mpi.SUM , 0 )

if mpi.rank == 0:
    print "time: " , mpi.wtime() - t1

# first processor prints out the result
if mpi.rank == 0:
    print 'PI is about ' , global_sum


