import os
import tempfile

import idaapi
import idc
import idautils

from miasm2.core.bin_stream_ida import bin_stream_ida
from miasm2.core.asmblock import is_int
from miasm2.expression.simplifications import expr_simp
from miasm2.analysis.data_flow import dead_simp
from miasm2.ir.ir import AssignBlock, IRBlock

from utils import guess_machine, expr2colorstr


# Override Miasm asmblock default label naming convention to shrink block size
# in IDA

def label_init(self, name="", offset=None):
    self.fixedblocs = False
    if is_int(name):
        name = "loc_%X" % (int(name) & 0xFFFFFFFFFFFFFFFF)
    self.name = name
    self.attrib = None
    if offset is None:
        self.offset = None
    else:
        self.offset = int(offset)
def label_str(self):
    if isinstance(self.offset, (int, long)):
        return "%s:0x%x" % (self.name, self.offset)
    else:
        return "%s:%s" % (self.name, str(self.offset))


def color_irblock(irblock, ir_arch):
    out = []
    lbl = idaapi.COLSTR(ir_arch.symbol_pool.str_loc_key(irblock.loc_key), idaapi.SCOLOR_INSN)
    out.append(lbl)
    for assignblk in irblock:
        for dst, src in sorted(assignblk.iteritems()):
            dst_f = expr2colorstr(dst, symbol_pool=ir_arch.symbol_pool)
            src_f = expr2colorstr(src, symbol_pool=ir_arch.symbol_pool)
            line = idaapi.COLSTR("%s = %s" % (dst_f, src_f), idaapi.SCOLOR_INSN)
            out.append('    %s' % line)
        out.append("")
    out.pop()
    dst = str('    Dst: %s' % irblock.dst)
    dst = idaapi.COLSTR(dst, idaapi.SCOLOR_RPTCMT)
    out.append(dst)
    return "\n".join(out)


class GraphMiasmIR(idaapi.GraphViewer):

    def __init__(self, ir_arch, title, result):
        idaapi.GraphViewer.__init__(self, title)
        self.ir_arch = ir_arch
        self.result = result
        self.names = {}

    def OnRefresh(self):
        self.Clear()
        addr_id = {}
        for irblock in self.ir_arch.blocks.values():
            id_irblock = self.AddNode(color_irblock(irblock, self.ir_arch))
            addr_id[irblock] = id_irblock

        for irblock in self.ir_arch.blocks.values():
            if not irblock:
                continue
            all_dst = self.ir_arch.dst_trackback(irblock)
            for dst in all_dst:
                if not dst.is_loc():
                    continue
                if not dst.loc_key in self.ir_arch.blocks:
                    continue
                dst_block = self.ir_arch.blocks[dst.loc_key]
                node1 = addr_id[irblock]
                node2 = addr_id[dst_block]
                self.AddEdge(node1, node2)
        return True

    def OnGetText(self, node_id):
        return str(self[node_id])

    def OnSelect(self, node_id):
        return True

    def OnClick(self, node_id):
        return True

    def Show(self):
        if not idaapi.GraphViewer.Show(self):
            return False
        return True


def build_graph(verbose=False, simplify=False):
    start_addr = idc.ScreenEA()

    machine = guess_machine(addr=start_addr)
    mn, dis_engine, ira = machine.mn, machine.dis_engine, machine.ira

    if verbose:
        print "Arch", dis_engine

    fname = idc.GetInputFile()
    if verbose:
        print fname

    bs = bin_stream_ida()
    mdis = dis_engine(bs)
    ir_arch = ira(mdis.symbol_pool)

    # populate symbols with ida names
    for addr, name in idautils.Names():
        # print hex(ad), repr(name)
        if name is None:
            continue
        if (mdis.symbol_pool.getby_offset(addr) or
            mdis.symbol_pool.getby_name(name)):
            # Symbol alias
            continue
        mdis.symbol_pool.add_location(name, addr)

    if verbose:
        print "start disasm"
    if verbose:
        print hex(addr)

    asmcfg = mdis.dis_multiblock(start_addr)

    if verbose:
        print "generating graph"
        open('asm_flow.dot', 'w').write(asmcfg.dot())

        print "generating IR... %x" % start_addr

    for block in asmcfg.blocks:
        if verbose:
            print 'ADD'
            print block
        ir_arch.add_block(block)

    if verbose:
        print "IR ok... %x" % start_addr

    for irb in ir_arch.blocks.itervalues():
        irs = []
        for assignblk in irb:
            new_assignblk = {
                expr_simp(dst): expr_simp(src)
                for dst, src in assignblk.iteritems()
            }
            irs.append(AssignBlock(new_assignblk, instr=assignblk.instr))
        ir_arch.blocks[irb.loc_key] = IRBlock(irb.loc_key, irs)

    if verbose:
        out = ir_arch.graph.dot()
        open(os.path.join(tempfile.gettempdir(), 'graph.dot'), 'wb').write(out)
    title = "Miasm IR graph"

    if simplify:
        dead_simp(ir_arch)

        ir_arch.simplify(expr_simp)
        modified = True
        while modified:
            modified = False
            modified |= dead_simp(ir_arch)
            modified |= ir_arch.remove_empty_assignblks()
            modified |= ir_arch.remove_jmp_blocks()
            modified |= ir_arch.merge_blocks()
        title += " (simplified)"

    g = GraphMiasmIR(ir_arch, title, None)

    g.Show()

if __name__ == "__main__":
    build_graph(verbose=True, simplify=False)
