package crystalpalace.btf;

import crystalpalace.coff.*;
import crystalpalace.util.*;

import java.util.*;
import java.io.*;

import com.github.icedland.iced.x86.*;
import com.github.icedland.iced.x86.asm.*;
import com.github.icedland.iced.x86.enc.*;
import com.github.icedland.iced.x86.dec.*;
import com.github.icedland.iced.x86.fmt.*;
import com.github.icedland.iced.x86.fmt.gas.*;

public class LinkTimeOptimizer {
	protected COFFObject          object  = null;
	protected Set                 touched = new HashSet();
	protected Map                 funcs   = null;
	protected Code                code    = null;

	public LinkTimeOptimizer(Code code) {
		this.code   = code;
		this.object = code.getObject();
	}

	/*
	 * This is our x64 call analysis, the purpose of this section is to walk our code (using a specific starting
	 * function), and determine which functions are used vs. not.
	 */
	protected void walk_x64(String function) {
		/* our instructions of interest */
		Set x64insts = new HashSet();
		x64insts.add("LEA r64, m");
		x64insts.add("MOV r64, r/m64");
		x64insts.add("CALL r/m64");

		/* if we're walking the function, it's referenced/called and we want to keep it */
		touched.add(function);

		/* start walking instruction by instruction */
		Iterator i = ( (List)funcs.get(function) ).iterator();
		while (i.hasNext()) {
			Instruction inst = (Instruction)i.next();

			if ( inst.isCallNear() ) {
				Symbol temp = code.getLabel( inst.getMemoryDisplacement32() );
				if (temp != null && !touched.contains( temp.getName() ))
					walk_x64( temp.getName() );
			}
			else if (inst.isIPRelativeMemoryOperand()) {
				if (x64insts.contains(inst.getOpCode().toInstructionString())) {
					Symbol temp = code.getLabel( inst.getMemoryDisplacement32() );
					if (temp != null && !touched.contains( temp.getName() ))
						walk_x64( temp.getName() );
				}
			}

			/* handle .refptr labels as a special case */
			Relocation r = code.getRelocation(inst);
			if (r != null && r.getSymbolName().startsWith(".refptr.")) {
				String symb = r.getSymbolName().substring(8);
				Symbol temp = object.getSymbol(symb);
				if (temp != null && ".text".equals(temp.getSection().getName()) && !touched.contains(temp.getName())) {
					walk_x64( temp.getName() );
				}
			}
		}
	}

	/*
	 * This is our x86 call analysis, the purpose of this section is to walk our code (using a specific starting
	 * function), and determine which functions are used vs. not.
	 */
	protected void walk_x86(String function) {
		/* if we're walking the function, it's referenced/called and we want to keep it */
		touched.add(function);

		/* start walking instruction by instruction */
		Iterator i = ( (List)funcs.get(function) ).iterator();
		while (i.hasNext()) {
			Instruction inst = (Instruction)i.next();

			/* if this is an instruction that touches our local label, we want to get that label
			 * and walk that function */
			if ( inst.isCallNear() ) {
				Symbol temp = code.getLabel( inst.getMemoryDisplacement32() );
				if (temp != null && !touched.contains( temp.getName() ))
					walk_x86( temp.getName() );
			}

			/* check for a relocation associated with the label */
			Relocation r = code.getRelocation(inst);
			if (r != null && ".text".equals(r.getSymbolName())) {
				Symbol temp = code.getLabel( r.getOffsetAsLong() );
				if (temp != null && !touched.contains( temp.getName() ))
					walk_x86( temp.getName() );
			}
			/* same type of thing as the x64 .refptr issue... we have a relocation for a local symbol... we need to walk it */
			else if (r != null) {
				Symbol temp = object.getSymbol(r.getSymbolName());
				if (temp != null && temp.getSection() != null && ".text".equals(temp.getSection().getName()) && !touched.contains(temp.getName())) {
					walk_x86( temp.getName() );
				}
			}
		}
	}

	public Map apply(Map _funcs) {
		funcs  = _funcs;

		if ("x64".equals(object.getMachine())) {
			/* check that we have an entry point of go() */
			if (!funcs.containsKey("go"))
				throw new RuntimeException("+optimize requires go() function as entrypoint");

			/* build up our map of touched functions/function symbols. We presume "go" is our entrypoint */
			walk_x64("go");
		}
		else {
			/* check that we have an entry point of go() */
			if (!funcs.containsKey("_go"))
				throw new RuntimeException("+optimize requires _go() function as entrypoint");

			/* build up our map of touched functions/function symbols. We presume "go" is our entrypoint */
			walk_x86("_go");
		}

		/* symbol names to get rid of from our COFF */
		Set removeme = new HashSet();

		/* let's do a little *snip* *snip* for anything that's not used by our program */
		Iterator i = funcs.entrySet().iterator();
		while (i.hasNext()) {
			Map.Entry entry = (Map.Entry)i.next();

			/* if it's not a function, we're not interested in cutting it out. */
			if ( !object.getSymbol(entry.getKey().toString()).isFunction() )
				continue;

			if ( !touched.contains(entry.getKey().toString()) ) {
				i.remove();
				removeme.add(entry.getKey().toString());
				//CrystalUtils.print_error("Getting rid of: " + entry.getKey().toString());
			}
		}

		/* get that stuff out of our COFF now */
		object.removeSymbols(removeme);

		/* and as simple as that... return our modified function map */
		return funcs;
	}
}
