#include <sys/capability.h>
#include <sys/types.h>
#include <net/bpf.h>
#include <net/bpf_jitter.h>
#include <map>
#include <stdint.h>

#include <llvm/LinkAllPasses.h>
#include <llvm/Bitcode/ReaderWriter.h>
#include <llvm/Constants.h>
#include <llvm/DerivedTypes.h>
#include <llvm/Linker.h>
#include <llvm/ExecutionEngine/ExecutionEngine.h>
#include <llvm/ExecutionEngine/JIT.h>
#include <llvm/ExecutionEngine/JITEventListener.h>
#include <llvm/ExecutionEngine/GenericValue.h>
#include <llvm/GlobalVariable.h>
#include <llvm/Module.h>
#include <llvm/LLVMContext.h>
#include <llvm/PassManager.h>
#include "llvm/Analysis/Verifier.h"
#include <llvm/Support/IRBuilder.h>
#include <llvm/Support/MemoryBuffer.h>
#include "llvm/Transforms/IPO/PassManagerBuilder.h"
#include <llvm/Target/TargetData.h>
#include <llvm/Support/system_error.h>
#include <llvm/Support/TargetSelect.h>

//#define DEBUG_CODEGEN

using namespace llvm;

class BPFJIT
{
	private:
	LLVMContext c;
	Module *mod;
	Type *int8Ty;
	Type *int16Ty;
	Type *int32Ty;
	PointerType *int8PtrTy;
	PointerType *int16PtrTy;
	PointerType *int32PtrTy;
	Function *func;
	IRBuilder<> B;
	// BPF programs contain only forward branches.  For each branch, we
	// insert a new basic block for the target and then an unconditional
	// branch to that block when we encounter the relevant instruction.
	std::map<unsigned long, BasicBlock*> branchTargets;
	// Accumulator
	Value *A;
	// Index register
	Value *X;
	// Scratch memory
	Value *M;
	// Buffer containing the packet
	Value *P;
	// Length of the packet
	Value *len;

	class Listener : public JITEventListener
	{
		public:
		void *functionAddress;
		size_t functionSize;
		Listener() : functionAddress(0), functionSize(0) {}
		virtual void NotifyFunctionEmitted(const Function &F, void *addr,
				size_t size, const EmittedFunctionDetails &details)
		{
			functionSize = size;
			functionAddress = addr;
		}
	} listener;

	Value *loadFromMemory(unsigned int v)
	{
		auto gep = B.CreateConstInBoundsGEP1_32(M, v);
		return (B.CreateLoad(gep));
	}
	Value *loadFromPacket(Value *v, int size)
	{
		auto gep = B.CreateGEP(P, v);
		switch (size) {
		default: llvm_unreachable("Invalid size for load");
		case BPF_B: {
			Value *val = B.CreateLoad(gep);
			return (B.CreateZExt(val, int32Ty));
		}
		case BPF_H: {
			gep = B.CreateBitCast(gep, int16PtrTy);
			Value *val = B.CreateLoad(gep);
			return (B.CreateZExt(val, int32Ty));
		}
		case BPF_W: {
			gep = B.CreateBitCast(gep, int32PtrTy);
			return (B.CreateLoad(gep));
		}
		}
	}

	void emit_ld(bpf_insn *isn)
	{
		Value *v;
		switch(BPF_MODE(isn->code)) {
		default: llvm_unreachable("Invalid opcode for load");
		case BPF_IMM:
			v = ConstantInt::get(int32Ty, isn->k);
			break;
		case BPF_LEN:
			v = len;
			break;
		case BPF_MEM:
			v = loadFromMemory(isn->k);
			break;
		case BPF_IND:
			v = B.CreateLoad(X);
			v = B.CreateAdd(v, ConstantInt::get(int32Ty, isn->k));
			v = loadFromPacket(v, BPF_SIZE(isn->code));
			break;
		case BPF_ABS:
			v = loadFromPacket(ConstantInt::get(int32Ty, isn->k),
			                   BPF_SIZE(isn->code));
			break;
		}
		B.CreateStore(v, A);
	}

	void emit_ldx(bpf_insn *isn)
	{
		Value *v;
		switch(BPF_MODE(isn->code)) {
		default: llvm_unreachable("Invalid opcode for ldx");
		case BPF_IMM:
			v = ConstantInt::get(int32Ty, isn->k);
			break;
		case BPF_LEN:
			v = len;
			break;
		case BPF_MEM:
			v = loadFromMemory(isn->k);
			break;
		case BPF_MSH:
			// p[k:1]
			v = B.CreateLoad(P, B.CreateConstInBoundsGEP1_32(P, isn->k));
			// (P[k:1]&0xf)
			v = B.CreateAnd(v, 0xf);
			// 4*(P[k:1]&0xf)
			v = B.CreateShl(v, 2);
			break;
		}
		B.CreateStore(v, X);
	}

	void emit_st(bpf_insn *isn)
	{
		auto gep = B.CreateConstInBoundsGEP1_32(M, isn->k);
		B.CreateStore(A, gep);
	}

	void emit_stx(bpf_insn *isn)
	{
		auto gep = B.CreateConstInBoundsGEP1_32(M, isn->k);
		B.CreateStore(X, gep);
	}

	void emit_alu(bpf_insn *isn)
	{
		Value *lhs = B.CreateLoad(A);
		Value *rhs;
		Value *v;

		if (BPF_SRC(isn->code) == BPF_K)
			rhs = ConstantInt::get(int32Ty, isn->k);
		else if (BPF_SRC(isn->code) == BPF_X)
			rhs = B.CreateLoad(X);
		else
			assert(BPF_OP(isn->code) == BPF_NEG);
		switch (BPF_OP(isn->code)) {
		case BPF_ADD:
			v = B.CreateAdd(lhs, rhs);
			break;
		case BPF_SUB:
			v = B.CreateSub(lhs, rhs);
			break;
		case BPF_MUL:
			v = B.CreateMul(lhs, rhs);
			break;
		case BPF_DIV:
			v = B.CreateUDiv(lhs, rhs);
			break;
		case BPF_AND:
			v = B.CreateAnd(lhs, rhs);
			break;
		case BPF_OR:
			v = B.CreateOr(lhs, rhs);
			break;
		case BPF_LSH:
			v = B.CreateShl(lhs, rhs);
			break;
		case BPF_RSH:
			v = B.CreateAShr(lhs, rhs);
			break;
		case BPF_NEG:
			v = B.CreateSub(ConstantInt::get(int32Ty, 0), lhs);
			break;
		}
		B.CreateStore(v, A);
	}

	void emit_jmp(unsigned long pc, bpf_insn *isn)
	{
		if (BPF_OP(isn->code) == BPF_JA) {
			BasicBlock *&target = branchTargets[pc + isn->k];
			if (0 == target)
				target = BasicBlock::Create(c, "unconditional", func);
			B.CreateBr(target);
			B.ClearInsertionPoint();
			return;
		}
		BasicBlock *&jt = branchTargets[pc + isn->jt];
		if (0 == jt)
			jt = BasicBlock::Create(c, "true", func);
		BasicBlock *&jf = branchTargets[pc + isn->jf];
		if (0 == jf)
			jf = BasicBlock::Create(c, "false", func);
		Value *rhs;
		Value *v;
		Value *lhs = B.CreateLoad(A);

		if (BPF_SRC(isn->code) == BPF_K)
			rhs = ConstantInt::get(int32Ty, isn->k);
		else if (BPF_SRC(isn->code) == BPF_X)
			rhs = B.CreateLoad(X);
		else
			llvm_unreachable("Invalid rvalue for comparison");

		switch (BPF_OP(isn->code)) {
		case BPF_JGT:
			v = B.CreateICmpUGT(lhs, rhs);
			break;
		case BPF_JGE:
			v = B.CreateICmpUGE(lhs, rhs);
			break;
		case BPF_JEQ:
			v = B.CreateICmpEQ(lhs, rhs);
			break;
		case BPF_JSET:
			v = B.CreateAnd(lhs, rhs);
			v = B.CreateICmpNE(v, ConstantInt::get(int32Ty, 0));
			break;
		}
		B.CreateCondBr(v, jt, jf);
		B.ClearInsertionPoint();
	}

	void emit_ret(bpf_insn *isn)
	{
		Value *v;
		if (BPF_SRC(isn->code) == BPF_K)
			v = ConstantInt::get(int32Ty, isn->k);
		else if (BPF_RVAL(isn->code) == BPF_A)
			v = B.CreateLoad(A);
		else
			llvm_unreachable("Invalid return");
		B.CreateRet(v);
		B.ClearInsertionPoint();
	}

	void emit_misc(bpf_insn *isn)
	{
		switch (BPF_MISCOP(isn->code)) {
		default: llvm_unreachable("Invalid size for load");
		case BPF_TAX:
			B.CreateStore(B.CreateLoad(A), X);
			break;
		case BPF_TXA:
			B.CreateStore(B.CreateLoad(X), A);
			break;
		}
	}

	public:
	BPFJIT(): B(c)
	{
		mod = new Module("bpf", c);
		int8Ty = Type::getInt8Ty(c);
		int16Ty = Type::getInt16Ty(c);
		int32Ty = Type::getInt32Ty(c);
		int8PtrTy = PointerType::getUnqual(int8Ty);
		int16PtrTy = PointerType::getUnqual(int16Ty);
		int32PtrTy = PointerType::getUnqual(int32Ty);
		// Int is always 32 bits on FreeBSD
		Type *argTypes[] = {
			PointerType::getUnqual(Type::getInt8Ty(c)), int32Ty,
			int32Ty };
		// u_int filter(u_char *pkt, u_int wirelen, u_int buflen);
		FunctionType *bpfTy =
			FunctionType::get(int32Ty, argTypes, false);
		func = Function::Create(bpfTy, GlobalValue::ExternalLinkage,
		                        "bpf_filter", mod);
		auto args = func->arg_begin();
		P = args++;
		len = args++;
		BasicBlock *entry = BasicBlock::Create(c, "entry", func);
		B.SetInsertPoint(entry);
		// Accumulator
		A = B.CreateAlloca(int32Ty);
		// Index register
		X = B.CreateAlloca(int32Ty);
		// Scratch memory
		M = B.CreateAlloca(int32Ty, 
		                   ConstantInt::get(int32Ty, BPF_MEMWORDS));
	}
	~BPFJIT()
	{
		if (mod)
			delete mod;
	}

	void generateIRFromBPF(unsigned long n, bpf_insn *isns)
	{

		for (unsigned long i=0 ; i<n ; i++) {
			auto block = branchTargets.find(i);
			if (block != branchTargets.end()) {
				if (B.GetInsertBlock() != 0)
					B.CreateBr(block->second);
				B.SetInsertPoint(block->second);
			}
			// Skip unreachable instructions
			if (0 == B.GetInsertBlock())
				continue;

			bpf_insn *isn = &isns[i];
			switch (BPF_CLASS(isn->code)) {
			case BPF_LD:
				emit_ld(isn); break;
			case BPF_LDX:
				emit_ldx(isn); break;
			case BPF_ST:
				emit_st(isn); break;
			case BPF_STX:
				emit_stx(isn); break;
			case BPF_ALU:
				emit_alu(isn); break;
			case BPF_JMP:
				emit_jmp(i+1, isn); break;
			case BPF_RET:
				emit_ret(isn); break;
			case BPF_MISC:
				emit_misc(isn); break;
			}
		}
		for (auto block : branchTargets) {
			if (!block.second->getTerminator()) {
				B.SetInsertPoint(block.second);
				B.CreateUnreachable();
			}
		}
		mod->setTargetTriple("x86_64-unknown-freebsd10.0");
	}

	/**
	 * Run the standard set of optimisation passes on the generated bitcode.
	 */
	void optimise(void)
	{
#ifdef DEBUG_CODEGEN
		mod->dump();
		verifyModule(*mod);
#endif
		PassManagerBuilder PMBuilder;
		PMBuilder.OptLevel = 3;
		PMBuilder.Inliner = createFunctionInliningPass(275);
		FunctionPassManager *PerFunctionPasses =
		         new FunctionPassManager(mod);
		PerFunctionPasses->add(new TargetData(mod));

		PMBuilder.populateFunctionPassManager(*PerFunctionPasses);

		for (auto &I : *mod)
			if (!I.isDeclaration())
				PerFunctionPasses->run(I);

		PerFunctionPasses->doFinalization();
		delete PerFunctionPasses;
		// Run the per-module passes
		PassManager *PerModulePasses = new PassManager();
		PerModulePasses->add(new TargetData(mod));
		PMBuilder.populateModulePassManager(*PerModulePasses);
		PerModulePasses->run(*mod);
		delete PerModulePasses;
#ifdef DEBUG_CODEGEN
		mod->dump();
		verifyModule(*mod);
#endif
	}

	/**
	 * JIT the packet filter.  This function is intended for debugging.
	 */
	bpf_filter_func jit(size_t &size)
	{
		std::string error;
		ExecutionEngine *EE = ExecutionEngine::create(mod, false, &error);
		//ExecutionEngine *EE = MCJIT::createJIT(mod, false, &error);
		mod = 0;
		if (!EE) {
			fprintf(stderr, "Error: %s\n", error.c_str());
			exit(-1);
		}
		EE->RegisterJITEventListener(&listener);
		bpf_filter_func ret =
			(bpf_filter_func)EE->getPointerToFunction(func);
		size = listener.functionSize;
		return ret;
	}

	void writeBitcode(void)
	{
		llvm::raw_fd_ostream os(STDOUT_FILENO, false);
		WriteBitcodeToFile(mod, os);
		fflush(stdout);
	}
};

int main(int argc, char **argv)
{
	bool optimise = false;
	bool debug = false;
	int c;

	cap_enter();
	InitializeNativeTarget();
	LLVMLinkInJIT();
	while ((c = getopt(argc, argv, "dO")) != -1)
		switch (c) {
		case 'O':
			optimise = true;
			break;
		case 'd':
			debug = true;
			break;
		}
	while (!feof(stdin))
	{
		int32_t size;
		BPFJIT j;
		size_t funcSize;
		bpf_filter_func addr;

		fread(&size, sizeof(size), 1, stdin);
		if (feof(stdin)) break;
		bpf_insn *prog = (bpf_insn*)calloc(size, sizeof(bpf_insn));
		//if (size != fread(prog, size, sizeof(bpf_insn), stdin)) {

		size_t ret = fread(prog, sizeof(bpf_insn), size, stdin);
		if (size != ret) {
			fprintf(stderr, "Read %d instructions, expected %d\n", ret, size);
			return (EXIT_FAILURE);
		}

		j.generateIRFromBPF(size, prog);
		if (optimise)
			j.optimise();

		addr = j.jit(funcSize);
		fwrite((void*)addr, funcSize, 1, stdout); 

		if (debug)
			for (unsigned i=0; i<funcSize ; i++)
			{
				fprintf(stderr, "0x%hhx ", ((char*)addr)[i]);
			}
		free(prog);
	}
	return (EXIT_FAILURE);
}
