package amadeus;

import java.io.BufferedInputStream;
import java.io.DataInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.util.Vector;

import edu.tau.compbio.math.Stat;

public class ComparePwms {

	public enum ComparisonPolicy
	{
		PEARSON_SIM,
		EUCLIDIAN_DIST,
		KL_DIVERGENCE
	}
	
	private final static String USAGE = 
		"Usage:java ComparePwms  <policy P=PEARSON E=EUCLIDIAN K=KL divergence> <Pwm1File> <Pwm2File> <minAlignLen>";
	private final static String HELP = "-h";
	private final static int ALPHABET = 4;
	private final static int PREFIX = 3;
	
	public static class MotifMatrixAlignmentScore {

		double   corrScore;
		int      alignLength;
		boolean  revCompFlag;
		
		/**
		 * CTOR.
		 * @param score   - the alignment correlation score (betweem -1 and 1)
		 * @param length  - the alignment length (on which the score was computed)
		 * @param revComp - true if the score was computed when one of the matrices was reverse-complemented
		 */
		public MotifMatrixAlignmentScore (double score, int length, boolean revComp)
		{
			corrScore = score;
			alignLength = length;
			revCompFlag = revComp;
		}
		
		protected void setRevCompFlag (boolean revComp)
		{
			revCompFlag = revComp;
		}
		
		/**
		 * @return the alignment's correlation score (between -1 and 1)
		 */
		public double getScore()  
		{ 
			return (corrScore); 
		}
		
		/**
		 * @return the alignment's length
		 */
		public int getLength()
		{
			return (alignLength);
		}
		
		/**
		 * @return true if the score was computed when one of the matrices was reverse-complemented,
		 * i.e., on the opposite strand
		 */
		public boolean getRevCompFlag()
		{
			return (revCompFlag);
		}
	}
	
	/*
	 * The algorithm implementation
	 * (Note: the revComp flag in the returned MotifMatrixAlignmentScore object is always false,
	 *  it's up to the caller to change it if needed)
	 */
	public static MotifMatrixAlignmentScore alignMatrices (ComparisonPolicy policy, double[][] mat1, double[][] mat2, 
			int minLen, boolean averageGiven, double average1, double average2)
	{
		int minAlgnLen = Stat.MIN3(mat1.length-2, mat2.length-2, minLen);

		double bestScore = (policy.equals(ComparisonPolicy.PEARSON_SIM)) ? -1.0 : Double.MAX_VALUE;
		
		double currAlgnScore;
		int currAlgnLength;
		int maxStart1Idx = mat1.length - minAlgnLen;
		int maxStart2Idx = mat2.length - minAlgnLen;
		int alignmentLength = 0;

		for (int startIdx2 = 0; startIdx2 <= maxStart2Idx; startIdx2++)
		{
			currAlgnScore = 0.0;
			currAlgnLength = Math.min(mat1.length, mat2.length - startIdx2);

			for (int i = 0, j = startIdx2; 
				i < currAlgnLength; i++, j++)
			{		
				switch (policy) {
				case PEARSON_SIM:
					if (averageGiven)
					{
						currAlgnScore += 
							Stat.pearsonCorrCoeff(mat1[i], mat2[j], average1, average2);
					}
					else
					{
						currAlgnScore += 
							Stat.pearsonCorrCoeff(mat1[i], mat2[j]);
					}
					break;
				case EUCLIDIAN_DIST:
					currAlgnScore += Stat.euclidianDist(mat1[i], mat2[j]);
					break;
				case KL_DIVERGENCE:
					currAlgnScore += Stat.klDivergence(mat1[i], mat2[j]);
					break;
				}
			}
			
			currAlgnScore /= currAlgnLength;
		
			switch (policy) {
				case PEARSON_SIM:		
					if (currAlgnScore > bestScore)
					{
						bestScore = currAlgnScore;
						alignmentLength = currAlgnLength;
					}
					break;
				case KL_DIVERGENCE:
					if (currAlgnScore < bestScore)
					{
						bestScore = currAlgnScore;
						alignmentLength = currAlgnLength;
					}
					break;
				case EUCLIDIAN_DIST:
					currAlgnScore /= Stat.SQRT2;
					if (currAlgnScore < bestScore)
					{
						bestScore = currAlgnScore;
						alignmentLength = currAlgnLength;
					}
					break;
			}
		}
		
		for (int startIdx1 = 1; startIdx1 <= maxStart1Idx; startIdx1++)
		{
			currAlgnScore = 0.0;
			currAlgnLength = Math.min(mat2.length, mat1.length - startIdx1);
			
			for (int i = 0, j = startIdx1; i < currAlgnLength; i++, j++)
			{
				switch (policy) {
					case PEARSON_SIM:
						if (averageGiven)
						{
							currAlgnScore += 
								Stat.pearsonCorrCoeff(mat1[j], mat2[i], average1, average2);
						}
						else
						{
							currAlgnScore += 
								Stat.pearsonCorrCoeff(mat1[j], mat2[i]);
						}
						break;
					case EUCLIDIAN_DIST:
						currAlgnScore += Stat.euclidianDist(mat1[j], mat2[i]);
						break;
					case KL_DIVERGENCE:
						currAlgnScore += Stat.klDivergence(mat1[j], mat2[i]);
						break;
				}
			}
			currAlgnScore /= currAlgnLength;

			switch (policy) {
			case PEARSON_SIM:		
				if (currAlgnScore > bestScore)
				{
					bestScore = currAlgnScore;
					alignmentLength = currAlgnLength;
				}
				break;
			case KL_DIVERGENCE:
				if (currAlgnScore < bestScore)
				{
					bestScore = currAlgnScore;
					alignmentLength = currAlgnLength;
				}
				break;
			case EUCLIDIAN_DIST:
				currAlgnScore /= Stat.SQRT2;
				if (currAlgnScore < bestScore)
				{
					bestScore = currAlgnScore;
					alignmentLength = currAlgnLength;
				}
				break;
			}		
		}
		return (new MotifMatrixAlignmentScore(bestScore, alignmentLength, false));
	}
	
	/**
	 * @param args
	 */
	public static void main(String[] args) {
	
		// checking input
		if (args.length == 0 || args[0].equals(HELP) || args.length < 4) {
			System.err.println(HELP);
			return;
		}

		String compPolicyStr = args[0];
		String PwmFile1 = args[1];
		String PwmFile2 = args[2];
		int minAlignLen = Integer.parseInt(args[3]);
	
		ComparisonPolicy compPolicy = ComparisonPolicy.EUCLIDIAN_DIST;
		if (compPolicyStr.equals("P"))
		{
			compPolicy = ComparisonPolicy.PEARSON_SIM;
		}
		else if (compPolicyStr.equals("K"))
		{
			compPolicy = ComparisonPolicy.KL_DIVERGENCE;
		}
		else if (compPolicyStr.equals("E"))
		{
			compPolicy = ComparisonPolicy.EUCLIDIAN_DIST;
		}
		else
		{
			System.err.println("ERROR: Worng comparison policy " + compPolicyStr);
			return;
		}
		
		Vector<Vector<Double > > pwm1 = new Vector<Vector<Double > >(ALPHABET);
		Vector<Vector<Double > > pwm2 = new Vector<Vector<Double > >(ALPHABET);
		readPwmFile(pwm1, PwmFile1);
		readPwmFile(pwm2, PwmFile2);

		double score = comparePwms(compPolicy, minAlignLen, pwm1, pwm2);
		System.out.println(score);
	}
	
	private static void readPwmFile(Vector<Vector<Double > > pwm, String filename) 
	{
		System.out.println("Read file " + filename);
		File file = new File(filename);
	    FileInputStream fis = null;
	    BufferedInputStream bis = null;
	    DataInputStream dis = null;

	    try {
	      fis = new FileInputStream(file);

	      // Here BufferedInputStream is added for fast reading.
	      bis = new BufferedInputStream(fis);
	      dis = new DataInputStream(bis);

	      for (int i = 0; i < ALPHABET; i++)
	      {
	    	  String line = dis.readLine();
	    	  String[] freqs = line.split("\t");
	    	  Vector<Double> freqs_vec = new Vector<Double>();
	    	  for (int j = 1; j < freqs.length; j++) {
	    		  freqs_vec.add(Double.parseDouble(freqs[j]));
	    	  }
	    	  pwm.add(freqs_vec);
	    	  if (i > 0 && pwm.get(i).size() != pwm.get(i-1).size())
	    		  throw(new IOException("Different number of frequencies"));
	      }
	      if (dis.available() != 0) {
    		  throw(new IOException("More than 4 lines"));
	      }

	      // dispose all the resources after using them.
	      fis.close();
	      bis.close();
	      dis.close();

	    } catch (FileNotFoundException e) {
	      e.printStackTrace();
	    } catch (IOException e) {
	      e.printStackTrace();
	    }
	  }
	
	private static double[][] convertToDouble(Vector<Vector<Double > > pwm) {
		double[][] rc = new double[pwm.get(0).size()][ALPHABET];
		for (int i = 0; i < ALPHABET; i++)
			for (int j = 0; j < pwm.get(i).size(); j++) {
				rc[j][i] = pwm.get(i).get(j);
			}
		return rc;
	}
	
	private static double comparePwms(ComparisonPolicy compPolicy, int minLen,
			Vector<Vector<Double > > pwm1, Vector<Vector<Double > > pwm2) 
	{
		double[][] doublePwm1 = convertToDouble(pwm1);
		double[][] doublePwm2 = convertToDouble(pwm2);
		MotifMatrixAlignmentScore score = alignMatrices(compPolicy,
				doublePwm1, doublePwm2, minLen, false, 0.0, 0.0);
		
		return score.getScore();
	}
}

