Click here to Skip to main content
15,895,777 members
Articles / Programming Languages / Visual C++ 9.0

Compiling C++ code at runtime

Rate me:
Please Sign up or sign in to vote.
4.53/5 (19 votes)
8 Oct 2008CPOL4 min read 77.3K   812   37  
Optimizing algorithms at runtime with a domain-specific embedded language (DSEL) and LLVM.
#include "DynCompiler.h"

const static int OPT_LEVEL = 2; // FIXME

DynHelper::DynHelper() {
}

void DynHelper::SetupOptimizer() {
    if (OPT_LEVEL == 0)
        return;
    FPM->add(createCFGSimplificationPass());
    if (OPT_LEVEL == 1)
        FPM->add(createPromoteMemoryToRegisterPass());
    else
        FPM->add(createScalarReplAggregatesPass());
    FPM->add(createInstructionCombiningPass());

    FPM->add(createInstructionCombiningPass());
    FPM->add(createReassociatePass());
    FPM->add(createGVNPass());
}

Function* DynHelper::CurrentFunction() {
    return Builder.GetInsertBlock()->getParent();
}

 AllocaInst* DynHelper::AllocLocalVar(const std::string& var, const Type* type) {
    Function* f = CurrentFunction();
    IRBuilder<> B(&f->getEntryBlock(), f->getEntryBlock().begin());
    return B.CreateAlloca(type, 0, var.c_str());
}

void DynHelper::Init() {
    if (!M) {
        M = new Module("JIT");
        EE = ExecutionEngine::create(M);

        FPM = new FunctionPassManager(new ExistingModuleProvider(M));
        FPM->add(new TargetData(M));

        SetupOptimizer();

        // Set up "standard library"
        {
            std::vector<const Type*> args;
            args.push_back(Type::DoubleTy);
            FunctionType *FT = FunctionType::get(Type::Int32Ty, args, false);
            Function::Create(FT, Function::ExternalLinkage, "print_double", M);
        }
        {
            std::vector<const Type*> args;
            args.push_back(Type::Int32Ty);
            FunctionType *FT = FunctionType::get(Type::Int32Ty, args, false);
            Function::Create(FT, Function::ExternalLinkage, "print_int", M);
        }
    }
}

DynHelper TheHelper;

DynDoubleCast::DynDoubleCast(PDynExprBase expr) 
    : DynExprBase(Type::DoubleTy)
    , expr_(expr) 
{
}

Value* DynDoubleCast::eval() const {
    return TheHelper.Builder.CreateSIToFP(expr_->eval(), Type::DoubleTy);
}

DynIntCast::DynIntCast(PDynExprBase expr) 
    : DynExprBase(Type::Int32Ty)
    , expr_(expr)
{
}

Value* DynIntCast::eval() const {
    return TheHelper.Builder.CreateFPToSI(expr_->eval(), Type::Int32Ty);
}

DynBinaryOp::DynBinaryOp(Op o, PDynExprBase l, PDynExprBase r) 
    : DynExprBase(DynBestType(l, r))
    , op_(o)
    , lhs_(DynCastAsNeeded(l, type_))
    , rhs_(DynCastAsNeeded(r, type_))
{
}

