Files
rproxy/rproxy.go
2025-10-31 16:01:57 +01:00

498 lines
11 KiB
Go

package main
import (
"bufio"
"context"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"path/filepath"
"regexp"
"strings"
"time"
)
const (
LogSilent = iota
LogSystem
LogWarn
LogFull
)
type ACL struct {
Name string
Patterns []*regexp.Regexp
}
type Config struct {
NetOutIPv4 string
NetOutIPv6 string
Addr string
Port string
LogLevel int
ACLs map[string]*ACL
NetOutACLs map[string]NetOutIPs // ACL name -> IPs
ReqDeny []string // ACL names to deny
ReqAllow []string // ACL names to allow
AllowAll bool
}
type NetOutIPs struct {
IPv4 string
IPv6 string
}
func NewConfig() *Config {
return &Config{
Addr: "",
Port: "3128",
LogLevel: LogSystem,
ACLs: make(map[string]*ACL),
NetOutACLs: make(map[string]NetOutIPs),
ReqDeny: []string{},
ReqAllow: []string{},
AllowAll: false,
}
}
func (c *Config) Load(path string) error {
f, err := os.Open(path)
if err != nil {
return err
}
defer f.Close()
scanner := bufio.NewScanner(f)
lineNum := 0
for scanner.Scan() {
lineNum++
line := strings.TrimSpace(scanner.Text())
// Skip empty lines and comments
if line == "" || strings.HasPrefix(line, "#") {
continue
}
parts := strings.Fields(line)
if len(parts) == 0 {
continue
}
switch parts[0] {
case "net_out":
if len(parts) < 2 {
return fmt.Errorf("line %d: net_out requires an IP address", lineNum)
}
ip := parts[1]
parsedIP := net.ParseIP(ip)
if parsedIP == nil {
return fmt.Errorf("line %d: invalid IP address %s", lineNum, ip)
}
if len(parts) == 2 {
// Default net_out
if parsedIP.To4() != nil {
c.NetOutIPv4 = ip
} else {
c.NetOutIPv6 = ip
}
} else if len(parts) == 3 {
// net_out IP ACL_NAME
aclName := parts[2]
if c.NetOutACLs[aclName].IPv4 == "" && c.NetOutACLs[aclName].IPv6 == "" {
c.NetOutACLs[aclName] = NetOutIPs{}
}
netOut := c.NetOutACLs[aclName]
if parsedIP.To4() != nil {
netOut.IPv4 = ip
} else {
netOut.IPv6 = ip
}
c.NetOutACLs[aclName] = netOut
}
case "addr":
if len(parts) < 2 {
return fmt.Errorf("line %d: addr requires an address", lineNum)
}
c.Addr = parts[1]
case "port":
if len(parts) < 2 {
return fmt.Errorf("line %d: port requires a port number", lineNum)
}
c.Port = parts[1]
case "log_level":
if len(parts) < 2 {
return fmt.Errorf("line %d: log_level requires a level", lineNum)
}
switch parts[1] {
case "0":
c.LogLevel = LogSilent
case "1":
c.LogLevel = LogSystem
case "2":
c.LogLevel = LogWarn
case "3":
c.LogLevel = LogFull
default:
return fmt.Errorf("line %d: invalid log level %s", lineNum, parts[1])
}
case "acl":
if len(parts) < 3 {
return fmt.Errorf("line %d: acl requires name and pattern", lineNum)
}
aclName := parts[1]
pattern := parts[2]
if c.ACLs[aclName] == nil {
c.ACLs[aclName] = &ACL{
Name: aclName,
Patterns: []*regexp.Regexp{},
}
}
// Convert wildcard pattern to regex
regexPattern := "^" + regexp.QuoteMeta(pattern) + "$"
regexPattern = strings.ReplaceAll(regexPattern, `\*`, ".*")
re, err := regexp.Compile(regexPattern)
if err != nil {
return fmt.Errorf("line %d: invalid pattern %s: %v", lineNum, pattern, err)
}
c.ACLs[aclName].Patterns = append(c.ACLs[aclName].Patterns, re)
case "req_deny":
if len(parts) < 2 {
return fmt.Errorf("line %d: req_deny requires ACL name", lineNum)
}
c.ReqDeny = append(c.ReqDeny, parts[1])
case "req_allow":
if len(parts) < 2 {
return fmt.Errorf("line %d: req_allow requires ACL name", lineNum)
}
if parts[1] == "all" {
c.AllowAll = true
} else {
c.ReqAllow = append(c.ReqAllow, parts[1])
}
default:
return fmt.Errorf("line %d: unknown directive %s", lineNum, parts[0])
}
}
return scanner.Err()
}
func (c *Config) MatchACL(host string) []string {
matched := []string{}
for name, acl := range c.ACLs {
for _, pattern := range acl.Patterns {
if pattern.MatchString(host) {
matched = append(matched, name)
break
}
}
}
return matched
}
type Proxy struct {
config *Config
}
func NewProxy(config *Config) *Proxy {
return &Proxy{config: config}
}
func (p *Proxy) logf(level int, format string, args ...interface{}) {
if p.config.LogLevel >= level {
log.Printf(format, args...)
}
}
func (p *Proxy) getDialer(host string) *net.Dialer {
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}
matchedACLs := p.config.MatchACL(host)
// Check for net_out ACL match
for _, aclName := range matchedACLs {
if netOut, ok := p.config.NetOutACLs[aclName]; ok {
// Try IPv4 first if available
if netOut.IPv4 != "" {
localAddr, err := net.ResolveTCPAddr("tcp4", netOut.IPv4+":0")
if err != nil {
p.logf(LogWarn, "Failed to resolve local IPv4 address %s: %v", netOut.IPv4, err)
} else {
dialer.LocalAddr = localAddr
p.logf(LogFull, "Using outgoing IPv4 %s for %s (ACL: %s)", netOut.IPv4, host, aclName)
return dialer
}
}
// Try IPv6 if IPv4 not available or failed
if netOut.IPv6 != "" {
localAddr, err := net.ResolveTCPAddr("tcp6", "["+netOut.IPv6+"]:0")
if err != nil {
p.logf(LogWarn, "Failed to resolve local IPv6 address %s: %v", netOut.IPv6, err)
} else {
dialer.LocalAddr = localAddr
p.logf(LogFull, "Using outgoing IPv6 %s for %s (ACL: %s)", netOut.IPv6, host, aclName)
return dialer
}
}
}
}
// Use default net_out if set (try IPv4 first, then IPv6)
if p.config.NetOutIPv4 != "" {
localAddr, err := net.ResolveTCPAddr("tcp4", p.config.NetOutIPv4+":0")
if err != nil {
p.logf(LogWarn, "Failed to resolve default local IPv4 address %s: %v", p.config.NetOutIPv4, err)
} else {
dialer.LocalAddr = localAddr
}
} else if p.config.NetOutIPv6 != "" {
localAddr, err := net.ResolveTCPAddr("tcp6", "["+p.config.NetOutIPv6+"]:0")
if err != nil {
p.logf(LogWarn, "Failed to resolve default local IPv6 address %s: %v", p.config.NetOutIPv6, err)
} else {
dialer.LocalAddr = localAddr
}
}
return dialer
}
func (p *Proxy) isRequestAllowed(host string) bool {
matchedACLs := p.config.MatchACL(host)
// Check deny ACLs first
for _, aclName := range matchedACLs {
for _, denyACL := range p.config.ReqDeny {
if aclName == denyACL {
p.logf(LogFull, "Request to %s denied (ACL: %s)", host, aclName)
return false
}
}
}
// Check allow ACLs
if p.config.AllowAll {
return true
}
for _, aclName := range matchedACLs {
for _, allowACL := range p.config.ReqAllow {
if aclName == allowACL {
return true
}
}
}
// If no allow ACLs matched and not allow all, deny
if len(p.config.ReqAllow) > 0 {
p.logf(LogFull, "Request to %s denied (no matching allow ACL)", host)
return false
}
return true
}
func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
host := r.Host
if !strings.Contains(host, ":") {
host = host + ":443"
}
p.logf(LogFull, "CONNECT %s from %s", host, r.RemoteAddr)
// Check ACLs
hostname := strings.Split(r.Host, ":")[0]
if !p.isRequestAllowed(hostname) {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
// Get appropriate dialer
dialer := p.getDialer(hostname)
// Connect to target using DialContext to ensure LocalAddr is used
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
defer cancel()
targetConn, err := dialer.DialContext(ctx, "tcp", host)
if err != nil {
p.logf(LogSystem, "Failed to connect to %s: %v", host, err)
http.Error(w, "Bad Gateway", http.StatusBadGateway)
return
}
defer targetConn.Close()
// Log the actual local address being used
if localAddr := targetConn.LocalAddr(); localAddr != nil {
p.logf(LogFull, "Connected from local address: %s", localAddr.String())
}
// Hijack the connection
hijacker, ok := w.(http.Hijacker)
if !ok {
p.logf(LogSystem, "Hijacking not supported")
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
return
}
clientConn, _, err := hijacker.Hijack()
if err != nil {
p.logf(LogSystem, "Failed to hijack connection: %v", err)
return
}
defer clientConn.Close()
// Send 200 Connection Established
clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))
// Bidirectional copy
errChan := make(chan error, 2)
go func() {
_, err := io.Copy(targetConn, clientConn)
errChan <- err
}()
go func() {
_, err := io.Copy(clientConn, targetConn)
errChan <- err
}()
// Wait for one direction to finish
<-errChan
}
func (p *Proxy) handleHTTP(w http.ResponseWriter, r *http.Request) {
p.logf(LogFull, "%s %s from %s", r.Method, r.URL, r.RemoteAddr)
// Check ACLs
host := r.URL.Hostname()
if !p.isRequestAllowed(host) {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
// Get appropriate dialer
dialer := p.getDialer(host)
// Create transport with custom dialer
transport := &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, network, addr)
},
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
// Create new request
outReq := r.Clone(r.Context())
if outReq.URL.Scheme == "" {
outReq.URL.Scheme = "http"
}
if outReq.URL.Host == "" {
outReq.URL.Host = r.Host
}
// Remove hop-by-hop headers
outReq.RequestURI = ""
outReq.Header.Del("Proxy-Connection")
outReq.Header.Del("Proxy-Authenticate")
outReq.Header.Del("Proxy-Authorization")
outReq.Header.Del("Connection")
// Perform request
resp, err := transport.RoundTrip(outReq)
if err != nil {
p.logf(LogSystem, "Failed to forward request to %s: %v", r.URL, err)
http.Error(w, "Bad Gateway", http.StatusBadGateway)
return
}
defer resp.Body.Close()
// Copy response headers
for k, vv := range resp.Header {
for _, v := range vv {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
io.Copy(w, resp.Body)
}
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodConnect {
p.handleConnect(w, r)
} else {
p.handleHTTP(w, r)
}
}
func (p *Proxy) Start() error {
addr := fmt.Sprintf("%s:%s", p.config.Addr, p.config.Port)
p.logf(LogSystem, "Starting rproxy on %s", addr)
if p.config.NetOutIPv4 != "" {
p.logf(LogSystem, "Default outgoing IPv4: %s", p.config.NetOutIPv4)
}
if p.config.NetOutIPv6 != "" {
p.logf(LogSystem, "Default outgoing IPv6: %s", p.config.NetOutIPv6)
}
server := &http.Server{
Addr: addr,
Handler: p,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
IdleTimeout: 120 * time.Second,
}
return server.ListenAndServe()
}
func main() {
configPath := "/etc/rproxy/rproxy.conf"
if len(os.Args) > 1 {
configPath = os.Args[1]
}
// Expand path
if strings.HasPrefix(configPath, "~/") {
home, err := os.UserHomeDir()
if err == nil {
configPath = filepath.Join(home, configPath[2:])
}
}
config := NewConfig()
if err := config.Load(configPath); err != nil {
log.Fatalf("Failed to load config from %s: %v", configPath, err)
}
proxy := NewProxy(config)
if err := proxy.Start(); err != nil {
log.Fatalf("Proxy error: %v", err)
}
}