#!/usr/bin/perl

#
# Program finds the posterior sample alignment that has the minimum average distance to all other samples.
# Calculating the pairwise distances between alignments is O(N^2) where N is the number of samples.  Substantial
# thinning is advised.
#
# Command-line: pairwise-alignment-distances [size of thinned posterior] < [posterior sample file]
#
# By: Marc A. Suchard (2/26/06)
#

$thinto = $ARGV[0]; # Must be an integer
if( $thinto == 0 ) {
    $thinto = 100;
}

print STDERR "Reading posterior from <STDIN>.\n";
# Parse posterior samples.  Save all thinned samples in an array
@samplearray = ();
@samplepairsarray = ();
@iterarray = ();
@samplealignment = ();
{ 
    local $/ = undef;
    my $bigstring = <STDIN>;
    $bigstring =~ /iterations = (\d+)/s;
    $start = $1;
    $bigstring =~ /.*iterations = (\d+)/s;
    $end = $1;
    @chunks = split("\n\n(?=it)",$bigstring);
}
print STDERR "Successfully read in <STDIN>.\n";
$approxsamples = int (($end - $start)/10);
print STDERR "There are approximately $approxsamples alignments.\n";
$thin = int ($approxsamples / $thinto);
$predict = $approxsamples / $thin;
if( $predict > $thinto ) {
    $thin++;
}
print STDERR "Subsampling every $thin alignments to generate $thinto thinned samples.\n";
#exit(0);
$pick = 0;

print STDERR "Creating all pairwise homologies statements...\n";
foreach $chunk (@chunks) { # Iterate over samples
    if( $chunk =~ /align\[sample/ ) {
	if( ($pick % $thin) == 0 ) {
	    # This sample contains an alignment
	    $chunk =~ /iterations = (\d+)/;
	    $iter = $1;
	    #$chunk =~ /align.*? =.?\n(.*)\n\n\n/s; # returns whole alignment (including internals)
	    $chunk =~ /align.*? =.?\n(.*?)>A\d+\s+?\n/s;  # returns only observed alignment 
	    my $alignment = $1; 
	    my $sample = parse_fasta($alignment);
	    my $samplepairs  = pairwise_homologies( $sample->{"array"} );
	    push(@samplealignment, $alignment);
	    push(@samplearray, $sample);
	    push(@samplepairsarray, $samplepairs);	    
	    push(@iterarray, $iter);
	}
	$pick++;
    }
}
print STDERR "Done.\n";

# Retrieve number of taxa and lengths

@all_lengths = @{$samplearray[0]->{"lengths"}};
$ntaxa  = scalar(@all_lengths);
$npairs = scalar(@{$samplepairsarray[0]});
$nsamples = scalar(@iterarray);
#print "Ntaxa = $ntaxa\nNpairs = $npairs\nNSamples = $nsamples\n";
print STDERR "Processing distance matrix for $nsamples samples...\n";
$standard = 0;
for(my $i=0; $i<$ntaxa; $i++) {
    $standard += $all_lengths[$i];
}
$standard *= $ntaxa - 1;


# Generate distance matrix

for(my $i=0; $i<$nsamples; $i++) {
    for(my $j=$i+1; $j<$nsamples; $j++) {
    #for(my $j=0; $j<$nsamples; $j++) {
	# Calculate distance between sample i and j
	my $sum = 0;
	my $k = 0;
	for(my $k=0; $k<$npairs; $k++) {
	    $sum += distance( $samplepairsarray[$i]->[$k], $samplepairsarray[$j]->[$k]);
	}
	$sum += $standard;
	$distance[$i][$j] = $distance[$j][$i] = ($sum / $standard);
	#$distance[$i][$j] = $sum; # used to check for symmetry of distance measure
    }
}
print_matrix( \@distance );

# Find sample with min average distance to all others
@rowtotal = ();
$minsum = 0;
$minindex = 0;
print "Average distances:";
for(my $i=0; $i<$nsamples; $i++) {
    $cnt = 0;
    for(my $j=0; $j<$nsamples; $j++) {
	$cnt += $distance[$i][$j];
    }
    $cnt /= $nsamples-1;
    $rowtotal[$i] = $cnt;
    if( $i == 0 ) {
	$minsum = $cnt;
	$minindex = $i;
    } else {
	if( $cnt < $minsum ) {
	    $minsum = $cnt;
	    $minindex = $i;
	}
    }
    $str = sprintf("%0.4f",$cnt);
    print " $str";
}
print "\n";
$reportindex = $minindex + 1;
$str = sprintf("%0.4f",$minsum);
print "Minimum average distance = $str ($reportindex)\n";
print "\n\n$samplealignment[$minindex]\n";

#
# Subroutine prints a square matrix of floating-points
#
sub print_matrix {
    my $x = shift;
    my $ylen = scalar(@$x);
    my $xlen = scalar(@{$x->[0]});
    print "Matrix: $ylen-by-$xlen\n";
    for(my $i=0; $i<$ylen; $i++) {
	for(my $j=0; $j<$xlen; $j++) {
	    $str = sprintf("%0.4f", $x->[$i]->[$j]);
	    print "$str ";
	}
	print "\n";
    }
}


     
#
# Subroutine returns a hash of homology statements from two aligned arrays (sequences)
#

sub get_homologies {
    my ($x, $y) = @_;
    my %homologies = ();
    my $len = scalar(@$x);
    for(my $k=0; $k<$len; $k++) {
	if( $x->[$k] == 0 ) { # gap in first sequence
	    if( $y->[$k] > 0 ) { # no gap in second sequence
		# generate insertion event
		$homologies{'I'}{$y->[$k]} = 1;
		#push(@{$homologies{'I'}}, $y->[$k]);
	    } # else both are gaps and ignore
	} else { # no gap in first sequence
	    if( $y->[$k] == 0) { # gap in second sequence
		# generate deletion event
		$homologies{'D'}{$x->[$k]} = 1;
		#push(@{$homologies{'D'}}, $x->[$k]);
	    } else { # no gap in either sequence
		# generate match event
		$homologies{'M'}{"$x->[$k]:$y->[$k]"} = 1;
		#push(@{$homologies{'M'}}, "$x->[$k]:$y->[$k]");
	    }
	}
    }
    return \%homologies
}

#
# Subroutine to return distance between two homology hashes
# NB: true-distance = distance + length_x + length_y
#

sub distance { 
    my ($x, $y) = @_;
    return - 2 * card_intersection( $x->{'M'}, $y->{'M'} )
	       - card_intersection( $x->{'D'}, $y->{'D'} )
	       - card_intersection( $x->{'I'}, $y->{'I'} );
}
					 
#
# Subroutine to return the cardinality of the intersection of two hashes
#

sub card_intersection {
    my ($x, $y) = @_;
    my $count = 0;
    for $e (keys %{$x}) {
	#print "looking at $e: ";
	if( $y->{$e} ) {
	    $count++;
	}
    }
    return $count;
}

#
# Subroutine to return the cardinality of a hash
#

sub cardinality {
    my ($x) = @_;
    return scalar( keys %{$x} );
}

#
# Subroutine parses a string containing an entire fasta alignment
#

sub parse_fasta {
	my ($x) = shift;
	my @blocks = split("(?=>)",$x);
	my $nseq = scalar(@blocks);
	my @names = ();
	my @seq = ();
	for(my $i=0; $i<$nseq; $i++) {
		$blocks[$i] =~ />(\w+).*?\n(.*)\n?/m;
		$names[$i] = $1;
		my @tmp = split("",$2);
		my $index = 1;
		my $seqlen = scalar(@tmp);
		for(my $k=0; $k<$seqlen; $k++) {
			if( $tmp[$k] eq '-') {
				$seq[$i][$k] = 0;
			} else {
				$seq[$i][$k] = $index++;
			}
		}
		$length[$i] = $index - 1;
	}
	my %alignment = ();
	$alignment{"names"} = \@names;
	$alignment{"array"} = \@seq;
	$alignment{"lengths"} = \@length;
	return \%alignment;
}

#
# Subroutine prints out a set of homology statements
#

sub print_homologies {
	my $x = shift;
	my %homology = %{$x};
	for $type ('M','I','D') {
	    print "$type: ";
	    for $key( keys %{$homology{$type}} ) {
		print "$key ";
	    }
	    print "\n";
	}
}

#
# Subroutine returns an array of all pairwise homology statements in an alignment
#

sub pairwise_homologies {
	my $x = shift;
	my @seq = @{$x};
	my $ntaxa = scalar(@seq);
	my @homologies = ();
	for(my $i=0; $i<$ntaxa; $i++) {
		for(my $j=$i+1; $j<$ntaxa; $j++) {
			my %homology = %{ get_homologies( \@{$seq[$i]}, \@{$seq[$j]} ) };	
			push(@homologies, \%homology);
		}
	}
	return \@homologies;
}