Value* DynBinaryOp::eval() const {
    Value* tmp;
    switch(op_) {
        case ADD: return TheHelper.Builder.CreateAdd(lhs_->eval(), rhs_->eval());
        case SUB: return TheHelper.Builder.CreateSub(lhs_->eval(), rhs_->eval());
        case MUL: return TheHelper.Builder.CreateMul(lhs_->eval(), rhs_->eval());
        case DIV:
            if (lhs_->type_ == Type::DoubleTy)
                return TheHelper.Builder.CreateFDiv(lhs_->eval(), rhs_->eval());
            else
                return TheHelper.Builder.CreateSDiv(lhs_->eval(), rhs_->eval());
        case MOD:
            if (lhs_->type_ == Type::DoubleTy)
                return TheHelper.Builder.CreateFRem(lhs_->eval(), rhs_->eval());
            else
                return TheHelper.Builder.CreateSRem(lhs_->eval(), rhs_->eval());
        case LT:
            if (lhs_->type_ == Type::DoubleTy)
                tmp = TheHelper.Builder.CreateFCmpULT(lhs_->eval(), rhs_->eval(), "cmptmp");
            else
                tmp = TheHelper.Builder.CreateICmpSLT(lhs_->eval(), rhs_->eval(), "cmptmp");
            break;

        case GT:
            if (lhs_->type_ == Type::DoubleTy)
                tmp = TheHelper.Builder.CreateFCmpUGT(lhs_->eval(), rhs_->eval(), "cmptmp");
            else
                tmp = TheHelper.Builder.CreateICmpSGT(lhs_->eval(), rhs_->eval(), "cmptmp");
            break;

        case GE:
            if (lhs_->type_ == Type::DoubleTy)
                tmp = TheHelper.Builder.CreateFCmpUGE(lhs_->eval(), rhs_->eval(), "cmptmp");
            else
                tmp = TheHelper.Builder.CreateICmpSGE(lhs_->eval(), rhs_->eval(), "cmptmp");
            break;

        case LE:
            if (lhs_->type_ == Type::DoubleTy)
                tmp = TheHelper.Builder.CreateFCmpULE(lhs_->eval(), rhs_->eval(), "cmptmp");
            else
                tmp = TheHelper.Builder.CreateICmpSLE(lhs_->eval(), rhs_->eval(), "cmptmp");
            break;

        case EQ:
            if (lhs_->type_ == Type::DoubleTy)
                tmp = TheHelper.Builder.CreateFCmpUEQ(lhs_->eval(), rhs_->eval(), "cmptmp");
            else
                tmp = TheHelper.Builder.CreateICmpEQ(lhs_->eval(), rhs_->eval(), "cmptmp");
            break;

        case NE:  
            if (lhs_->type_ == Type::DoubleTy)
                tmp = TheHelper.Builder.CreateFCmpUNE(lhs_->eval(), rhs_->eval(), "cmptmp");
            else
                tmp = TheHelper.Builder.CreateICmpNE(lhs_->eval(), rhs_->eval(), "cmptmp");
            break;

        default:  return 0;
    }

    if (lhs_->type_ == Type::DoubleTy)
        return TheHelper.Builder.CreateUIToFP(tmp, Type::DoubleTy, "booltmp");
    else
        return TheHelper.Builder.CreateSExt(tmp, Type::Int32Ty, "booltmp");
}

DynLocalVar::DynLocalVar(const std::string name, PDynExprBase val, const Type* t)
    : DynExprBase(t)
    , name_(name)
    , alloc_(0)
{
    this->doAlloc();
    this->assign(val);
}

DynLocalVar::DynLocalVar(const std::string name, const Type* t) 
    : DynExprBase(t)
    , name_(name)
    , alloc_(0)
{ 
}

void DynLocalVar::assign(PDynExprBase val) {
    if (!alloc_)
        doAlloc();
    // Complete statement - we emit code
    Value* rhs = DynUpcastAsNeeded(val, type_)->eval();
    TheHelper.Builder.CreateStore(rhs, alloc_);
}

Value* DynLocalVar::assign(Value* val) {
    if (!alloc_)
        doAlloc();
    return TheHelper.Builder.CreateStore(val, alloc_);
}

Value* DynLocalVar::eval() const {
    return TheHelper.Builder.CreateLoad(alloc_);
}

void DynLocalVar::doAlloc() {
    alloc_ = TheHelper.AllocLocalVar(name_, type_);
}

DynLocalVarWrapper::DynLocalVarWrapper(const std::string name, const Type* t) {
    var_.reset(new DynLocalVar(name, t));
}

DynLocalVarWrapper::operator PDynExprBase() const {
    return var_;
}

DynLocalVarWrapper::DynLocalVarWrapper(const DynLocalVarWrapper& o) {
    var_->assign(o.var_->eval());
}

void DynLocalVarWrapper::operator=(const DynLocalVarWrapper& o) {
    var_->assign(o.var_->eval());
}

void DynLocalVarWrapper::operator++(int) {
    var_->assign(var_ + Const(1));
}

void DynLocalVarWrapper::operator--(int) {
    var_->assign(var_ - Const(1));
}

void DynLocalVarWrapper::operator=(PDynExprBase val) {
    var_->assign(val);
}

void DynLocalVarWrapper::operator+=(PDynExprBase val) {
    var_->assign(var_ + val);
}

void DynLocalVarWrapper::operator-=(PDynExprBase val) {
    var_->assign(var_ - val);
}

void DynLocalVarWrapper::operator*=(PDynExprBase val) {
    var_->assign(var_ * val);
}

void DynLocalVarWrapper::operator/=(PDynExprBase val) {
    var_->assign(var_ / val);
}

void DynLocalVarWrapper::operator%=(PDynExprBase val) {
    var_->assign(var_ % val);
}

void DynLocalVarWrapper::operator=(int val) {
    var_->assign(Const(val));
}

void DynLocalVarWrapper::operator+=(int val) {
    var_->assign(var_ + Const(val));
}

void DynLocalVarWrapper::operator-=(int val) {
    var_->assign(var_ - Const(val));
}

void DynLocalVarWrapper::operator*=(int val) {
    var_->assign(var_ * Const(val));
}

void DynLocalVarWrapper::operator/=(int val) {
    var_->assign(var_ / Const(val));
}

void DynLocalVarWrapper::operator%=(int val) {
    var_->assign(var_ % Const(val));
}

void DynLocalVarWrapper::operator=(double val) {
    var_->assign(Const(val));
}

void DynLocalVarWrapper::operator+=(double val) {
    var_->assign(var_ + Const(val));
}

void DynLocalVarWrapper::operator-=(double val) {
    var_->assign(var_ - Const(val));
}

void DynLocalVarWrapper::operator*=(double val) {
    var_->assign(var_ * Const(val));
}

void DynLocalVarWrapper::operator/=(double val) {
    var_->assign(var_ / Const(val));
}

void DynLocalVarWrapper::operator%=(double val) {
    var_->assign(var_ % Const(val));
}

const std::string DynLocalVarWrapper::name() const {
    return var_->name_;
}

DynFuncBuilder::DynFuncBuilder(const std::string& n, void*& c)
    : name_(n)
    , compiled_(c)
{
}

DynFuncBuilder::~DynFuncBuilder() {
    // Return block
    TheHelper.Builder.CreateBr(ReturnBlock);

    func_->getBasicBlockList().push_back(ReturnBlock);
    TheHelper.Builder.SetInsertPoint(ReturnBlock);

    TheHelper.Builder.CreateRet(ret_val_->eval());

    verifyFunction(*func_);
    TheHelper.FPM->run(*func_);

    compiled_ = TheHelper.EE->getPointerToFunction(func_);
}

void DynFuncBuilder::build_prototype()
{
    FunctionType *FT = FunctionType::get(return_type_, arg_types_, false);
    func_ = Function::Create(FT, Function::ExternalLinkage, name_, TheHelper.M);

    Function::arg_iterator ai = func_->arg_begin();
    std::vector<DynLocalVarWrapper*>::iterator arg_it = args_.begin();
    for (; arg_it != args_.end(); ++arg_it, ++ai) {
        DynLocalVarWrapper& arg = **arg_it;
        ai->setName(arg.name().c_str());
    }

    BasicBlock *BB = BasicBlock::Create("entry", func_);
    TheHelper.Builder.SetInsertPoint(BB);

    ai = func_->arg_begin();
    arg_it = args_.begin();
    for (; arg_it != args_.end(); ++arg_it, ++ai) {
        DynLocalVarWrapper& arg = **arg_it;
        arg.var_->assign(ai);
    }

    ReturnBlock    = BasicBlock::Create("return");
    ret_val_.reset(new DynLocalVar(".retval", return_type_));
}

void DynFuncBuilder::add_return(PDynExprBase val) {
    ret_val_->assign(DynUpcastAsNeeded(val, return_type_)->eval());
    TheHelper.Builder.CreateBr(ReturnBlock);
    BasicBlock* BB = BasicBlock::Create("after_ret", func_);
    TheHelper.Builder.SetInsertPoint(BB);
}

void DynFuncBuilder::add_return(int val) {
    add_return(Const(val));
}

void DynFuncBuilder::add_return(double val) {
    add_return(Const(val));
}

void DynFuncBuilder::print(PDynExprBase expr) {
    Function* pr;
    if (expr->type_ == Type::DoubleTy)
        pr = TheHelper.M->getFunction("print_double");
    else
        pr = TheHelper.M->getFunction("print_int");

    std::vector<Value*> args;
    args.push_back(expr->eval());

    TheHelper.Builder.CreateCall(pr, args.begin(), args.end());
}

void DynFuncBuilder::print(int val) {
    print(Const(val));
}

void DynFuncBuilder::print(double val) {
    print(Const(val));
}


DynIfBuilder::DynIfBuilder(PDynExprBase expr) 
    : no_else_(true)
{
    Value* ifcond = expr->eval();
    if (expr->type_ == Type::DoubleTy)
        ifcond = TheHelper.Builder.CreateFCmpONE(ifcond, ConstantFP::get(APFloat(0.0)), "ifcond");
    else
        ifcond = TheHelper.Builder.CreateICmpNE(ifcond, ConstantInt::get(APInt(32, 0, true)), "ifcond");

    func_ = TheHelper.CurrentFunction();
    ThenBB = BasicBlock::Create("then", func_);
    ElseBB = BasicBlock::Create("else"); // May be empty
    MergeBB = BasicBlock::Create("ifcont");

    TheHelper.Builder.CreateCondBr(ifcond, ThenBB, ElseBB);

    TheHelper.Builder.SetInsertPoint(ThenBB);
}

DynIfBuilder::~DynIfBuilder() {
    if (no_else_) {
        TheHelper.Builder.CreateBr(MergeBB);
        func_->getBasicBlockList().push_back(ElseBB);
        TheHelper.Builder.SetInsertPoint(ElseBB);
    }
    TheHelper.Builder.CreateBr(MergeBB);
    ElseBB = TheHelper.Builder.GetInsertBlock();

    // merge
    func_->getBasicBlockList().push_back(MergeBB);
    TheHelper.Builder.SetInsertPoint(MergeBB);
}

void DynIfBuilder::add_else() {
    no_else_ = false;

    // Close 'then'
    TheHelper.Builder.CreateBr(MergeBB);
    ThenBB = TheHelper.Builder.GetInsertBlock();

    // prepare for 'else'
    func_->getBasicBlockList().push_back(ElseBB);
    TheHelper.Builder.SetInsertPoint(ElseBB);
}


DynWhileBuilder::DynWhileBuilder(PDynExprBase expr) {
    func_ = TheHelper.CurrentFunction();

    LoopBB    = BasicBlock::Create("loop", func_);
    LoopBody  = BasicBlock::Create("loop_body");
    AfterLoop = BasicBlock::Create("after_loop");

    TheHelper.Builder.CreateBr(LoopBB);
    TheHelper.Builder.SetInsertPoint(LoopBB);

    Value* cond = expr->eval();
    if (expr->type_ == Type::DoubleTy)
        cond = TheHelper.Builder.CreateFCmpONE(cond, ConstantFP::get(APFloat(0.0)), "loopcond");
    else
        cond = TheHelper.Builder.CreateICmpNE(cond, ConstantInt::get(APInt(32, 0, true)), "loopcond");

    TheHelper.Builder.CreateCondBr(cond, LoopBody, AfterLoop);
    func_->getBasicBlockList().push_back(LoopBody);
    TheHelper.Builder.SetInsertPoint(LoopBody);
}

DynWhileBuilder::~DynWhileBuilder() {
    TheHelper.Builder.CreateBr(LoopBB);

    // merge
    func_->getBasicBlockList().push_back(AfterLoop);
    TheHelper.Builder.SetInsertPoint(AfterLoop);
}




PDynExprBase DynCastAsNeeded(PDynExprBase e, const Type* t) {
    if (t == Type::DoubleTy && e->type_ != t)
        return PDynExprBase(new DynDoubleCast(e));
    return e;
}

PDynExprBase DynUpcastAsNeeded(PDynExprBase e, const Type* t) {
    if (e->type_ == t)
        return e;
    if (t == Type::DoubleTy)
        return PDynExprBase(new DynDoubleCast(e));
    else
        return PDynExprBase(new DynIntCast(e));
}

const Type* DynBestType(PDynExprBase lhs, PDynExprBase rhs) {
    if (lhs->type_ == rhs->type_)
        return lhs->type_;
    return Type::DoubleTy;
}

PDynExprBase operator+(const PDynExprBase lhs, const PDynExprBase rhs) {
    return PDynExprBase(new DynBinaryOp(DynBinaryOp::ADD, lhs, rhs));
}

PDynExprBase operator-(const PDynExprBase lhs, const PDynExprBase rhs) {
    return PDynExprBase(new DynBinaryOp(DynBinaryOp::SUB, lhs, rhs));
}

PDynExprBase operator*(const PDynExprBase lhs, const PDynExprBase rhs) {
    return PDynExprBase(new DynBinaryOp(DynBinaryOp::MUL, lhs, rhs));
}

PDynExprBase operator/(const PDynExprBase lhs, const PDynExprBase rhs) {
    return PDynExprBase(new DynBinaryOp(DynBinaryOp::DIV, lhs, rhs));
}

PDynExprBase operator%(const PDynExprBase lhs, const PDynExprBase rhs) {
    return PDynExprBase(new DynBinaryOp(DynBinaryOp::MOD, lhs, rhs));
}


PDynExprBase operator>(const PDynExprBase lhs, const PDynExprBase rhs) {
    return PDynExprBase(new DynBinaryOp(DynBinaryOp::GT, lhs, rhs));
}

PDynExprBase operator<(const PDynExprBase lhs, const PDynExprBase rhs) {
    return PDynExprBase(new DynBinaryOp(DynBinaryOp::LT, lhs, rhs));
}

PDynExprBase operator==(const PDynExprBase lhs, const PDynExprBase rhs) {
    return PDynExprBase(new DynBinaryOp(DynBinaryOp::EQ, lhs, rhs));
}

PDynExprBase operator!=(const PDynExprBase lhs, const PDynExprBase rhs) {
    return PDynExprBase(new DynBinaryOp(DynBinaryOp::NE, lhs, rhs));
}

PDynExprBase operator>=(const PDynExprBase lhs, const PDynExprBase rhs) {
    return PDynExprBase(new DynBinaryOp(DynBinaryOp::GE, lhs, rhs));
}

PDynExprBase operator<=(const PDynExprBase lhs, const PDynExprBase rhs) {
    return PDynExprBase(new DynBinaryOp(DynBinaryOp::LE, lhs, rhs));
}

PDynExprBase operator+(const PDynExprBase lhs, int rhs) {
    return lhs + Const(rhs);
}

PDynExprBase operator-(const PDynExprBase lhs, int rhs) {
    return lhs - Const(rhs);
}

PDynExprBase operator*(const PDynExprBase lhs, int rhs) {
    return lhs * Const(rhs);
}

PDynExprBase operator/(const PDynExprBase lhs, int rhs) {
    return lhs / Const(rhs);
}

PDynExprBase operator%(const PDynExprBase lhs, int rhs) {
    return lhs % Const(rhs);
}

PDynExprBase operator>(const PDynExprBase lhs, int rhs) {
    return lhs > Const(rhs);
}

PDynExprBase operator<(const PDynExprBase lhs, int rhs) {
    return lhs < Const(rhs);
}

PDynExprBase operator==(const PDynExprBase lhs, int rhs) {
    return lhs == Const(rhs);
}
PDynExprBase operator!=(const PDynExprBase lhs, int rhs) {
    return lhs != Const(rhs);
}

PDynExprBase operator>=(const PDynExprBase lhs, int rhs) {
    return lhs >= Const(rhs);
}

PDynExprBase operator<=(const PDynExprBase lhs, int rhs) {
    return lhs <= Const(rhs);
}

PDynExprBase operator+(int lhs, const PDynExprBase rhs) {
    return Const(lhs) + rhs;
}

PDynExprBase operator-(int lhs, const PDynExprBase rhs) {
    return Const(lhs) - rhs;
}

PDynExprBase operator*(int lhs, const PDynExprBase rhs) {
    return Const(lhs) * rhs;
}

PDynExprBase operator/(int lhs, const PDynExprBase rhs) {
    return Const(lhs) / rhs;
}

PDynExprBase operator%(int lhs, const PDynExprBase rhs) {
    return Const(lhs) % rhs;
}

PDynExprBase operator>(int lhs, const PDynExprBase rhs) {
    return Const(lhs) > rhs;
}

PDynExprBase operator<(int lhs, const PDynExprBase rhs) {
    return Const(lhs) < rhs;
}

PDynExprBase operator==(int lhs, const PDynExprBase rhs) {
    return Const(lhs) == rhs;
}

PDynExprBase operator!=(int lhs, const PDynExprBase rhs) {
    return Const(lhs) != rhs;
}
PDynExprBase operator>=(int lhs, const PDynExprBase rhs) {
    return Const(lhs) >= rhs;
}

PDynExprBase operator<=(int lhs, const PDynExprBase rhs) {
    return Const(lhs) <= rhs;
}

PDynExprBase operator+(const PDynExprBase lhs, double rhs) {
    return lhs + Const(rhs);
}

PDynExprBase operator-(const PDynExprBase lhs, double rhs) {
    return lhs - Const(rhs);
}

PDynExprBase operator*(const PDynExprBase lhs, double rhs) {
    return lhs * Const(rhs);
}

PDynExprBase operator/(const PDynExprBase lhs, double rhs) {
    return lhs / Const(rhs);
}

PDynExprBase operator%(const PDynExprBase lhs, double rhs) {
    return lhs % Const(rhs);
}

PDynExprBase operator>(const PDynExprBase lhs, double rhs) {
    return lhs > Const(rhs);
}

PDynExprBase operator<(const PDynExprBase lhs, double rhs) {
    return lhs < Const(rhs);
}

PDynExprBase operator==(const PDynExprBase lhs, double rhs) {
    return lhs == Const(rhs);
}
PDynExprBase operator!=(const PDynExprBase lhs, double rhs) {
    return lhs != Const(rhs);
}

PDynExprBase operator>=(const PDynExprBase lhs, double rhs) {
    return lhs >= Const(rhs);
}

PDynExprBase operator<=(const PDynExprBase lhs, double rhs) {
    return lhs <= Const(rhs);
}

PDynExprBase operator+(double lhs, const PDynExprBase rhs) {
    return Const(lhs) + rhs;
}

PDynExprBase operator-(double lhs, const PDynExprBase rhs) {
    return Const(lhs) - rhs;
}

PDynExprBase operator*(double lhs, const PDynExprBase rhs) {
    return Const(lhs) * rhs;
}

PDynExprBase operator/(double lhs, const PDynExprBase rhs) {
    return Const(lhs) / rhs;
}

PDynExprBase operator%(double lhs, const PDynExprBase rhs) {
    return Const(lhs) % rhs;
}

PDynExprBase operator>(double lhs, const PDynExprBase rhs) {
    return Const(lhs) > rhs;
}

PDynExprBase operator<(double lhs, const PDynExprBase rhs) {
    return Const(lhs) < rhs;
}

PDynExprBase operator==(double lhs, const PDynExprBase rhs) {
    return Const(lhs) == rhs;
}

PDynExprBase operator!=(double lhs, const PDynExprBase rhs) {
    return Const(lhs) != rhs;
}
PDynExprBase operator>=(double lhs, const PDynExprBase rhs) {
    return Const(lhs) >= rhs;
}

PDynExprBase operator<=(double lhs, const PDynExprBase rhs) {
    return Const(lhs) <= rhs;
}



PDynExprBase Const(int v) {
    return PDynExprBase(new DynInt(v));
}

PDynExprBase Const(double v) {
    return PDynExprBase(new DynDouble(v));
}


__declspec( dllexport )
int print_int(int val) {
    std::cout << val << "\n";
    return 0;
}
__declspec( dllexport )
int print_double(double val) {
    std::cout << val << "\n";
    return 0;
}

By viewing downloads associated with this article you agree to the Terms of Service and the article's licence.

If a file you wish to view isn't highlighted, and is a text file (not binary), please let us know and we'll add colourisation support for it.

License

This article, along with any associated source code and files, is licensed under The Code Project Open License (CPOL)


Written By
United States United States
This member has not yet provided a Biography. Assume it's interesting and varied, and probably something to do with programming.

Comments and Discussions