Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- // C++11 application to calculate wireguard AllowedIPs by PG.
- // License: Public Domain
- // To compile: g++ -std=c++11 ./allowedips.cpp -o allowedips
- //
- // Usage: allowedips [+-]spec [+-]spec [+-]spec ...
- // where spec can be
- // 1.2.3.4 // ipv4 address\n");
- // 1.2.3.4/32 // ipv4 address and mask\n");
- // 1.2.3.4-2.3.4.5 // ipv4 address range\n");
- // 2345:0425:2CA1:0000:0000:0567:5673:23b5 // ipv6 address\n");
- // ::/0 // ipv6 address and mask\n");
- // 2345:0425:2CA1:0000:0000:0567:5673:23b5-2345:0425:2CA1:0000:0000:0567:5673:23b6 // ipv6 address range\n");
- #include <stdio.h>
- #include <stdlib.h>
- #include <stdint.h>
- #include <string.h>
- #include <netinet/in.h>
- #include <arpa/inet.h>
- #include <string>
- #include <map>
- #include <vector>
- #define IS_LITTLE_ENDIAN (htonl(47)!=47)
- struct ip {
- sockaddr_storage sa={};
- int family() const { return sa.ss_family; }
- const sockaddr_in &ip4() const { return *(sockaddr_in*)&sa; }
- const sockaddr_in6 &ip6() const { return *(sockaddr_in6*)&sa; }
- uint8_t *s_addr() const { return is_ip4() ? (uint8_t*)&ip4().sin_addr.s_addr : (uint8_t*)&ip6().sin6_addr.s6_addr[0]; }
- uint8_t *s_addr_byte_ptr(const int idx) const { return s_addr() + (IS_LITTLE_ENDIAN ? s_addr_len()-1-idx : idx); }
- uint8_t s_addr_byte(const int idx) const { return *s_addr_byte_ptr(idx); }
- uint8_t s_addr_bit(const int idx) const { return (1<<(idx&7))&*s_addr_byte_ptr(idx>>3); }
- size_t s_addr_len() const { return is_ip4() ? sizeof(ip4().sin_addr.s_addr) : sizeof(ip6().sin6_addr.s6_addr); }
- int s_addr_bits() const { return 8*s_addr_len(); }
- int s_addr_cmp(const ip &o) const {
- for (int i=s_addr_len()-1;i>=0;i--) {
- if (s_addr_byte(i)<o.s_addr_byte(i)) { return -1; }
- if (s_addr_byte(i)>o.s_addr_byte(i)) { return 1; }
- }
- return 0;
- }
- bool is_ip4() const { return family()==AF_INET; }
- bool operator==(const ip &o) const { return s_addr_cmp(o)==0; }
- bool operator!=(const ip &o) const { return s_addr_cmp(o)!=0; }
- bool operator<(const ip &o) const { return s_addr_cmp(o)<0; }
- bool operator>(const ip &o) const { return s_addr_cmp(o)>0; }
- bool operator<=(const ip &o) const { return s_addr_cmp(o)<=0; }
- bool operator>=(const ip &o) const { return s_addr_cmp(o)>=0; }
- bool set(const std::string &s) {
- sa.ss_family=AF_INET6;
- if (inet_pton(sa.ss_family,s.c_str(),s_addr())==1)
- return true;
- sa.ss_family=AF_INET;
- return inet_pton(sa.ss_family,s.c_str(),s_addr())==1;
- }
- bool and_mask(const int mask_width) { // 192.168.1.x/24 => 192.168.1.0
- if (mask_width<0 || mask_width>s_addr_bits())
- return false;
- for (int i=0;i<s_addr_bits()-mask_width;i++)
- *s_addr_byte_ptr(i>>3)&=~(1<<(i&7));
- return true;
- }
- bool or_mask(const int mask_width) { // 192.168.1.x/24 => 192.168.1.255
- if (mask_width<0 || mask_width>s_addr_bits())
- return false;
- for (int i=0;i<s_addr_bits()-mask_width;i++)
- *s_addr_byte_ptr(i>>3)|=1<<(i&7);
- return true;
- }
- ip plus1() const { // return IP address + 1
- ip next(*this);
- for (int i=0;i<s_addr_len();i++) {
- if (++*next.s_addr_byte_ptr(i)) break;
- }
- return next;
- }
- ip minus1() const { // return IP address - 1
- ip prev(*this);
- for (int i=0;i<s_addr_len();i++) {
- if (--*prev.s_addr_byte_ptr(i)!=0xff) break;
- }
- return prev;
- }
- std::string to_str() const {
- char str[INET6_ADDRSTRLEN];
- return inet_ntop(family(),s_addr(),str,INET6_ADDRSTRLEN) ? str : "";
- }
- };
- struct range {
- ip start,end;
- bool set(const std::string &spec) {
- size_t msep=spec.find('/'); // 1.2.3.4/32
- if (msep!=std::string::npos) {
- int mask_width=atoi(spec.substr(msep+1).c_str());
- if (!start.set(spec.substr(0,msep)))
- return false;
- if (!start.and_mask(mask_width))
- return false;
- end=start;
- if (!end.or_mask(mask_width))
- return false;
- return true;
- }
- size_t rsep=spec.find('-'); // 1.2.3.4-2.3.4.5
- if (rsep!=std::string::npos) {
- if (!start.set(spec.substr(0,rsep)))
- return false;
- if (!end.set(spec.substr(rsep+1)))
- return false;
- if (end<start || start.family()!=end.family())
- return false;
- return true;
- }
- if (!start.set(spec)) // 1.2.3.4
- return false;
- end=start;
- return true;
- }
- bool is_ip4() const { return start.is_ip4(); }
- int family() const { return start.family(); }
- std::string to_str() const { return start.to_str()+'-'+end.to_str(); }
- };
- struct RangeSet {
- int family;
- std::map<ip,range> ranges;
- RangeSet(const int family=AF_INET) : family(family) { }
- bool remove_range(const range &b) {
- for (auto i=ranges.begin();i!=ranges.end();) {
- range &a=i->second;
- if (b.start>a.end) { i++; continue; }
- if (b.end<a.start) break;
- if (b.start<=a.start && b.end>=a.end) { // remove a completely
- auto x=i++;
- ranges.erase(x);
- } else if (b.start>a.start && b.end<a.end) { // remove range in middle of a
- range n;
- n.start=b.end.plus1();
- n.end=a.end;
- ranges[n.start]=n;
- a.end=b.start.minus1();
- break;
- } else if (b.start<=a.start) { // trim left
- a.start=b.end.plus1();
- break;
- } else { // trim right
- a.end=b.start.minus1();
- i++;
- }
- }
- return true;
- }
- bool remove_range(const std::string &spec) {
- range b;
- return b.set(spec) && b.family()==family ? remove_range(b) : false;
- }
- bool add_range(const std::string &spec) {
- range b;
- if (!b.set(spec) || b.family()!=family)
- return false;
- // remove any overlaps with existing ranges
- if (!remove_range(b))
- return false;
- // add the new range
- ranges[b.start]=b;
- // merge any adjacent ranges
- if (ranges.size()>1) {
- for (auto i=ranges.begin();;) {
- auto next=i;
- if (++next==ranges.end())
- break;
- range &a=i->second,&b=next->second;
- if (a.end.plus1()==b.start) {
- range n=b;
- n.start=a.start;
- ranges.erase(i);
- ranges.erase(next);
- next=ranges.insert({n.start,n}).first;
- }
- i=next;
- }
- }
- return true;
- }
- std::multimap<int,ip> get_cidrs() {
- std::multimap<int,ip> cidrs; // width,ip
- const auto ranges_copy=ranges;
- // Find CIDR ranges by iteratively finding and removing
- // the shortest CIDR mask possible from each range
- while (!ranges.empty()) {
- range &r=ranges.begin()->second;
- // find shortest CIDR mask (largest IP range) that fits within range r
- range x;
- int mask_width=0;
- for (;;) {
- x.start=x.end=r.start;
- x.start.and_mask(mask_width);
- x.end.or_mask(mask_width);
- if (x.start>=r.start && x.end<=r.end)
- break;
- mask_width++;
- }
- cidrs.insert({mask_width,x.start});
- remove_range(x);
- }
- ranges=ranges_copy; // restore range data
- return cidrs;
- }
- };
- static RangeSet ip4ranges(AF_INET),ip6ranges(AF_INET6);
- static void output_cidrs() {
- printf("AllowedIPs =");
- auto ip4cidrs=ip4ranges.get_cidrs();
- int count=0;
- for (auto &p : ip4cidrs)
- printf("%s%s/%d",count++?", ":" ",p.second.to_str().c_str(),p.first);
- auto ip6cidrs=ip6ranges.get_cidrs();
- for (auto &p : ip6cidrs)
- printf("%s%s/%d",count++?", ":" ",p.second.to_str().c_str(),p.first);
- printf("\n");
- }
- int main(int argc,char **argv) {
- bool err=argc<2;
- for (int i=1;!err && i<argc;i++) {
- bool ok=false;
- if (argv[i][0]=='+') {
- ok=ip6ranges.add_range(argv[i]+1);
- if (!ok)
- ok=ip4ranges.add_range(argv[i]+1);
- } else if (argv[i][0]=='-') {
- ok=ip6ranges.remove_range(argv[i]+1);
- if (!ok)
- ok=ip4ranges.remove_range(argv[i]+1);
- }
- if (!ok) {
- fprintf(stderr,"error - bad arg: %s\n",argv[i]);
- err=true;
- }
- }
- if (err) {
- fprintf(stderr,"Usage: %s [+-]spec [+-]spec [+-]spec ...\n",argv[0]);
- fprintf(stderr," where spec can be\n");
- fprintf(stderr," 1.2.3.4 // ipv4 address\n");
- fprintf(stderr," 1.2.3.4/32 // ipv4 address and mask\n");
- fprintf(stderr," 1.2.3.4-2.3.4.5 // ipv4 address range\n");
- fprintf(stderr," 2345:0425:2CA1:0000:0000:0567:5673:23b5 // ipv6 address\n");
- fprintf(stderr," ::/0 // ipv6 address and mask\n");
- fprintf(stderr," 2345:0425:2CA1:0000:0000:0567:5673:23b5-2345:0425:2CA1:0000:0000:0567:5673:23b6 // ipv6 address range\n");
- exit(EXIT_FAILURE);
- }
- output_cidrs();
- return 0;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement