Advertisement
Guest User

Untitled

a guest
Feb 21st, 2017
54
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 7.72 KB | None | 0 0
  1. #include "llvm/Pass.h"
  2. #include "llvm/ADT/SmallString.h"
  3. #include "llvm/ADT/SmallVector.h"
  4. #include "llvm/ADT/StringRef.h"
  5. #include "llvm/Analysis/LoopPass.h"
  6. #include "llvm/Analysis/LoopInfo.h"
  7. #include "llvm/Analysis/OrderedBasicBlock.h"
  8. #include "llvm/IR/BasicBlock.h"
  9. #include "llvm/IR/Instruction.h"
  10. #include "llvm/IR/Instructions.h"
  11. #include "llvm/IR/InstVisitor.h"
  12. #include "llvm/IR/IRBuilder.h"
  13. #include "llvm/IR/MDBuilder.h"
  14. #include "llvm/IR/Metadata.h"
  15. #include "llvm/Transforms/Utils/BasicBlockUtils.h"
  16. #include "llvm/Transforms/Utils/Cloning.h"
  17. #include "llvm/Transforms/Utils/LoopUtils.h"
  18. #include "llvm/Support/Debug.h"
  19. #include "llvm/Support/raw_ostream.h"
  20.  
  21. // Defines the value to used during debugging
  22. #define DEBUG_TYPE "predictive"
  23.  
  24. // Defines the type of indexed expression. Possible values are {RHS, LHS}
  25. #define IE_TYPE "loop.type"
  26.  
  27. #define IE_REPLACE "loop.replace"
  28.  
  29. // Defines the constant which replaces the induction variable
  30. #define IE_ITER "loop.iter"
  31.  
  32. // Defines the array affected by the transformation
  33. #define IE_ARRAY "loop.array"
  34.  
  35. // Marks the instruction which will be useless after transformation
  36. #define IE_DEAD "loop.dead"
  37.  
  38. using namespace llvm;
  39.  
  40. namespace {
  41. class Pass1 : public LoopPass {
  42. public:
  43. static char ID;
  44. LoopInfo *LI;
  45. Pass1() : LoopPass(ID) {}
  46.  
  47. void getAnalysisUsage(AnalysisUsage &AU) const override {
  48. AU.setPreservesAll();
  49. AU.addRequired<LoopInfoWrapperPass>();
  50. AU.addPreserved<LoopInfoWrapperPass>();
  51. getLoopAnalysisUsage(AU);
  52. }
  53.  
  54. int getInitial(Instruction *index, Instruction *arrayptr) {
  55. for (auto tmpL = index,
  56. tmpA = arrayptr;
  57. tmpL != tmpA;
  58. tmpL = tmpL->user_back()
  59. ) {
  60. if(auto bi = dyn_cast<BinaryOperator>(tmpL)) {
  61. for (auto l = bi->op_begin(); l != bi->op_end(); ++l) {
  62. if (auto c = dyn_cast<ConstantInt>(l)) {
  63. return c->getSExtValue();
  64. }
  65. }
  66. }
  67. }
  68. return 0;
  69. }
  70.  
  71. void setUsersMetadata(Instruction *start, Instruction *end, MDNode *node) {
  72. for (auto tmpL = start,
  73. tmpA = end;
  74. tmpL != tmpA;
  75. tmpL = tmpL->user_back()
  76. ) {
  77. tmpL->setMetadata(IE_DEAD, node);
  78. }
  79. }
  80.  
  81. Constant* makeConstant(LLVMContext& context, int value) {
  82. return ConstantInt::get(Type::getInt32Ty(context), value);
  83. }
  84.  
  85. void setAllMetadata(BasicBlock *BB) {
  86. OrderedBasicBlock oBB(BB);
  87. LLVMContext& context = BB->getContext();
  88. MDBuilder builder(context);
  89. // int count = 0;
  90. for (auto gepIT = BB->begin(); gepIT != BB->end(); ++gepIT) {
  91. if (auto *gep = dyn_cast<GetElementPtrInst>(gepIT)) {
  92. auto *arr = MDNode::get(context, builder.createString("array"));
  93. cast<Instruction>(gep->getOperand(0))->setMetadata(IE_ARRAY, arr);
  94. auto *RHS = MDNode::get(context, builder.createString("RHS"));
  95. auto *LHS = MDNode::get(context, builder.createString("LHS"));
  96. Instruction *dominant = nullptr;
  97. for (auto it = BB->begin(); it != BB->end(); ++it) {
  98. if (oBB.dominates(&*it, gep) && isa<LoadInst>(it)) {
  99. dominant = &*it;
  100. }
  101. }
  102. auto *dominated = gep->user_back();
  103. bool lhs = isa<LoadInst>(dominated);
  104. if (dominant != nullptr && dominated != nullptr) {
  105. if (lhs) {
  106. auto deadInst = MDNode::get(context, builder.createString("dead"));
  107. gepIT->setMetadata(IE_TYPE, RHS);
  108. gepIT->setMetadata(IE_DEAD, deadInst);
  109. setUsersMetadata(dominant, &*gepIT, deadInst);
  110. auto cnt = getInitial(dominant, &*gepIT);
  111. auto iterV = makeConstant(context, cnt);
  112. auto iterMD = MDNode::get(context, builder.createConstant(iterV));
  113. dominated->setMetadata(IE_ITER, iterMD);
  114. dominated->setMetadata(IE_DEAD, deadInst);
  115. dominant->setMetadata(IE_DEAD, deadInst);
  116. } else {
  117. gepIT->setMetadata(IE_TYPE, LHS);
  118. }
  119. }
  120. }
  121. }
  122. }
  123.  
  124. bool runOnLoop(Loop *loop, LPPassManager &LPM) override {
  125. LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  126. for (auto block : loop->getBlocks()) {
  127. if (LI->isLoopHeader(block)) {
  128. continue;
  129. }
  130. if (loop->isLoopExiting(block)) {
  131. continue;
  132. }
  133. SmallString<10> blockName(block->getName());
  134. StringRef requiredPrefix("for.body", 8);
  135. if (!blockName.startswith(requiredPrefix)) {
  136. continue;
  137. }
  138. setAllMetadata(block);
  139. }
  140. return true;
  141. }
  142. };
  143. }
  144.  
  145. char Pass1::ID = 0;
  146. static RegisterPass<Pass1> X(
  147. "pass1",
  148. "Predictive Commoning"
  149. );
  150.  
  151. namespace {
  152. struct Pass2 : public FunctionPass, public InstVisitor<Pass2> {
  153. public:
  154. static char ID;
  155.  
  156. Pass2() : FunctionPass(ID) {}
  157.  
  158. void getAnalysisUsage(AnalysisUsage &AU) const override {
  159. //AU.setPreservesCFG();
  160. AU.addRequired<LoopInfoWrapperPass>();
  161. AU.addPreserved<LoopInfoWrapperPass>();
  162. }
  163.  
  164. Constant* makeConstant(LLVMContext& context, int value) {
  165. return ConstantInt::get(Type::getInt32Ty(context), value);
  166. }
  167.  
  168. bool runOnFunction(Function &F) override {
  169. SmallString<10> fname(F.getName());
  170. if (fname.startswith("_")) {
  171. return false;
  172. }
  173.  
  174. LoopInfo& LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
  175. auto entry = &F.getEntryBlock();
  176. AllocaInst *arr = nullptr;
  177. for (auto i = entry->begin(); i != entry->end(); i++) {
  178. if(i->getMetadata(IE_ARRAY)) {
  179. arr = cast<AllocaInst>(i);
  180. }
  181. }
  182.  
  183. LLVMContext& ctx = F.getContext();
  184. BasicBlock* block = BasicBlock::Create(ctx, "predict", &F);
  185. IRBuilder<> builder(block);
  186.  
  187. for (auto loops = LI.begin(); loops != LI.end(); loops++) {
  188. Loop *L = *loops;
  189. BasicBlock *lbody;
  190. SmallVector<AllocaInst*, 4> newAllocs;
  191. SmallVector<Instruction*, 4> oldLoads;
  192. for (auto it = L->block_begin(); it != L->block_end(); ++it) {
  193. BasicBlock *BB = *it;
  194. SmallString<10> bname(BB->getName());
  195. if (bname.startswith("for.body")) {
  196. lbody = BB;
  197. break;
  198. }
  199. }
  200. for (auto inst = lbody->begin(); inst != lbody->end(); inst++) {
  201. if (isa<GetElementPtrInst>(inst)) {
  202. if (auto md = inst->getMetadata(IE_TYPE)) {
  203. auto type = cast<MDString>(md->getOperand(0))->getString();
  204. if (type == "LHS") {
  205. continue;
  206. }
  207. auto arrMD = inst->user_back()->getMetadata(IE_ITER);
  208. int index = mdconst::extract<ConstantInt>(arrMD->getOperand(0))->getSExtValue();
  209. auto p = builder.CreateAlloca(Type::getInt32Ty(ctx), nullptr, "p");
  210. p->setAlignment(4);
  211. newAllocs.push_back(p);
  212. oldLoads.push_back(inst->user_back());
  213. auto cnt = makeConstant(ctx, 0);
  214. auto arrindex = makeConstant(ctx, index);
  215. auto gepI = builder.CreateInBoundsGEP(arr, { cnt, arrindex });
  216. auto lI = builder.CreateAlignedLoad(gepI, 4);
  217. builder.CreateAlignedStore(lI, p, 4);
  218. } else {
  219. continue;
  220. }
  221. }
  222. }
  223. IRBuilder<> loopBuilder(lbody, lbody->getFirstInsertionPt());
  224. int i = 0;
  225. for (auto it : newAllocs) {
  226. auto newLoad = loopBuilder.CreateAlignedLoad(it, 4);
  227. oldLoads[i]->replaceAllUsesWith(newLoad);
  228. i++;
  229. }
  230. }
  231. builder.CreateBr(entry->getSingleSuccessor());
  232. entry->getTerminator()->setSuccessor(0, block);
  233. return true;
  234. }
  235. };
  236. }
  237.  
  238. char Pass2::ID = 0;
  239. static RegisterPass<Pass2> Y(
  240. "pass2",
  241. "PredictiveCommoning pass 2"
  242. );
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement