Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- // main.cpp
- #include "libstat.h"
- #include <iostream>
- using namespace std;
- int main() {
- statistics::tests::AggregSum();
- statistics::tests::AggregMax();
- statistics::tests::AggregMean();
- statistics::tests::AggregStandardDeviation();
- statistics::tests::AggregMode();
- statistics::tests::AggregPrinter();
- cout << "Test passed!"sv << endl;
- }
- // libstat.h
- #pragma once
- #include <algorithm>
- #include <cmath>
- #include <iostream>
- #include <optional>
- #include <string>
- #include <string_view>
- #include <unordered_map>
- using namespace std::literals;
- namespace statistics {
- namespace aggregations {
- class Sum { // сумма
- public:
- void PutValue(double value);
- std::optional<double> Get() const;
- static std::string_view GetValueName() {
- return "sum"sv;
- }
- private:
- double sum_ = 0;
- };
- class Max { // максимум
- public:
- void PutValue(double value);
- std::optional<double> Get() const;
- static std::string_view GetValueName() {
- return "max"sv;
- }
- private:
- std::optional<double> cur_max_;
- };
- class Mean { // среднее арифметическое
- public:
- void PutValue(double value);
- std::optional<double> Get() const;
- static std::string_view GetValueName() {
- return "mean"sv;
- }
- private:
- aggregations::Sum sum_;
- size_t count_ = 0;
- };
- class StandardDeviation { // стандартное отклонение
- public:
- void PutValue(double value);
- std::optional<double> Get() const;
- static std::string_view GetValueName() {
- return "standard deviation"sv;
- }
- private:
- aggregations::Sum sum_;
- aggregations::Sum sum_sq_;
- size_t count_ = 0;
- };
- class Mode { // мода
- public:
- void PutValue(double value);
- std::optional<double> Get() const;
- static std::string_view GetValueName() {
- return "mode"sv;
- }
- private:
- std::unordered_map<double, size_t> counts_;
- std::optional<double> cur_max_;
- size_t cur_count_ = 0;
- };
- } // конец namespace statistics::aggregations
- namespace tests {
- void AggregSum();
- void AggregMax();
- void AggregMean();
- void AggregStandardDeviation();
- void AggregMode();
- void AggregPrinter();
- } // конец namespace tests
- template <typename Aggreg>
- class AggregPrinter {
- public:
- void PutValue(double value) {
- inner_.PutValue(value);
- }
- void Print(std::ostream& out) const {
- auto val = inner_.Get();
- out << inner_.GetValueName() << " is "sv;
- if (val) {
- out << *val;
- } else {
- out << "undefined"sv;
- }
- out << std::endl;
- }
- private:
- Aggreg inner_;
- };
- } // конец namespace statistics
- // libstat.cpp
- #include "libstat.h"
- void statistics::aggregations::Sum::PutValue(double value) {
- sum_ += value;
- }
- std::optional<double> statistics::aggregations::Sum::Get() const {
- return sum_;
- }
- void statistics::aggregations::Max::PutValue(double value) {
- cur_max_ = std::max(value, cur_max_.value_or(value));
- }
- std::optional<double> statistics::aggregations::Max::Get() const {
- return cur_max_;
- }
- void statistics::aggregations::Mean::PutValue(double value) {
- sum_.PutValue(value);
- ++count_;
- }
- std::optional<double> statistics::aggregations::Mean::Get() const {
- auto val = sum_.Get();
- if (!val || count_ == 0) {
- return std::nullopt;
- }
- return *val / count_;
- }
- void statistics::aggregations::StandardDeviation::PutValue(double value) {
- sum_.PutValue(value);
- sum_sq_.PutValue(value * value);
- ++count_;
- }
- std::optional<double> statistics::aggregations::StandardDeviation::Get() const {
- auto val = sum_.Get();
- auto val2 = sum_sq_.Get();
- if (!val || !val2 || count_ < 2) {
- return std::nullopt;
- }
- return ::std::sqrt((*val2 - *val * *val / count_) / count_);
- }
- void statistics::aggregations::Mode::PutValue(double value) {
- const size_t new_count = ++counts_[round(value)];
- if (new_count > cur_count_) {
- cur_max_ = value;
- cur_count_ = new_count;
- }
- }
- std::optional<double> statistics::aggregations::Mode::Get() const {
- return cur_max_;
- }
- // libstat_test.cpp
- #include "libstat.h"
- #include <cassert>
- #include <cmath>
- #include <sstream>
- namespace statistics::tests::detail {
- template <typename T>
- std::string GetPrinterValue(statistics::AggregPrinter<T>& printer) {
- std::ostringstream out;
- printer.Print(out);
- return std::move(out).str();
- }
- } // конец namespace detail
- void statistics::tests::AggregSum() {
- statistics::aggregations::Sum aggreg;
- assert(*aggreg.Get() == 0);
- aggreg.PutValue(10.);
- aggreg.PutValue(20.);
- aggreg.PutValue(-40.);
- assert(*aggreg.Get() == -10.);
- }
- void statistics::tests::AggregMax() {
- statistics::aggregations::Max aggreg;
- assert(!aggreg.Get());
- aggreg.PutValue(10.);
- aggreg.PutValue(20.);
- aggreg.PutValue(-40.);
- assert(*aggreg.Get() == 20.);
- }
- void statistics::tests::AggregMean() {
- statistics::aggregations::Mean aggreg;
- assert(!aggreg.Get());
- aggreg.PutValue(10.);
- aggreg.PutValue(20.);
- aggreg.PutValue(-40.);
- aggreg.PutValue(30.);
- assert(*aggreg.Get() == 5.);
- }
- void statistics::tests::AggregStandardDeviation() {
- statistics::aggregations::StandardDeviation aggreg;
- assert(!aggreg.Get());
- aggreg.PutValue(10.);
- aggreg.PutValue(10.);
- aggreg.PutValue(10.);
- aggreg.PutValue(10.);
- assert(std::abs(*aggreg.Get()) < 1e-5);
- aggreg.PutValue(20.);
- aggreg.PutValue(20.);
- aggreg.PutValue(20.);
- aggreg.PutValue(20.);
- assert(std::abs(*aggreg.Get() - 5.) < 1e-5);
- }
- void statistics::tests::AggregMode() {
- statistics::aggregations::Mode aggreg;
- assert(!aggreg.Get());
- aggreg.PutValue(1.1);
- aggreg.PutValue(0.9);
- aggreg.PutValue(2.1);
- aggreg.PutValue(2.2);
- aggreg.PutValue(2.1);
- aggreg.PutValue(-1.0);
- aggreg.PutValue(3.0);
- aggreg.PutValue(3.0);
- aggreg.PutValue(1000.);
- assert(std::round(*aggreg.Get()) == 2.);
- }
- void statistics::tests::AggregPrinter() {
- statistics::AggregPrinter<statistics::aggregations::Max> printer;
- assert(statistics::tests::detail::GetPrinterValue(printer) == "max is undefined\n"s);
- printer.PutValue(10.);
- printer.PutValue(20.);
- printer.PutValue(-40.);
- std::ostringstream out;
- out << 20.;
- assert(statistics::tests::detail::GetPrinterValue(printer) == "max is "s + out.str() + "\n"s);
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement