#!/usr/bin/env python
from __future__ import division
from numpy import *
import argparse
from pyframe3dd.frame3dd import write_frame3dd_file, read_lowest_mode, read_frame3dd_displacements, compute_mass
from pyframe3dd.util import magnitudes, close
import subprocess
import matplotlib.pyplot as plt
plt.style.use('bmh')

def plot_connections(nodes,beamsets):
    #for debug only, this is slow!
    import matplotlib as mpl
    from mpl_toolkits.mplot3d import Axes3D
    import numpy as np
    import matplotlib.pyplot as plt
    fig = plt.figure()
    cmap = plt.cm.get_cmap('Dark2', len(beamsets)+1)
    ax = fig.gca(projection='3d')
    for i,beamset in enumerate(beamsets):
	    for seg in beamset:
	        ax.plot(nodes[seg,0], nodes[seg,1], nodes[seg,2], c=cmap(i))
    plt.show()


def get_rotation_from_nodes(nodes,axis,disps,targets):
    base = asarray(axis[0]); 
    a = axis[1]-base; 
    a = a/magnitudes(a) #set up axis coordinates
    x = nodes[targets]-base
    d = x - dot(x, a)[...,None]*a
    y = [dot(a,b) for a,b in zip(disps[targets,:3], cross(d,a))]
    return arctan2(y,magnitudes(d)),magnitudes(d)

def run_frame3dd(args,nodes,global_args,beam_sets,constraints,loads):
	write_frame3dd_file(nodes,global_args,beam_sets,constraints,loads)
	cmd = ["frame3dd", "-i",global_args['frame3dd_filename']+'.csv']
	if args.quiet: cmd.append("-q")
	print ' '.join(cmd)
	subprocess.call(cmd)

def clean_up_frame3dd(filename):
	#Delete files generated by frame3dd
	files = [filename+end for end in ["_out.csv",".csv.out",".csv.plt",".csv.if01",".csv"]]
	subprocess.call(["rm"]+files)

def build(args):
	#return nodes,rods as numpy arrays
	dxy = args.attach_radius/sqrt(2)
	nodes = array([[dxy,dxy,0],[-dxy,dxy,0],[-dxy,-dxy,0],[dxy,-dxy,0]])
	solid_beams = array([ [0,1],[1,2],[2,3],[3,0],[0,2],[1,3] ])
	z_os = array([0,0,.5*args.sep])
	nodes = vstack(( nodes+z_os, nodes-z_os ))
	solid_beams = vstack((solid_beams, solid_beams + 4))
	solid_beams = vstack((solid_beams, array([ 
		[0,4],[1,5],[2,6],[3,7],
		[0,5],[1,6],[2,7],[3,4],
		[0,7],[1,4],[2,5],[3,6] ])))
	#sensor nodes
	nodes = vstack((nodes, array([ [args.sensor_radius,0,0],[0,args.sensor_radius,0],[-args.sensor_radius,0,0],[0,-args.sensor_radius,0] ])))
	solid_beams = vstack((solid_beams, array([ [0,8],[3,8],[0,9],[1,9],[1,10],[2,10],[2,11],[3,11] ]) ))
	solid_beams = vstack((solid_beams, array([ [4,8],[7,8],[4,9],[5,9],[5,10],[6,10],[6,11],[7,11] ]) ))

	if args.flexure_type == 'cyclic':
		if args.chamfer > 0:
			l = args.l; ch = args.chamfer
			#chamfer the flexure	
			flexure_nodes = array([[0,l-ch,0],[-ch,l,0],[-l,l,0]]) + array([dxy,dxy,0])
			flexure_nodes = vstack(( flexure_nodes, array([[l-ch,0,0],[l,ch,0],[l,l,0]]) + array([dxy,-dxy,0]) ))
			#append reflection
			flexure_nodes = vstack((flexure_nodes,array([[-n[0],-n[1],0] for n in flexure_nodes])))
			flexure_beams = array([[0,12],[12,13],[13,14],[3,15],[15,16],[16,17],[2,18],[18,19],[19,20],[1,21],[21,22],[22,23]])
			#append both plates
			flexure_nodes = vstack(( flexure_nodes + z_os, flexure_nodes - z_os))
			flexure_beams = vstack(( flexure_beams, array([[4,24],[24,25],[25,26],[7,27],[27,28],[28,29],[6,30],[30,31],[31,32],[5,33],[33,34],[34,35]]) ))
			fixed_nodes = [14,17,20,23,26,29,32,35]
		else:
			flexure_nodes = args.l*array([[0,1,0],[-1,1,0]]) + array([dxy,dxy,0])
			flexure_nodes = vstack(( flexure_nodes, args.l*array([[1,0,0],[1,1,0]]) + array([dxy,-dxy,0]) ))
			#append reflection
			flexure_nodes = vstack((flexure_nodes,array([[-n[0],-n[1],0] for n in flexure_nodes])))
			flexure_beams = array([[0,12],[12,13],[3,14],[14,15],[2,16],[16,17],[1,18],[18,19]])
			#append both plates
			flexure_nodes = vstack(( flexure_nodes + z_os, flexure_nodes - z_os))
			flexure_beams = vstack(( flexure_beams, array([[4,20],[20,21],[7,22],[22,23],[6,24],[24,25],[5,26],[26,27]]) ))
			fixed_nodes = [13,15,17,19,21,23,25,27]
	elif args.flexure_type == 'mirrored':
		flexure_nodes = args.l*array([[0,1,0],[-1,1,0]]) + array([dxy,dxy,0])
		flexure_nodes = vstack(( flexure_nodes, args.l*array([[1,0,0],[1,-1,0]]) + array([dxy,dxy,0]) ))
		#append reflection
		flexure_nodes = vstack((flexure_nodes,array([[-n[0],-n[1],0] for n in flexure_nodes])))
		flexure_beams = array([[0,12],[12,13],[0,14],[14,15],[2,16],[16,17],[2,18],[18,19]])
		#append both plates
		flexure_nodes = vstack(( flexure_nodes + z_os, flexure_nodes - z_os))
		flexure_beams = vstack(( flexure_beams, array([[4,20],[20,21],[4,22],[22,23],[6,24],[24,25],[6,26],[26,27]]) ))
		fixed_nodes = [13,15,17,19,21,23,25,27]


	nodes = vstack((nodes, flexure_nodes))
	return nodes, flexure_beams, solid_beams, fixed_nodes

def run_simulation(args):
	#set up simulation
	nodes,beams,solid_beams,fixed_nodes = build(args)
	global_args = {
		'n_modes':args.n_modes,'length_scaling':args.length_scaling,'exagerration':10,
		'zoom_scale':2.,'node_radius':zeros(shape(nodes)[0]),
		'frame3dd_filename':args.base_filename+"_frame3dd"
	}
	clean_up_frame3dd(global_args['frame3dd_filename'])
	beam_sets = [
	(beams,{'E':args.E,'nu':args.nu,'rho':args.rho,'cross_section':'rectangular','d2':args.w,'d1':args.t,'roll':0.,'loads':[],'beam_divisions':args.bd,'prestresses':[]}),
	(solid_beams,{'E':10*args.E,'nu':args.nu,'rho':args.rho,'cross_section':'rectangular','d1':.003,'d2':.003,'roll':0.,'loads':[],'beam_divisions':1,'prestresses':[]})
	]
	constraints = [{'node':node,'DOF':dof,'value':0} for dof in [0,1,2,3,4,5] for node in fixed_nodes]

	
	#loaded_nodes = [0,5,10,15,20,21,22,23]
	#sensor_nodes = [24,25,26,27]
	loaded_nodes = range(8)
	sensor_nodes = [8,9,10,11]

	results = []

	for force_dof in [0,1,2]:
		loads = [{'node':n,'DOF':force_dof,'value':args.force/len(loaded_nodes)} for n in loaded_nodes]
		run_frame3dd(args,nodes,global_args,beam_sets,constraints,loads)
		disps = read_frame3dd_displacements(global_args['frame3dd_filename'])
		force_disp = average(disps[loaded_nodes,force_dof])
		#print "Degree of freedom: %d"%force_dof
		#print "Force applied: %.1f N"%(args.force)
		#print "Displacement at sensor: %.1f microns"%(force_disp*1e6)
		results.append( {'dof':force_dof, 'force/torque':args.force, 'displacement':force_disp} ) 
	#sys.exit(0)
	#Tx, Ty
	torque_force = args.torque/(.5*args.sep)/len(loaded_nodes) # F = torque / dist
	for torque_dof in [0,1]:
		loads = [{'node':n,'DOF':torque_dof,'value':torque_force if nodes[n][2]>0 else -torque_force} for n in loaded_nodes]
		run_frame3dd(args,nodes,global_args,beam_sets,constraints,loads)
		disps = read_frame3dd_displacements(global_args['frame3dd_filename'])
		moving_sensor_nodes = [8,10] if torque_dof==0 else [9,11]

		#print disps[sensor_nodes]
		#axis = (array([0,0,0]), array([0,-1,0]) if torque_dof==0 else array([1,0,0]) )
		#rots,ds = get_rotation_from_nodes(nodes,axis,disps/global_args['length_scaling'],loaded_nodes)
		#print average(rots)
		#print "Degree of freedom: %d"%torque_dof
		#print "Torque applied: %.1f Nm"%(args.torque)
		#print "Torque force: %.2f N"%torque_force
		#print "Displacement at sensor: %.1f microns"%(sin(average(rots))*args.sensor_radius*1e6)	
		#print "Displacement at sensor: %.1f microns"%( average(abs(disps[moving_sensor_nodes,2]))*1e6 )
		results.append( {'dof':torque_dof+3, 'force/torque':args.torque, 'displacement':average(abs(disps[moving_sensor_nodes,2]))} ) 

	#Tz
	torque_force = args.torque/args.attach_radius/len(loaded_nodes) # F = torque / dist
	def node_to_force(n):
		return array([-nodes[n,1],nodes[n,0],0]) / sqrt(nodes[n,0]**2 + nodes[n,1]**2)
	loads = [{'node':n,'DOF':i,'value':torque_force*node_to_force(n)[i]} for n in loaded_nodes for i in [0,1]]
	run_frame3dd(args,nodes,global_args,beam_sets,constraints,loads)
	disps = read_frame3dd_displacements(global_args['frame3dd_filename'])
	#axis = (array([0,0,0]), array([0,0,1]))
	#rots,ds = get_rotation_from_nodes(nodes,axis,disps/global_args['length_scaling'],loaded_nodes)
	#print average(rots)
	#print "Degree of freedom: %d"%2
	#print "Torque applied: %.1f Nm"%(args.torque)
	#print "Torque force: %.2f N"%torque_force
	#print "Displacement at sensor: %.1f microns"%(sin(average(rots))*args.sensor_radius*1e6)
	#print "Displacement at sensor: %.1f microns"%( average(magnitudes(disps[sensor_nodes,:3]))*1e6 )
	results.append( {'dof':5, 'force/torque':args.torque, 'displacement':average(magnitudes(disps[sensor_nodes,:3]))} ) 


	#todo: plot displacements vs. design parameters
	return results

if __name__ == '__main__':
	parser = argparse.ArgumentParser()
	parser.add_argument('-M','--mode',choices=('simulate','graph', 'visualize'), required=True)
	parser.add_argument('-flexure_type','--flexure_type',choices=('cyclic','mirrored'), required=True)
	parser.add_argument('-Q','--quiet',action='store_true',help='Whether to suppress frame3dd output')
	parser.add_argument("-f","--force",  type=double, default=.1, help="force to apply (N)")
	parser.add_argument("-torque","--torque",  type=double, default=1., help="torque to apply (Nm)")
	#parser.add_argument("-fr","--force_res", type=double, default=.01, help="Final resolution of force for search mode")

	parser.add_argument("-w","--w", type=double, default=.0005, help="width of flexure (m)")
	parser.add_argument("-t","--t", type=double, default=.0023, help="thickness of flexure material (m)")
	parser.add_argument("-l","--l", type=double, default=.0068, help="length of flexure segment (m)")
	parser.add_argument("-attach_radius","--attach_radius", type=double, default=.0043, help="distance from z axis to flexure attachment (m)")
	parser.add_argument("-sep","--sep", type=double, default=.0185, help="flexure plate z separation (m)")
	parser.add_argument("-sensor_radius","--sensor_radius", type=double, default=.012, help="distance from rotation axis to sensor (m)")
	parser.add_argument("-chamfer","--chamfer", type=double, default=.001, help="chamfer length for flexure (m), zero for no chamfer")

	parser.add_argument("-bd","--bd", type=int, default=1, help='how many divisions for each rod, useful in buckling analysis')
	parser.add_argument("-E","--E", type=double, default=70e9, help="Young's Modulus of laminate")
	parser.add_argument("-nu","--nu", type=double, default=.33, help="Poisson Ratio")
	parser.add_argument("-base_filename","--base_filename", default='buckle', help="Base filename for segments and frame3dd")
	parser.add_argument("-rho","--rho",type=double,default=2700.,help='density of beam material, kg/m^3')
	parser.add_argument("-n_modes","--n_modes",type=int,default=0,help='number of dynamic modes to compute')
	parser.add_argument("-ls","--length_scaling", type=double, default=1., help="Scale factor to keep numbers commesurate")
	args = parser.parse_args()

	#if args.mode=='search':
	#	forces,freqs,last_res = find_stability_threshold(args)
	#	print "Fundamental frequency: %.3f Hz"%(freqs[-1])
	#	print "Critical force: %.3f N"%(forces[-1])
	#	print "Critical stress: %.3f MPa"%(last_res['stress']/1e6)
	if args.mode == 'graph':
		ws = linspace(.000, .9*args.l, 10)
		res = {}
		for wi in ws:
			#args.w = wi
			args.chamfer = wi
			res[wi] = run_simulation(args)

		X = [1e6*res[wi][0]['displacement'] for wi in ws]
		Y = [1e6*res[wi][1]['displacement'] for wi in ws]
		Z = [1e6*res[wi][2]['displacement'] for wi in ws]
		rX = [1e6*res[wi][3]['displacement'] for wi in ws]
		rY = [1e6*res[wi][4]['displacement'] for wi in ws]
		rZ = [1e6*res[wi][5]['displacement'] for wi in ws]
		print ws,X,Y,Z
		plt.plot( 1e3*ws, X, label='X' )
		plt.plot( 1e3*ws, Y, label='Y' )
		plt.plot( 1e3*ws, Z, label='Z' )
		plt.plot( 1e3*ws, rX, label='rX' )
		plt.plot( 1e3*ws, rY, label='rY' )
		plt.plot( 1e3*ws, rZ, label='rZ' )
		plt.ylabel('displacement at sensor (microns)')
		plt.xlabel('chamfer width (mm)')
		plt.xlim([1e3*ws[0],1e3*ws[-1]])
		plt.legend(loc='upper right')
		plt.show()
	elif args.mode=='simulate':
		res = run_simulation(args)
		print res
		#print "Fundamental frequency: %.3f Hz"%res['fundamental_frequency']
		#print "Stress: %.3f MPa"%(res['stress']/1e6)
	elif args.mode=='visualize':
		nodes,rods,solid_beams,fixed_nodes = build(args)
		plot_connections(nodes,[rods,solid_beams])
	else:
		assert(0) #should not be here



'''
def find_stability_threshold(args):
	#out loop of simulations to determine the buckling load
	lower = 0 #lower bound
	upper = 10*args.force_res #initial upper bound before bracketing
	bracketed=False
	#actually not necessary, but fun to have the unloaded frequency
	args.force = lower
	res = run_simulation(args)	
	freqs = [res['fundamental_frequency']]
	forces = [args.force]

	i = 0
	while not bracketed:
		print lower,upper,bracketed,res['fundamental_frequency']
		args.force = upper
		res = run_simulation(args); i += 1
		if res['fundamental_frequency']<0:
			bracketed=True
		else:
			freqs.append(res['fundamental_frequency'])
			forces.append(args.force)
			lower = upper
			upper = 2*upper
	while (upper-lower > args.force_res):
		print lower,upper,bracketed
		args.force = .5*(upper+lower)
		res = run_simulation(args); i += 1
		if res['fundamental_frequency']>0:
			freqs.append(res['fundamental_frequency'])
			forces.append(args.force)
			lower = .5*(upper+lower)
		else:
			upper = .5*(upper+lower)
	return forces,freqs,res
'''
