#!/usr/bin/env python3
#
# Plot sector entropy of disk image
# Copyright polprog 2021
# Released on 3-clause BSD Licence
# Version 1.2, 12.12.2021
# 1.1: Fix entropy calculation routine to reflect Shannon's entropy
# 1.2: Add histograms, autosizing and offset on xlabel in hexadecimal
#      or CHS mode
import os
import math
import numpy as np
import matplotlib.pyplot as plt
import argparse


parser = argparse.ArgumentParser(description='Plot entropy of given file. Author: polprog, https://polprog.net')
parser.add_argument('-s', '--sectors', dest="sectors", type=int, default=18, help='Sectors per track')
parser.add_argument('-c', '--bytes', dest="bps", type=int, default=512, help='Bytes per sector')
parser.add_argument('-z', '--zeros', dest="zeros", action="store_true", help='Pad with zeros')
parser.add_argument('-H', '--histogram', dest="histogram", action="store_true", help='Draw a histogram')
parser.add_argument('-A', '--auto-size', dest="autosize", action="store_true", help='Automatically pick sectors/track')
parser.add_argument('-C', '--chs', dest="chs", action="store_true", help='Display offset as sectors')


parser.add_argument("filename",
                    help="input file with two matrices")

args = parser.parse_args()
print(args)

image = None


try:
    image = open(args.filename, "rb")
except FileNotFoundError:
    print("Error opening image")
    exit(1)

imagesize =  os.stat(args.filename).st_size
print("File size: ", imagesize, "bytes")
if(imagesize % 512 != 0 and not args.zeros):
    print(f"Error: Not a multiple of {args.bps} bytes. Use -z to display anyway")
    exit(1)

if(args.autosize):
    args.sectors = int(math.sqrt(imagesize / args.bps /3 ))
found_bytes = [0] * 256
entropy = []
sectornum = 0
sector = image.read(args.bps)

while sector:
    e = 0
    popcount = 0
    for b in sector:
        found_bytes[b] += 1
    for k in found_bytes:
        if k != 0:
            e -= k/args.bps * math.log(k/args.bps, 2)
    entropy.append(e)
    sector = image.read(args.bps)
    sectornum += 1
    found_bytes = [0] * 256
    if(sectornum % (imagesize/512/100)):
        print("Calculating entropy. %d%% done...\r" % (sectornum/(imagesize/args.bps)*100), end='')


# pad array to reshape if necesary
if(sectornum % args.sectors != 0):
    entropy.extend([0] * (args.sectors - sectornum % args.sectors))
else:
    print("\nNo padding necesary.")
entropy_tracks = np.rot90(np.fliplr(np.reshape(entropy, (-1, args.sectors))))
print("\nDrawing...")

fig, ax = plt.subplots()
if(args.histogram):
    im = plt.hist(entropy_tracks.flatten(), bins=16)
    stdev = np.std(entropy_tracks.flatten())
    avg = np.mean(entropy_tracks.flatten())
    stats = "Average %.2f, stdev %.2f" % (avg, stdev)
    plt.title("Sector entropy histogram of " + args.filename + "\n" + str(args.bps) + " bytes per sector.\n" + stats)
    plt.xlim([0, 8])

else:
    im = plt.imshow(entropy_tracks, cmap='viridis', interpolation='nearest', vmin=0, vmax=8)
    plt.title("Sector entropy of " + args.filename + "\n" + str(args.bps) + " bytes per sector")
    plt.ylabel("Sector")
    if(args.chs):
        plt.xlabel("Track")
        xlabels = map(lambda t: "%d" % int(t ), ax.get_xticks())
        plt.yticks(np.arange(0, args.sectors, int(args.sectors/6)))
        ax.set_xticklabels(xlabels)

    else:
        plt.xlabel("Offset")
        roundfactor = math.log(imagesize,16)
        stop = imagesize/args.bps/args.sectors
        xticks = np.arange(0, stop, int(stop/8+1))
        plt.xticks(xticks)
        xlabels = map(lambda t: "0x%08X" % int(t * args.bps * args.sectors), ax.get_xticks())
        ax.set_xticklabels(xlabels)

    plt.xticks(rotation = 30, ha='right')
    plt.subplots_adjust(bottom=0.025, top=0.98)
    fig.colorbar(im, ax=ax, orientation='horizontal')
plt.show()
