Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- // database.h
- #pragma once
- #ifndef DATABASE_H
- #define DATABASE_H
- #include <boost/asio.hpp>
- #include <boost/unordered_map.hpp>
- #include <boost/shared_ptr.hpp>
- #include <boost/scoped_ptr.hpp>
- #include <boost/function.hpp>
- #include <boost/thread/mutex.hpp>
- #include <boost/thread/locks.hpp>
- #include <libpq-fe.h>
- #include <string>
- #include <queue>
- #include "query.h"
- #include "result.h"
- #include "notification_handler.h"
- void null_handler(result&);
- class database
- {
- public:
- database(boost::asio::io_service& io_service, std::string host, unsigned short port, std::string name, std::string user, std::string password);
- ~database();
- result exec(const query& query);
- typedef boost::function<void(result&)> result_handler;
- static boost::arg<1> result_placeholder;
- void async_exec(const query& query, result_handler handler = null_handler);
- bool connected() const;
- std::string error_message() const;
- void add_notification_handler(const std::string& channel, boost::shared_ptr<notification_handler> handler);
- private:
- database(const database& other);
- database& operator=(const database& other);
- std::string escape_connection_param(const std::string& escape_string);
- std::string escape_string(const std::string& escape_string);
- void on_data_available(const boost::system::error_code& error);
- void start_waiting_for_data();
- void handle_notifications();
- void handle_results();
- void send_next_command();
- PGconn* connection_;
- boost::asio::ip::tcp::socket socket_;
- boost::unordered_map<std::string, boost::shared_ptr<notification_handler> > notification_handlers_;
- boost::mutex exec_mutex_;
- result_handler result_handler_;
- std::queue<std::pair<query, result_handler> > exec_queue_;
- };
- #endif
- // database.cpp
- #include "database.h"
- #include <sstream>
- #include <algorithm>
- #include <iostream>
- #include <boost/bind.hpp>
- #include <boost/format.hpp>
- #include <cassert>
- #include "log.h"
- boost::arg<1> database::result_placeholder;
- void null_handler(result&)
- {
- }
- void log_error_result(const PGresult* result)
- {
- log::error("Database Error: %1%", PQresultErrorField(result, PG_DIAG_MESSAGE_PRIMARY));
- const char* detail = PQresultErrorField(result, PG_DIAG_MESSAGE_DETAIL);
- if (detail)
- {
- log::error("%1%", detail);
- }
- }
- void notice_receiver(void* arg, const PGresult* result)
- {
- log_error_result(result);
- }
- database::database(boost::asio::io_service& io_service, std::string host, unsigned short port, std::string name, std::string user, std::string password)
- : socket_(io_service), connection_(0)
- {
- // Assuming TCP is used to connect to Postgres, this code would have to be changed for Unix socket descriptors
- std::stringstream ss;
- ss << "host='" << escape_connection_param(host) << "' port=" << port << " dbname='" << escape_connection_param(name) << "' user='" << escape_connection_param(user) << "' password='" << escape_connection_param(password) << "'";
- connection_ = PQconnectdb(ss.str().c_str());
- if (PQstatus(connection_) == CONNECTION_OK)
- {
- PQsetNoticeReceiver(connection_, notice_receiver, 0);
- if (PQsetnonblocking(connection_, 1) != 0)
- {
- throw std::runtime_error("PQsetnonblocking failed");
- }
- socket_.assign(boost::asio::ip::tcp::v4(), PQsocket(connection_));
- start_waiting_for_data();
- log::success("Connected to database '%1%' as user '%2%' on %3%:%4%", name, user, host, port);
- result res = exec(query::create("SELECT version()"));
- log::info("Database Version: %1%", res.get(0, 0));
- }
- else
- {
- throw std::runtime_error(boost::str(boost::format("Database connection failed! %1%") % error_message()));
- }
- }
- database::~database()
- {
- if (connection_)
- {
- PQfinish(connection_);
- }
- }
- bool database::connected() const
- {
- return socket_.is_open();
- }
- std::string database::error_message() const
- {
- return PQerrorMessage(connection_);
- }
- std::string database::escape_connection_param(const std::string& escape_string)
- {
- std::string new_string;
- size_t needed = escape_string.size() + std::count(escape_string.begin(), escape_string.end(), '\'') + std::count(escape_string.begin(), escape_string.end(), '\\');
- new_string.reserve(needed);
- for (std::string::const_iterator iter = escape_string.begin(); iter != escape_string.end(); ++iter)
- {
- switch(*iter)
- {
- case '\'':
- new_string.append("\\'");
- break;
- case '\\':
- new_string.append("\\\\");
- break;
- default:
- new_string.append(1, *iter);
- }
- }
- return new_string;
- }
- void database::start_waiting_for_data()
- {
- socket_.async_receive(boost::asio::null_buffers(), boost::bind(&database::on_data_available, this, boost::asio::placeholders::error));
- }
- void database::on_data_available(const boost::system::error_code& error)
- {
- if (!error)
- {
- if (PQconsumeInput(connection_) == 1)
- {
- if (PQisBusy(connection_) == 0)
- {
- handle_notifications();
- handle_results();
- }
- start_waiting_for_data();
- }
- else
- {
- log::error("Database Error: %1%", error_message());
- }
- }
- }
- void database::async_exec(const query& query, result_handler handler)
- {
- if (query.valid())
- {
- boost::lock_guard<boost::mutex> lock(exec_mutex_);
- if (!result_handler_)
- {
- if (1 == PQsendQueryParams(connection_, query.command(), query.num_params(), query.param_types(), query.param_values(), query.param_lengths(), query.param_formats(), query.binary_results() ? 1 : 0))
- {
- result_handler_ = handler;
- }
- else
- {
- log::error("Database Error: %1%", error_message());
- }
- }
- else
- {
- exec_queue_.push(std::make_pair(query, handler));
- }
- }
- else
- {
- throw std::runtime_error("Invalid query object");
- }
- }
- result database::exec(const query& query)
- {
- if (query.valid())
- {
- {
- // Make sure asynchronous executions will get queued during the synchronous execution
- boost::lock_guard<boost::mutex> lock(exec_mutex_);
- result_handler_ = null_handler;
- }
- PGresult *result = PQexecParams(connection_, query.command(), query.num_params(), query.param_types(), query.param_values(), query.param_lengths(), query.param_formats(), query.binary_results() ? 1 : 0);
- if (!result)
- {
- throw std::runtime_error(boost::str(boost::format("Fatal database error! %1%") % error_message()));
- }
- ExecStatusType status = PQresultStatus(result);
- if (status != PGRES_COMMAND_OK && status != PGRES_TUPLES_OK)
- {
- log_error_result(result);
- }
- handle_notifications();
- boost::lock_guard<boost::mutex> lock(exec_mutex_);
- result_handler_.clear();
- send_next_command();
- return result;
- }
- else
- {
- throw std::runtime_error("Invalid query object");
- }
- }
- void database::handle_notifications()
- {
- while (PGnotify *notify = PQnotifies(connection_))
- {
- boost::unordered_map<std::string, boost::shared_ptr<notification_handler> >::iterator handler = notification_handlers_.find(notify->relname);
- if (handler != notification_handlers_.end())
- {
- (*handler->second)(notify->extra);
- }
- else
- {
- log::warning("Unhandled database notification: Channel = %1%, Message = %2%", notify->relname, notify->extra);
- }
- PQfreemem(notify);
- }
- }
- void database::handle_results()
- {
- if (PGresult *result = PQgetResult(connection_))
- {
- std::vector<::result> results;
- do
- {
- ExecStatusType status = PQresultStatus(result);
- if (status != PGRES_COMMAND_OK && status != PGRES_TUPLES_OK)
- {
- log_error_result(result);
- }
- else
- {
- results.push_back(result);
- }
- } while (result = PQgetResult(connection_));
- result_handler handler;
- {
- boost::lock_guard<boost::mutex> lock(exec_mutex_);
- assert (result_handler_ && "result_handler_ must be bound when a query result is ready");
- std::swap(handler, result_handler_);
- send_next_command();
- }
- for (std::vector<::result>::iterator iter = results.begin(); iter != results.end(); ++iter)
- {
- handler(*iter);
- }
- }
- }
- void database::send_next_command()
- {
- if (!exec_queue_.empty())
- {
- std::pair<query, result_handler>& queued_command = exec_queue_.front();
- query& query = queued_command.first;
- result_handler& handler = queued_command.second;
- if (1 == PQsendQueryParams(connection_, query.command(), query.num_params(), query.param_types(), query.param_values(), query.param_lengths(), query.param_formats(), query.binary_results() ? 1 : 0))
- {
- result_handler_ = handler;
- }
- else
- {
- log::error("Database Error: %1%", error_message());
- }
- exec_queue_.pop();
- }
- }
- std::string database::escape_string(const std::string& escape_string)
- {
- int error = 0;
- boost::scoped_ptr<char> buf(new char[escape_string.size() * 2 + 1]);
- size_t new_size = PQescapeStringConn(connection_, buf.get(), escape_string.c_str(), escape_string.size(), &error);
- if (error)
- {
- log::error("Database Error: %1%", error_message());
- }
- return std::string(buf.get(), buf.get() + new_size);
- }
- void database::add_notification_handler(const std::string& channel, boost::shared_ptr<notification_handler> handler)
- {
- notification_handlers_[channel] = handler;
- exec(query::create("LISTEN " + escape_string(channel)));
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement