83 lines
2.3 KiB
Go
83 lines
2.3 KiB
Go
|
|
package middleware
|
||
|
|
|
||
|
|
import (
|
||
|
|
"net/http"
|
||
|
|
"time"
|
||
|
|
|
||
|
|
"github.com/didip/tollbooth/v7"
|
||
|
|
"github.com/didip/tollbooth/v7/limiter"
|
||
|
|
"github.com/gin-gonic/gin"
|
||
|
|
)
|
||
|
|
|
||
|
|
// RateLimitMiddleware creates a rate limiting middleware for API protection
|
||
|
|
func RateLimitMiddleware() gin.HandlerFunc {
|
||
|
|
// Create a rate limiter: 100 requests per minute per IP
|
||
|
|
lmt := tollbooth.NewLimiter(100, &limiter.ExpirableOptions{
|
||
|
|
DefaultExpirationTTL: time.Hour,
|
||
|
|
})
|
||
|
|
|
||
|
|
// Configure the limiter
|
||
|
|
lmt.SetIPLookups([]string{"RemoteAddr", "X-Forwarded-For", "X-Real-IP"})
|
||
|
|
lmt.SetMethods([]string{"GET", "POST", "PUT", "DELETE", "PATCH"})
|
||
|
|
|
||
|
|
// Configure message for rate limit exceeded
|
||
|
|
lmt.SetMessage("Rate limit exceeded. Please try again later.")
|
||
|
|
|
||
|
|
return func(c *gin.Context) {
|
||
|
|
// Add rate limit headers
|
||
|
|
c.Header("X-Rate-Limit-Limit", "100")
|
||
|
|
c.Header("X-Rate-Limit-Window", "60s")
|
||
|
|
|
||
|
|
// Check rate limit
|
||
|
|
httpError := tollbooth.LimitByRequest(lmt, c.Writer, c.Request)
|
||
|
|
if httpError != nil {
|
||
|
|
c.Header("Retry-After", "60")
|
||
|
|
c.JSON(http.StatusTooManyRequests, gin.H{
|
||
|
|
"error": "Rate limit exceeded",
|
||
|
|
"message": "Too many requests. Please try again later.",
|
||
|
|
"retry_after": "60s",
|
||
|
|
})
|
||
|
|
c.Abort()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
c.Next()
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
// StrictRateLimitMiddleware creates a stricter rate limiting middleware for sensitive endpoints
|
||
|
|
func StrictRateLimitMiddleware() gin.HandlerFunc {
|
||
|
|
// Create a stricter rate limiter: 10 requests per minute per IP
|
||
|
|
lmt := tollbooth.NewLimiter(10, &limiter.ExpirableOptions{
|
||
|
|
DefaultExpirationTTL: time.Hour,
|
||
|
|
})
|
||
|
|
|
||
|
|
// Configure the limiter
|
||
|
|
lmt.SetIPLookups([]string{"RemoteAddr", "X-Forwarded-For", "X-Real-IP"})
|
||
|
|
lmt.SetMethods([]string{"POST", "PUT", "DELETE"})
|
||
|
|
|
||
|
|
// Configure message for rate limit exceeded
|
||
|
|
lmt.SetMessage("Rate limit exceeded for sensitive endpoint. Please try again later.")
|
||
|
|
|
||
|
|
return func(c *gin.Context) {
|
||
|
|
// Add rate limit headers
|
||
|
|
c.Header("X-Rate-Limit-Limit", "10")
|
||
|
|
c.Header("X-Rate-Limit-Window", "60s")
|
||
|
|
|
||
|
|
// Check rate limit
|
||
|
|
httpError := tollbooth.LimitByRequest(lmt, c.Writer, c.Request)
|
||
|
|
if httpError != nil {
|
||
|
|
c.Header("Retry-After", "60")
|
||
|
|
c.JSON(http.StatusTooManyRequests, gin.H{
|
||
|
|
"error": "Rate limit exceeded",
|
||
|
|
"message": "Too many requests for this sensitive endpoint. Please try again later.",
|
||
|
|
"retry_after": "60s",
|
||
|
|
})
|
||
|
|
c.Abort()
|
||
|
|
return
|
||
|
|
}
|
||
|
|
|
||
|
|
c.Next()
|
||
|
|
}
|
||
|
|
}
|