Advertisement
Guest User

Untitled

a guest
Aug 20th, 2018
146
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Go 5.46 KB | None | 0 0
  1. package httd
  2.  
  3. import (
  4.     "net/http"
  5.     "strconv"
  6.     "strings"
  7.     "sync"
  8.     "time"
  9.  
  10.     "github.com/andersfylling/snowflake"
  11.     "encoding/json"
  12. )
  13.  
  14. const (
  15.     XRateLimitLimit     = "X-RateLimit-Limit"
  16.     XRateLimitRemaining = "X-RateLimit-Remaining"
  17.     XRateLimitReset     = "X-RateLimit-Reset"
  18.     XRateLimitGlobal    = "X-RateLimit-Global"
  19.     RateLimitRetryAfter = "Retry-After"
  20. )
  21.  
  22. // const
  23. var majorEndpointPrefixes = []string{
  24.     "/channels/",
  25.     "/guilds/",
  26.     "/webhooks/",
  27. }
  28.  
  29. // TODO: fix ratelimiting logic
  30. func RatelimitChannel(id snowflake.ID) string {
  31.     return "c:" + id.String()
  32. }
  33.  
  34. func RatelimitGuild(id snowflake.ID) string {
  35.     return "g:" + id.String()
  36. }
  37.  
  38. func RatelimitWebhook() string {
  39.     return "wh"
  40. }
  41.  
  42. func RatelimitUsers() string {
  43.     return "u"
  44. }
  45.  
  46. type RateLimiter interface {
  47.     Bucket(key string) *Bucket
  48.     RateLimitTimeout(key string) int64
  49.     RateLimited(key string) bool
  50.     HandleResponse(key string, res *http.Response, responseBody []byte)
  51. }
  52.  
  53. func NewRateLimit() *RateLimit {
  54.     return &RateLimit{
  55.         buckets: make(map[string]*Bucket),
  56.         global: &Bucket{
  57.             global: true,
  58.         },
  59.     }
  60. }
  61.  
  62. // RateLimit
  63. // TODO: a bucket is created for every request. Might want to delete them after a while. seriously.
  64. // `/users/1` has the same ratelimiter as `/users/2`
  65. // but any major endpoint prefix does not: `/channels/1` != `/channels/2`
  66. type RateLimit struct {
  67.     buckets map[string]*Bucket
  68.     global  *Bucket
  69.     mu      sync.RWMutex
  70. }
  71.  
  72. func (r *RateLimit) Bucket(key string) *Bucket {
  73.     var bucket *Bucket
  74.     var exists bool
  75.  
  76.     // check for major endpoints
  77.     // TODO: this feels frail
  78.     var endpoint string
  79.     for _, major := range majorEndpointPrefixes {
  80.         if !strings.HasPrefix(key, major) {
  81.             continue
  82.         }
  83.         pathAfterMajor := strings.TrimPrefix(key, major)
  84.         endpoint = major
  85.         for _, r := range pathAfterMajor {
  86.             if r == '/' {
  87.                 break
  88.             }
  89.             endpoint += string(r)
  90.         }
  91.     }
  92.     if endpoint == "" {
  93.         endpoint = key
  94.     }
  95.  
  96.     r.mu.Lock()
  97.     if bucket, exists = r.buckets[key]; !exists {
  98.         r.buckets[key] = &Bucket{
  99.             endpoint: key,
  100.             reset:    time.Now().UnixNano() / 1000,
  101.         }
  102.         bucket = r.buckets[key]
  103.     }
  104.     r.mu.Unlock()
  105.  
  106.     return bucket
  107. }
  108.  
  109. func (r *RateLimit) RateLimitTimeout(key string) int64 {
  110.     bucket := r.Bucket(key)
  111.     return bucket.timeout()
  112. }
  113.  
  114. func (r *RateLimit) RateLimited(key string) bool {
  115.     bucket := r.Bucket(key)
  116.     return bucket.limited()
  117. }
  118.  
  119. type ratelimitBody struct {
  120.     Message    string `json:"message"`
  121.     RetryAfter int64  `json:"retry_after"`
  122.     Global     bool   `json:"global"`
  123. }
  124.  
  125. // TODO: rewrite
  126. func (r *RateLimit) HandleResponse(key string, res *http.Response, content []byte) {
  127.     var err error
  128.     var global bool
  129.     var limit uint64
  130.     var remaining uint64
  131.     var reset int64
  132.     var body *ratelimitBody
  133.     var noBody bool
  134.  
  135.     // read body as well
  136.     if len(content) == 0 {
  137.         noBody = true
  138.     } else {
  139.         err = json.Unmarshal(content, body)
  140.         if err != nil {
  141.             return
  142.         }
  143.     }
  144.  
  145.     // global?
  146.     if res.Header.Get(XRateLimitGlobal) == "true" || (!noBody && body.Global) {
  147.         global = true
  148.     }
  149.  
  150.     // max number of request before reset
  151.     if res.Header.Get(XRateLimitLimit) != "" || (!noBody && body.Global) {
  152.         limit, err = strconv.ParseUint(res.Header.Get(XRateLimitLimit), 10, 64)
  153.         if err != nil {
  154.             // TODO: logging
  155.         }
  156.     }
  157.  
  158.     // remaining requests before reset
  159.     remainingStr := res.Header.Get(XRateLimitRemaining)
  160.     if remainingStr != "" {
  161.         remaining, err = strconv.ParseUint(remainingStr, 10, 64)
  162.         if err != nil {
  163.             // TODO: logging
  164.         }
  165.     }
  166.  
  167.     // reset unix timestamp
  168.     resetStr := res.Header.Get(XRateLimitReset)
  169.     if resetStr != "" {
  170.         // here we get a unix timestamp in seconds, which we convert to milliseconds
  171.         reset, err = strconv.ParseInt(remainingStr, 10, 64)
  172.         if err == nil {
  173.             reset *= 1000 // => milliseconds
  174.         } else {
  175.             // TODO: logging
  176.         }
  177.     } else if res.Header.Get(RateLimitRetryAfter) != "" || (!noBody && body.RetryAfter > 0) {
  178.         // here we are given a delay in millisecond, which we need to convert into a timestamp
  179.         if res.Header.Get(RateLimitRetryAfter) != "" {
  180.             reset, err = strconv.ParseInt(res.Header.Get(RateLimitRetryAfter), 10, 64)
  181.             if err != nil {
  182.                 reset = 0
  183.             }
  184.         } else if !noBody && body.RetryAfter > 0 {
  185.             reset = body.RetryAfter
  186.         }
  187.  
  188.         // convert diff to timestamp
  189.         reset += time.Now().UnixNano() / 1000
  190.     }
  191.  
  192.     if global {
  193.         r.global.mu.Lock()
  194.         defer r.global.mu.Unlock()
  195.  
  196.         if limit != 0 {
  197.             r.global.limit = limit
  198.         }
  199.         if remaining != 0 {
  200.             r.global.remaining = remaining
  201.         }
  202.         if reset != 0 {
  203.             r.global.reset = reset
  204.         }
  205.     } else {
  206.         bucket := r.Bucket(key)
  207.         bucket.mu.Lock()
  208.         defer bucket.mu.Unlock()
  209.  
  210.         if limit != 0 {
  211.             bucket.limit = limit
  212.         }
  213.         if remaining != 0 {
  214.             bucket.remaining = remaining
  215.         }
  216.         if reset != 0 {
  217.             bucket.reset = reset
  218.         }
  219.     }
  220. }
  221.  
  222. // ---------------------
  223.  
  224. type Bucket struct {
  225.     endpoint  string // endpoint where rate limit is applied. endpoint = key
  226.     limit     uint64 // total allowed requests before rate limit
  227.     remaining uint64 // remaining requests
  228.     reset     int64  // unix milliseconds, even tho discord prefers seconds. global uses milliseconds however.
  229.     global    bool   // global rate limiter
  230.  
  231.     mu sync.RWMutex
  232. }
  233.  
  234. func (b *Bucket) limited() bool {
  235.     b.mu.RLock()
  236.     defer b.mu.RUnlock()
  237.  
  238.     return b.reset > (time.Now().UnixNano() / 1000)
  239. }
  240.  
  241. func (b *Bucket) timeout() int64 {
  242.     b.mu.RLock()
  243.     defer b.mu.RUnlock()
  244.  
  245.     now := time.Now().UnixNano() / 1000
  246.     if b.reset > now {
  247.         return b.reset - now
  248.     }
  249.     return 0
  250. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement