/*

File: "decipher.c"

A program which uses MCMC to decode text that has been encoded using a
simple substitution cipher (e.g. from the companion program scramble.c).

Copyright (c) 2008 by Jeffrey S. Rosenthal (http://probability.ca/jeff/)

Licensed for general copying, distribution and modification according to
the GNU General Public License (http://www.gnu.org/copyleft/gpl.html).

Compile with:  "cc -lm decipher.c -o decipher"
Run with:  "decipher trainingfile testfile"

*/

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/time.h>
#include <math.h>

#define NUMLETT 26
#define NUMITS 4000
#define PIPOWER 0.25

/*
#define VERBOSE true
#define VVERBOSE true
*/

#define PRINTEVERY 100
#define PRINTAMOUNT 1350

int traincounts[NUMLETT+1][NUMLETT+1];
int testcounts[NUMLETT+1][NUMLETT+1];
double logtraincounts[NUMLETT+1][NUMLETT+1];

main(argc, argv)
int argc;
char *argv[];
{
    int t, i, j, a, b, itnum, newrand(), newnum;
    void getcounts();
    int f[NUMLETT+1], newf[NUMLETT+1], displaynum[PRINTAMOUNT];
    char c, displaydata[PRINTAMOUNT], outname[160];
    double thelogscore(), prevlogscore, newlogscore;
    FILE *fp, *fpout;

    /* Begin reading file. */
    if (argc < 3) {
	fprintf(stderr, "Too few arguments.\n");
	exit(1);
    }

    /* Compute the actual counts. */
    getcounts(argv[1], traincounts);
    getcounts(argv[2], testcounts);
    printf("Computing logarithms of pair counts ... ");
    for (i=0; i<=NUMLETT; i++)
        for (j=0; j<=NUMLETT; j++)
	    logtraincounts[i][j] = log(traincounts[i][j]);
    printf("done.\n\n");

#ifdef VERBOSE
    /* Output the counts. */
    for (i=0; i<=NUMLETT; i++) {
        for (j=0; j<=NUMLETT; j++) {
	    printf("%d[%f](%d) ", traincounts[i][j], logtraincounts[i][j],
							testcounts[i][j]);
	}
	printf("\n");
    }
    printf("\n");
#endif

    /* Read in data for later displaying. */
    if ((fp = fopen(argv[2], "r"))) {
    } else {
	fprintf(stderr, "Cannot open file %s.\n", argv[2]);
	exit(1);
    }
    for (i=0; i<PRINTAMOUNT; i++) {
        t = getc(fp);
        if (t==EOF) {
	    displaydata[i] = ' ';
        } else {
	    displaydata[i] = t;
        }
	displaynum[i] = thenum(displaydata[i]);
    }
    fclose(fp);

    /* Display initial data. */
    printf("-----------------------------\n");
    printf("INITIAL DATA:\n\n");
    for (i=0; i<PRINTAMOUNT; i++)
	printf("%c", displaydata[i]);
    printf("\n\n");

    /* Initialise decoding function f. */
    for (i=0; i<=NUMLETT; i++)
	f[i] = i;
    prevlogscore = thelogscore(f);

    /* Run iterative algorithm. */
    seedrand();
    for (itnum=1; itnum<=NUMITS; itnum++) {

	prevlogscore = thelogscore(f); /* Unnecessary ... */

	/* Create function newf by swapping. */
	a = newrand();
	b = newrand();
	for (i=0; i<=NUMLETT; i++) {
	    if (i==a)
	      newf[i] = f[b];
	    else if (i==b)
	      newf[i] = f[a];
	    else
	      newf[i] = f[i];
	}

	/* Consider whether to accept new function. */
#ifdef VVERBOSE
for (i=0; i<=NUMLETT; i++)
    printf("f[%d]=%d ", i, f[i]);
printf("\nSwapping %d and %d ... ", a, b);
#endif
	newlogscore = thelogscore(newf);
	if ( log(drand48()) < newlogscore - prevlogscore ) {
	    /* Accept the new function. */
#ifdef VVERBOSE
printf("accepted; news=%f, prevs=%f\n", newlogscore, prevlogscore);
#endif
	    for (i=0; i<=NUMLETT; i++)
		f[i] = newf[i];
	    prevlogscore = newlogscore;
	} else {
	    /* Reject the new function. */
#ifdef VVERBOSE
printf("rejected; news=%f, prevs=%f\n", newlogscore, prevlogscore);
#endif
	}
	
	/* Provide output every once in a while. */
	if (divisible(itnum, PRINTEVERY) || (itnum==NUMITS) ) {
	    printf("-----------------------------\n");
	    printf("ITERATION %d (score=%f):\n\n", itnum, prevlogscore);
	    for (i=0; i<PRINTAMOUNT; i++) {
	        if (displaynum[i] == NUMLETT) {
		    /* Non-alphabetic character. */
		    printf("%c", displaydata[i]);
	        } else {
		    printf("%c", 'A' + f[displaynum[i]]);
	        }
	    }
	    printf("\n\n");
	}

    }

    /* Output final deciphered file. */
    strcpy(outname, argv[2]);
    strcat(outname, "out");
    if ((fp = fopen(argv[2], "r"))) {
    } else {
	fprintf(stderr, "Cannot open file %s.\n", argv[2]);
	exit(1);
    }
    if ((fpout = fopen(outname, "w"))) {
    } else {
	fprintf(stderr, "Cannot open file %s.\n", outname);
	exit(1);
    }
    while ( (t=getc(fp)) != EOF ) {
	if ( (newnum=thenum(t)) < NUMLETT )
	    fprintf(fpout, "%c", 'A' + f[newnum]);
	else
	    /* Non-alphabetic character. */
	    putc(t, fpout);
    }
    fclose(fpout);
    fclose(fp);
    return(0);

}


void getcounts(filename, M)
char *filename;
int M[NUMLETT+1][NUMLETT+1];
/*
Note: this routine computes M[i][j] to be one plus the number of
consecutive pairs ij in the file called <filename>.
*/
{

    int t, i, j;
    int prevnum, newnum;
    FILE *fp;

    printf("Getting pair counts for file '%s' ...", filename);

    /* Open file to read. */
    if ((fp = fopen(filename, "r"))) {
    } else {
	fprintf(stderr, "Cannot open file %s.\n", filename);
	exit(1);
    }

    /* Initialise counts. */
    prevnum = thenum( getc(fp) );
    for (i=0; i<=NUMLETT; i++)
        for (j=0; j<=NUMLETT; j++)
	    M[i][j] = 1;

    /* Add up the counts. */
    while ( (t=getc(fp)) != EOF ) {
	newnum = thenum(t);
	M[prevnum][newnum]++;
	prevnum = newnum;
    }

    /* Finish up. */
    fclose(fp);
    printf("done.\n");

}


int thenum(c)
char c;
{
    if ( (c >= 'A') && (c <= 'Z') )
	return(c - 'A');
    else if ( (c >= 'a') && (c <= 'z') )
	return(c - 'a');
    else
	return(NUMLETT); /* Non-alphabetic character. */
}


/* SEEDRAND: SEED RANDOM NUMBER GENERATOR. */
seedrand()
{
    int i, seed;
    struct timeval tmptv;
    gettimeofday (&tmptv, (struct timezone *)NULL);
    seed = (int) tmptv.tv_usec;
    srand48(seed);
    return(0);
}

newrand()
{
    double drand48();
    int ifloor();
    return( ifloor(drand48() * NUMLETT) );
}

/* IFLOOR */
ifloor(double x)  /* returns floor(x) as an integer */
{
    double floor();
    return((int)floor(x));
}

divisible (int aaa, int bbb)
{
    return ( (bbb*(aaa/bbb) == aaa) );
}


double thelogscore(f)
int f[NUMLETT+1];
/*
Note: the score function is taken to be the product, over all consecutive
pairs "ab" of test text, of count[f[a],f[b]]^PIPOWER, where count[c,d]
is one plus the number of occurances of the pair "cd" in the training
set.  It follows that:  log(score) = sum_{i,j} testcounts[i][j] *
logtraincounts[f[i]][f[j]] * PIPOWER, summed over letter pairs "ij".
*/
{
    int i,j;
    double tmpscore = 0.0;
    for (i=0; i<=NUMLETT; i++) {
        for (j=0; j<=NUMLETT; j++) {
	    tmpscore = tmpscore
		+ testcounts[i][j] * logtraincounts[f[i]][f[j]];
	}
    }
    return(tmpscore * PIPOWER);
}