Advertisement
homer512

segfault handler

Oct 27th, 2014
189
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 9.84 KB | None | 0 0
  1. /**
  2.  * Demonstrates the use of a signal handler to create data lazily
  3.  *
  4.  * Compile with -std=gnu++11
  5.  *
  6.  *
  7.  * Copyright 2014 Florian Philipp
  8.  *
  9.  * Licensed under the Apache License, Version 2.0 (the "License");
  10.  * you may not use this file except in compliance with the License.
  11.  * You may obtain a copy of the License at
  12.  *
  13.  *  http://www.apache.org/licenses/LICENSE-2.0
  14.  *
  15.  * Unless required by applicable law or agreed to in writing, software
  16.  * distributed under the License is distributed on an "AS IS" BASIS,
  17.  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18.  * See the License for the specific language governing permissions and
  19.  * limitations under the License.
  20.  */
  21.  
  22. #include <set>
  23. // using std::set
  24. #include <algorithm>
  25. // using std::fill
  26. #include <exception>
  27. // using std::terminate, std::exception
  28. #include <stdexcept>
  29. // using std::logic_error
  30. #include <system_error>
  31. // using std::system_error, std::system_category
  32. #include <cerrno>
  33. // using errno
  34. #include <cassert>
  35. // using assert
  36. #include <cstdio>
  37. // using std::fprintf
  38. #include <cstddef>
  39. // using std::size_t
  40. #include <signal.h>
  41. // using sigaction
  42. #include <sys/mman.h>
  43. // using mprotect, mmap
  44. #include <unistd.h>
  45. // using sysconf
  46.  
  47.  
  48. /**
  49.  * Private C++ implementation follows
  50.  */
  51. namespace {
  52.  
  53.   /**
  54.    * Returns bytes per memory page
  55.    */
  56.   std::size_t pagelen()
  57.   {
  58.     return sysconf(_SC_PAGESIZE);
  59.   }
  60.  
  61.   /**
  62.    * Segmentation fault handler
  63.    *
  64.    * Uses a dedicated memory area that is filled with data on demand.
  65.    * A singleton in production code, but not necessarily in testing code
  66.    *
  67.    * TODO: Make thread-safe
  68.    *
  69.    * BUG: Not async-signal-safe
  70.    */
  71.   class SegvHandler
  72.   {
  73.     typedef std::set<void*>::iterator iterator;
  74.  
  75.     /**
  76.      * pointer to the start of the dedicated memory area or nullptr
  77.      */
  78.     void* base_addr;
  79.     /**
  80.      * pointer past the end of the dedicated memory area or nullptr
  81.      */
  82.     void* mapped_end;
  83.     /**
  84.      * Addresses of pages with read or read/write access permissions
  85.      */
  86.     std::set<void*> mapped_ro, mapped_rw;
  87.  
  88.     /**
  89.      * Fills page with meaningful data
  90.      *
  91.      * TODO: Stub
  92.      */
  93.     void populate_page(void* page)
  94.     {
  95.       unsigned* typed = static_cast<unsigned*>(page);
  96.       unsigned* page_end = typed + pagelen() / sizeof(*typed);
  97.       std::fill(typed, page_end, 0xDEADBEEF);
  98.     }
  99.  
  100.     /**
  101.      * Flushes changed page to shared storage or whatever
  102.      *
  103.      * TODO: Stub
  104.      */
  105.     void commit_page(void* page)
  106.     {
  107.       char* byte_addr = static_cast<char*>(page);
  108.       char* byte_base = static_cast<char*>(this->base_addr);
  109.       std::size_t offset = byte_addr - byte_base;
  110.       std::printf("Range [%zu, %zu) changed\n", offset, offset + pagelen());
  111.     }
  112.  
  113.     /**
  114.      * Returns page address of segfault
  115.      */
  116.     static void* get_page(const siginfo_t* siginfo)
  117.     {
  118.       std::size_t addr = reinterpret_cast<std::size_t>(siginfo->si_addr);
  119.       addr &= ~(pagelen() - 1);
  120.       return reinterpret_cast<void*>(addr);
  121.     }
  122.  
  123.     /**
  124.      * Invokes handler for segmentation faults that cannot be handled otherwise
  125.      */
  126.     static void real_segfault(const siginfo_t* siginfo)
  127.     {
  128.       // TODO: Replace with original segfault handler
  129.       std::fprintf(stderr, "SEGMENTATION FAULT %p\n", siginfo->si_addr);
  130.       std::terminate();
  131.     }
  132.  
  133.     /**
  134.      * Throws an std::system_error constructed from errno
  135.      */
  136.     static void throw_sys_err(const char* what)
  137.     {
  138.       throw std::system_error(errno, std::system_category(), what);
  139.     }
  140.  
  141.     /**
  142.      * Invokes mprotect. Wraps errors in std::system_error
  143.      */
  144.     static void change_protection(void* page, int permissions,
  145.                   std::size_t length = pagelen())
  146.     {
  147.       if(mprotect(page, length, permissions))
  148.     throw_sys_err("mprotect");
  149.     }
  150.  
  151.     /**
  152.      * Permits write access to page. Marks page as dirty
  153.      *
  154.      * Precondition: page is removed from this->mapped_ro
  155.      */
  156.     void map_rw(void* page)
  157.     {
  158.       change_protection(page, PROT_READ | PROT_WRITE);
  159.       this->mapped_rw.insert(page);
  160.     }
  161.  
  162.     /**
  163.      * Populates page and permits read access
  164.      */
  165.     void map_ro(void* page)
  166.     {
  167.       change_protection(page, PROT_READ | PROT_WRITE);
  168.       this->populate_page(page);
  169.       change_protection(page, PROT_READ);
  170.       this->mapped_ro.insert(page);
  171.     }
  172.  
  173.     /**
  174.      * Returns true if this segfault can be avoided
  175.      */
  176.     bool is_magic_segfault(const siginfo_t* siginfo) const
  177.     {
  178.       void* addr = siginfo->si_addr;
  179.       return siginfo->si_code == SEGV_ACCERR
  180.     && addr >= this->base_addr && addr < this->mapped_end
  181.     && ! this->mapped_rw.count(get_page(siginfo));
  182.     }
  183.  
  184.     /**
  185.      * Signal handler compatible with sigaction
  186.      */
  187.     static void signal_handler(int signum, siginfo_t* siginfo, void* ucontext)
  188.     {
  189.       try {
  190.     assert(signum == SIGSEGV);
  191.     SegvHandler* self = global_self();
  192.     void* page = get_page(siginfo);
  193.     if(! self->is_magic_segfault(siginfo))
  194.       real_segfault(siginfo);
  195.     if(self->mapped_ro.erase(page))
  196.       self->map_rw(page);
  197.     else
  198.       self->map_ro(page);
  199.       } catch(std::exception& err) {
  200.     std::fprintf(stderr, "segfault handler: %s\n", err.what());
  201.     std::terminate();
  202.       }
  203.     }
  204.  
  205.   public:
  206.     /**
  207.      * Initializes empty, unmapped SegvHandler
  208.      */
  209.     SegvHandler()
  210.       : base_addr(nullptr),
  211.     mapped_end(nullptr)
  212.     {}
  213.  
  214.     /**
  215.      * Returns singleton
  216.      */
  217.     static SegvHandler* global_self()
  218.     {
  219.       static SegvHandler* self = new SegvHandler();
  220.       return self;
  221.     }
  222.  
  223.     /**
  224.      * Installs the global signal handler
  225.      */
  226.     static void install()
  227.     {
  228.       struct sigaction action;
  229.       action.sa_sigaction = &SegvHandler::signal_handler;
  230.       action.sa_mask = sigset_t();
  231.       action.sa_flags = SA_SIGINFO;
  232.       if(sigaction(SIGSEGV, &action, nullptr))
  233.     throw_sys_err("sigaction");
  234.     }
  235.  
  236.     /**
  237.      * Allocates the dedicated memory area
  238.      */
  239.     void init_mapping()
  240.     {
  241.       if(this->base_addr)
  242.     throw std::logic_error("segfault handler double initialization");
  243.       const std::size_t mapping_len = pagelen() * 16; // TODO: placeholder
  244.       void* mapped = mmap(nullptr, mapping_len, PROT_READ | PROT_WRITE,
  245.               MAP_PRIVATE | MAP_ANONYMOUS, -1, 0);
  246.       if(mapped == MAP_FAILED)
  247.     throw_sys_err("mmap");
  248.       change_protection(mapped, PROT_NONE, mapping_len);
  249.       this->base_addr = mapped;
  250.       this->mapped_end = static_cast<char*>(mapped) + mapping_len;
  251.     }
  252.  
  253.     /**
  254.      * Commits all dirty pages and marks them as clean
  255.      */
  256.     void commit_changes()
  257.     {
  258.       iterator first = this->mapped_rw.begin();
  259.       iterator last = this->mapped_rw.end();
  260.       while(first != last) {
  261.     void* page = *first;
  262.     this->commit_page(page);
  263.     change_protection(page, PROT_READ);
  264.     this->mapped_rw.erase(first++);
  265.     this->mapped_ro.insert(page);
  266.       }
  267.     }
  268.  
  269.     /**
  270.      * Discards all dirty pages. Will re-populate them on demand
  271.      */
  272.     void discard_changes()
  273.     {
  274.       iterator first = this->mapped_rw.begin();
  275.       iterator last = this->mapped_rw.end();
  276.       while(first != last) {
  277.     void* page = *first;
  278.     change_protection(page, PROT_NONE);
  279.     this->mapped_rw.erase(first++);
  280.       }      
  281.     }
  282.  
  283.     /**
  284.      * Returns starting address of the dedicated memory area
  285.      */
  286.     void* get_base_addr() const
  287.     { return this->base_addr; }
  288.   };
  289.  
  290.   /**
  291.    * Wraps a void functor so that is compatible to C
  292.    *
  293.    * Exceptions are converted to errno values.
  294.    * Exceptions that indicate programming errors result in EINVAL
  295.    * and output on stderr.
  296.    *
  297.    * \return 0 on success, -1 on exceptions
  298.    */
  299.   template<class Callable>
  300.   int c_style_call(Callable&& function)
  301.   {
  302.     try {
  303.       function();
  304.       return 0;
  305.     } catch(std::system_error& err) {
  306.       errno = err.code().value();
  307.     } catch(std::bad_alloc& err) {
  308.       errno = ENOMEM;
  309.     } catch(std::exception& err) {
  310.       std::fprintf(stderr, "segfault handler: %s\n", err.what());
  311.       errno = EINVAL;
  312.     }
  313.     return -1;
  314.   }
  315.  
  316. } // namespace
  317.  
  318.  
  319. /**
  320.  * Public C interface follows
  321.  */
  322. extern "C" {
  323.  
  324.   /**
  325.    * Installs and initializes the segmentation fault handler
  326.    *
  327.    * \return 0 on success, -1 otherwise. Sets errno
  328.    */
  329.   int sigsegv_install()
  330.   {
  331.     auto lambda = []() {
  332.       SegvHandler::global_self()->init_mapping();
  333.       SegvHandler::install();
  334.     };
  335.     return c_style_call(lambda);
  336.   }
  337.  
  338.   /**
  339.    * Commits all changes
  340.    *
  341.    * \return 0 on success, -1 otherwise. Sets errno
  342.    */
  343.   int sigsegv_commit()
  344.   {
  345.     auto lambda = []() {
  346.       SegvHandler::global_self()->commit_changes();
  347.     };
  348.     return c_style_call(lambda);
  349.   }
  350.  
  351.   /**
  352.    * Discards all changes
  353.    *
  354.    * \return 0 on success, -1 otherwise. Sets errno
  355.    */
  356.   int sigsegv_discard()
  357.   {
  358.     auto lambda = []() {
  359.       SegvHandler::global_self()->discard_changes();
  360.     };
  361.     return c_style_call(lambda);
  362.   }
  363.  
  364.   /**
  365.    * Returns base address of the dedicated memory area
  366.    */
  367.   void* sigsegv_baseptr()
  368.   {
  369.     void* base = nullptr;
  370.     auto lambda = [&base]() {
  371.       base = SegvHandler::global_self()->get_base_addr();
  372.     };
  373.     c_style_call(lambda);
  374.     return base;
  375.   }
  376.  
  377. } // extern "C"
  378.  
  379.  
  380. /**
  381.  * Some simple testing code
  382.  *
  383.  * Observe it with strace
  384.  */
  385. int main()
  386. {
  387.   sigsegv_install();
  388.   unsigned* base = static_cast<unsigned*>(sigsegv_baseptr());
  389.   std::printf("Accessing RO %p = 0x%x\n", base + 3, base[3]);
  390.   base[3] = 0;
  391.   std::puts("Committing");
  392.   sigsegv_commit();
  393.   std::size_t otherpage = pagelen()/sizeof(unsigned) + 3;
  394.   std::puts("Making direct RW access");
  395.   base[otherpage] = 0;
  396.   std::puts("Discarding");
  397.   sigsegv_discard();
  398.   std::printf("Accessing RO %p = 0x%x\n", base + otherpage, base[otherpage]);
  399.   return 0;
  400. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement