Files
2024-07-31 00:48:41 +02:00

218 lines
4.8 KiB
Go

package main
import (
"flag"
"fmt"
"net"
"net/http"
"os"
"os/exec"
"os/signal"
"strings"
"syscall"
"time"
)
var (
port = flag.String("port", "8080", "Port to listen on")
clean = flag.Bool("clean", false, "Cleanup the banlist")
dryrun = flag.Bool("dryrun", false, "Dry run mode")
sev1 = flag.Int("sev1", 60, "Severity 1 ban time in seconds")
sev2 = flag.Int("sev2", 300, "Severity 2 ban time in seconds")
sev3 = flag.Int("sev3", 3600, "Severity 3 ban time in seconds")
banlist = make(map[string]int64)
)
func main() {
flag.Parse()
if *clean {
err := cleanup()
if err != nil {
fmt.Printf("Failed to clean up banlist\n")
os.Exit(1)
}
os.Exit(0)
}
if *dryrun {
fmt.Printf("Dry run mode enabled\n")
}
fmt.Printf("Starting Blocky on port: %s\n", *port)
http.HandleFunc("/", handler)
go func() {
err := http.ListenAndServe(":" + *port, nil)
if err != nil {
fmt.Printf("Failed to start server: %s\n", err)
}
}()
sig := make(chan os.Signal)
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
_ = <-sig
// Unban all IPs
if !*dryrun {
for ip := range banlist {
cmd := exec.Command("/usr/sbin/ufw", "delete", "deny", "from", ip, "comment", "Blocky")
err := cmd.Run()
if err != nil {
fmt.Printf("Failed to unban IP: %s\n", ip)
fmt.Printf("Error: %s\n", err)
}
fmt.Printf("IP: %s has been unbanned\n", ip)
}
}
fmt.Printf("\033[2K\rTime to sleep, goodbye\n")
}
// This server will be recieving POST requests containing an IP and the severity of the block
// If the request is a GET, it will check for the IP in the banlist and unban the IPs that have expired
func handler(w http.ResponseWriter, r *http.Request) {
if r.Method == "GET" {
fmt.Printf("Cleaning up banlist...\n")
// Get the current epoch time
var epoch int64
epoch = time.Now().Unix()
// Iterate through the banlist
for ip, banTime := range banlist {
// If the ban has expired, remove the IP from the banlist
if banTime < epoch {
// Unban the IP
if !*dryrun {
cmd := exec.Command("/usr/sbin/ufw", "delete", "deny", "from", ip, "comment", "Blocky")
err := cmd.Run()
if err != nil {
fmt.Printf("Failed to unban IP: %s\n", ip)
fmt.Printf("Error: %s\n", err)
}
}
delete(banlist, ip)
fmt.Printf("IP: %s has been unbanned\n", ip)
}
}
fmt.Printf("Banlist has been cleaned up\n")
w.WriteHeader(http.StatusOK)
return
}
if r.Method != "POST" {
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
return
}
// Parse the IP from the message
r.ParseForm()
message := r.Form.Get("message")
severity := atoi(r.Form.Get("priority"))
// Respond with a 200 OK
w.WriteHeader(http.StatusOK)
// Parse the IP from the message
ip := parseIP(message)
if ip == "" {
fmt.Printf("Invalid IP: %s\n", ip)
http.Error(w, "Invalid IP", http.StatusBadRequest)
return
}
// Log the IP and severity
fmt.Printf("Received IP: %s, Severity: %d\n", ip, severity)
// Get the current epoch time
var epoch int64
epoch = time.Now().Unix()
// Add the IP to the banlist
if severity == 1 {
banlist[ip] = epoch + int64(*sev1)
}
if severity == 2 {
banlist[ip] = epoch + int64(*sev2)
}
if severity == 3 {
banlist[ip] = epoch + int64(*sev3)
}
// Ban the IP
if !*dryrun {
cmd := exec.Command("/usr/sbin/ufw", "deny", "from", ip, "comment", "Blocky")
err := cmd.Run()
if err != nil {
fmt.Printf("Failed to ban IP: %s\n", ip)
fmt.Printf("Error: %s\n", err)
}
}
return
}
func cleanup() error {
resp, err := http.Get("http://localhost:" + *port)
if err == nil {
defer resp.Body.Close()
} else {
fmt.Printf("Blocky is not running, cleaning up manually...\n")
cmd := exec.Command("/usr/sbin/ufw", "status")
output, err := cmd.Output()
if err != nil {
fmt.Printf("Failed to get firewall status: %s\n", err)
return err
}
lines := strings.Split(string(output), "\n")
if len(lines) == 0 {
fmt.Printf("No IPs to unban\n")
return nil
}
for _, line := range lines {
if strings.Contains(line, "Blocky") {
// Extract the IP
ip := strings.Fields(line)[2]
// Unban the IP
cmd := exec.Command("/usr/sbin/ufw", "delete", "deny", "from", ip, "comment", "Blocky")
err := cmd.Run()
if err != nil {
fmt.Printf("Failed to unban IP: %s\n", ip)
fmt.Printf("Error: %s\n", err)
return err
}
fmt.Printf("IP: %s has been unbanned\n", ip)
}
}
}
fmt.Printf("Banlist has been cleaned up\n")
return nil
}
func parseIP(message string) string {
// Find the index of the first colon
colonIndex := strings.Index(message, ":")
// Find the index of the first pipe
pipeIndex := strings.Index(message, "|")
// Extract the IP
ip := message[colonIndex + 2:pipeIndex]
// Validate the IP
if net.ParseIP(ip) == nil {
return ""
}
return ip
}
func atoi(s string) int {
i := 0
for _, r := range s {
i = i*10 + int(r-'0')
}
return i
}