498 lines
11 KiB
Go
498 lines
11 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
LogSilent = iota
|
|
LogSystem
|
|
LogWarn
|
|
LogFull
|
|
)
|
|
|
|
type ACL struct {
|
|
Name string
|
|
Patterns []*regexp.Regexp
|
|
}
|
|
|
|
type Config struct {
|
|
NetOutIPv4 string
|
|
NetOutIPv6 string
|
|
Addr string
|
|
Port string
|
|
LogLevel int
|
|
ACLs map[string]*ACL
|
|
NetOutACLs map[string]NetOutIPs // ACL name -> IPs
|
|
ReqDeny []string // ACL names to deny
|
|
ReqAllow []string // ACL names to allow
|
|
AllowAll bool
|
|
}
|
|
|
|
type NetOutIPs struct {
|
|
IPv4 string
|
|
IPv6 string
|
|
}
|
|
|
|
func NewConfig() *Config {
|
|
return &Config{
|
|
Addr: "",
|
|
Port: "3128",
|
|
LogLevel: LogSystem,
|
|
ACLs: make(map[string]*ACL),
|
|
NetOutACLs: make(map[string]NetOutIPs),
|
|
ReqDeny: []string{},
|
|
ReqAllow: []string{},
|
|
AllowAll: false,
|
|
}
|
|
}
|
|
|
|
func (c *Config) Load(path string) error {
|
|
f, err := os.Open(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
|
|
scanner := bufio.NewScanner(f)
|
|
lineNum := 0
|
|
|
|
for scanner.Scan() {
|
|
lineNum++
|
|
line := strings.TrimSpace(scanner.Text())
|
|
|
|
// Skip empty lines and comments
|
|
if line == "" || strings.HasPrefix(line, "#") {
|
|
continue
|
|
}
|
|
|
|
parts := strings.Fields(line)
|
|
if len(parts) == 0 {
|
|
continue
|
|
}
|
|
|
|
switch parts[0] {
|
|
case "net_out":
|
|
if len(parts) < 2 {
|
|
return fmt.Errorf("line %d: net_out requires an IP address", lineNum)
|
|
}
|
|
|
|
ip := parts[1]
|
|
parsedIP := net.ParseIP(ip)
|
|
if parsedIP == nil {
|
|
return fmt.Errorf("line %d: invalid IP address %s", lineNum, ip)
|
|
}
|
|
|
|
if len(parts) == 2 {
|
|
// Default net_out
|
|
if parsedIP.To4() != nil {
|
|
c.NetOutIPv4 = ip
|
|
} else {
|
|
c.NetOutIPv6 = ip
|
|
}
|
|
} else if len(parts) == 3 {
|
|
// net_out IP ACL_NAME
|
|
aclName := parts[2]
|
|
if c.NetOutACLs[aclName].IPv4 == "" && c.NetOutACLs[aclName].IPv6 == "" {
|
|
c.NetOutACLs[aclName] = NetOutIPs{}
|
|
}
|
|
netOut := c.NetOutACLs[aclName]
|
|
if parsedIP.To4() != nil {
|
|
netOut.IPv4 = ip
|
|
} else {
|
|
netOut.IPv6 = ip
|
|
}
|
|
c.NetOutACLs[aclName] = netOut
|
|
}
|
|
|
|
case "addr":
|
|
if len(parts) < 2 {
|
|
return fmt.Errorf("line %d: addr requires an address", lineNum)
|
|
}
|
|
c.Addr = parts[1]
|
|
|
|
case "port":
|
|
if len(parts) < 2 {
|
|
return fmt.Errorf("line %d: port requires a port number", lineNum)
|
|
}
|
|
c.Port = parts[1]
|
|
|
|
case "log_level":
|
|
if len(parts) < 2 {
|
|
return fmt.Errorf("line %d: log_level requires a level", lineNum)
|
|
}
|
|
switch parts[1] {
|
|
case "0":
|
|
c.LogLevel = LogSilent
|
|
case "1":
|
|
c.LogLevel = LogSystem
|
|
case "2":
|
|
c.LogLevel = LogWarn
|
|
case "3":
|
|
c.LogLevel = LogFull
|
|
default:
|
|
return fmt.Errorf("line %d: invalid log level %s", lineNum, parts[1])
|
|
}
|
|
|
|
case "acl":
|
|
if len(parts) < 3 {
|
|
return fmt.Errorf("line %d: acl requires name and pattern", lineNum)
|
|
}
|
|
aclName := parts[1]
|
|
pattern := parts[2]
|
|
|
|
if c.ACLs[aclName] == nil {
|
|
c.ACLs[aclName] = &ACL{
|
|
Name: aclName,
|
|
Patterns: []*regexp.Regexp{},
|
|
}
|
|
}
|
|
|
|
// Convert wildcard pattern to regex
|
|
regexPattern := "^" + regexp.QuoteMeta(pattern) + "$"
|
|
regexPattern = strings.ReplaceAll(regexPattern, `\*`, ".*")
|
|
re, err := regexp.Compile(regexPattern)
|
|
if err != nil {
|
|
return fmt.Errorf("line %d: invalid pattern %s: %v", lineNum, pattern, err)
|
|
}
|
|
|
|
c.ACLs[aclName].Patterns = append(c.ACLs[aclName].Patterns, re)
|
|
|
|
case "req_deny":
|
|
if len(parts) < 2 {
|
|
return fmt.Errorf("line %d: req_deny requires ACL name", lineNum)
|
|
}
|
|
c.ReqDeny = append(c.ReqDeny, parts[1])
|
|
|
|
case "req_allow":
|
|
if len(parts) < 2 {
|
|
return fmt.Errorf("line %d: req_allow requires ACL name", lineNum)
|
|
}
|
|
if parts[1] == "all" {
|
|
c.AllowAll = true
|
|
} else {
|
|
c.ReqAllow = append(c.ReqAllow, parts[1])
|
|
}
|
|
|
|
default:
|
|
return fmt.Errorf("line %d: unknown directive %s", lineNum, parts[0])
|
|
}
|
|
}
|
|
|
|
return scanner.Err()
|
|
}
|
|
|
|
func (c *Config) MatchACL(host string) []string {
|
|
matched := []string{}
|
|
for name, acl := range c.ACLs {
|
|
for _, pattern := range acl.Patterns {
|
|
if pattern.MatchString(host) {
|
|
matched = append(matched, name)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return matched
|
|
}
|
|
|
|
type Proxy struct {
|
|
config *Config
|
|
}
|
|
|
|
func NewProxy(config *Config) *Proxy {
|
|
return &Proxy{config: config}
|
|
}
|
|
|
|
func (p *Proxy) logf(level int, format string, args ...interface{}) {
|
|
if p.config.LogLevel >= level {
|
|
log.Printf(format, args...)
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) getDialer(host string) *net.Dialer {
|
|
dialer := &net.Dialer{
|
|
Timeout: 30 * time.Second,
|
|
KeepAlive: 30 * time.Second,
|
|
}
|
|
|
|
matchedACLs := p.config.MatchACL(host)
|
|
|
|
// Check for net_out ACL match
|
|
for _, aclName := range matchedACLs {
|
|
if netOut, ok := p.config.NetOutACLs[aclName]; ok {
|
|
// Try IPv4 first if available
|
|
if netOut.IPv4 != "" {
|
|
localAddr, err := net.ResolveTCPAddr("tcp4", netOut.IPv4+":0")
|
|
if err != nil {
|
|
p.logf(LogWarn, "Failed to resolve local IPv4 address %s: %v", netOut.IPv4, err)
|
|
} else {
|
|
dialer.LocalAddr = localAddr
|
|
p.logf(LogFull, "Using outgoing IPv4 %s for %s (ACL: %s)", netOut.IPv4, host, aclName)
|
|
return dialer
|
|
}
|
|
}
|
|
// Try IPv6 if IPv4 not available or failed
|
|
if netOut.IPv6 != "" {
|
|
localAddr, err := net.ResolveTCPAddr("tcp6", "["+netOut.IPv6+"]:0")
|
|
if err != nil {
|
|
p.logf(LogWarn, "Failed to resolve local IPv6 address %s: %v", netOut.IPv6, err)
|
|
} else {
|
|
dialer.LocalAddr = localAddr
|
|
p.logf(LogFull, "Using outgoing IPv6 %s for %s (ACL: %s)", netOut.IPv6, host, aclName)
|
|
return dialer
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Use default net_out if set (try IPv4 first, then IPv6)
|
|
if p.config.NetOutIPv4 != "" {
|
|
localAddr, err := net.ResolveTCPAddr("tcp4", p.config.NetOutIPv4+":0")
|
|
if err != nil {
|
|
p.logf(LogWarn, "Failed to resolve default local IPv4 address %s: %v", p.config.NetOutIPv4, err)
|
|
} else {
|
|
dialer.LocalAddr = localAddr
|
|
}
|
|
} else if p.config.NetOutIPv6 != "" {
|
|
localAddr, err := net.ResolveTCPAddr("tcp6", "["+p.config.NetOutIPv6+"]:0")
|
|
if err != nil {
|
|
p.logf(LogWarn, "Failed to resolve default local IPv6 address %s: %v", p.config.NetOutIPv6, err)
|
|
} else {
|
|
dialer.LocalAddr = localAddr
|
|
}
|
|
}
|
|
|
|
return dialer
|
|
}
|
|
|
|
func (p *Proxy) isRequestAllowed(host string) bool {
|
|
matchedACLs := p.config.MatchACL(host)
|
|
|
|
// Check deny ACLs first
|
|
for _, aclName := range matchedACLs {
|
|
for _, denyACL := range p.config.ReqDeny {
|
|
if aclName == denyACL {
|
|
p.logf(LogFull, "Request to %s denied (ACL: %s)", host, aclName)
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check allow ACLs
|
|
if p.config.AllowAll {
|
|
return true
|
|
}
|
|
|
|
for _, aclName := range matchedACLs {
|
|
for _, allowACL := range p.config.ReqAllow {
|
|
if aclName == allowACL {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
|
|
// If no allow ACLs matched and not allow all, deny
|
|
if len(p.config.ReqAllow) > 0 {
|
|
p.logf(LogFull, "Request to %s denied (no matching allow ACL)", host)
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
|
|
host := r.Host
|
|
if !strings.Contains(host, ":") {
|
|
host = host + ":443"
|
|
}
|
|
|
|
p.logf(LogFull, "CONNECT %s from %s", host, r.RemoteAddr)
|
|
|
|
// Check ACLs
|
|
hostname := strings.Split(r.Host, ":")[0]
|
|
if !p.isRequestAllowed(hostname) {
|
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
// Get appropriate dialer
|
|
dialer := p.getDialer(hostname)
|
|
|
|
// Connect to target using DialContext to ensure LocalAddr is used
|
|
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
targetConn, err := dialer.DialContext(ctx, "tcp", host)
|
|
if err != nil {
|
|
p.logf(LogSystem, "Failed to connect to %s: %v", host, err)
|
|
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
|
return
|
|
}
|
|
defer targetConn.Close()
|
|
|
|
// Log the actual local address being used
|
|
if localAddr := targetConn.LocalAddr(); localAddr != nil {
|
|
p.logf(LogFull, "Connected from local address: %s", localAddr.String())
|
|
}
|
|
|
|
// Hijack the connection
|
|
hijacker, ok := w.(http.Hijacker)
|
|
if !ok {
|
|
p.logf(LogSystem, "Hijacking not supported")
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
clientConn, _, err := hijacker.Hijack()
|
|
if err != nil {
|
|
p.logf(LogSystem, "Failed to hijack connection: %v", err)
|
|
return
|
|
}
|
|
defer clientConn.Close()
|
|
|
|
// Send 200 Connection Established
|
|
clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))
|
|
|
|
// Bidirectional copy
|
|
errChan := make(chan error, 2)
|
|
|
|
go func() {
|
|
_, err := io.Copy(targetConn, clientConn)
|
|
errChan <- err
|
|
}()
|
|
|
|
go func() {
|
|
_, err := io.Copy(clientConn, targetConn)
|
|
errChan <- err
|
|
}()
|
|
|
|
// Wait for one direction to finish
|
|
<-errChan
|
|
}
|
|
|
|
func (p *Proxy) handleHTTP(w http.ResponseWriter, r *http.Request) {
|
|
p.logf(LogFull, "%s %s from %s", r.Method, r.URL, r.RemoteAddr)
|
|
|
|
// Check ACLs
|
|
host := r.URL.Hostname()
|
|
if !p.isRequestAllowed(host) {
|
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
// Get appropriate dialer
|
|
dialer := p.getDialer(host)
|
|
|
|
// Create transport with custom dialer
|
|
transport := &http.Transport{
|
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
return dialer.DialContext(ctx, network, addr)
|
|
},
|
|
MaxIdleConns: 100,
|
|
IdleConnTimeout: 90 * time.Second,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
}
|
|
|
|
// Create new request
|
|
outReq := r.Clone(r.Context())
|
|
if outReq.URL.Scheme == "" {
|
|
outReq.URL.Scheme = "http"
|
|
}
|
|
if outReq.URL.Host == "" {
|
|
outReq.URL.Host = r.Host
|
|
}
|
|
|
|
// Remove hop-by-hop headers
|
|
outReq.RequestURI = ""
|
|
outReq.Header.Del("Proxy-Connection")
|
|
outReq.Header.Del("Proxy-Authenticate")
|
|
outReq.Header.Del("Proxy-Authorization")
|
|
outReq.Header.Del("Connection")
|
|
|
|
// Perform request
|
|
resp, err := transport.RoundTrip(outReq)
|
|
if err != nil {
|
|
p.logf(LogSystem, "Failed to forward request to %s: %v", r.URL, err)
|
|
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Copy response headers
|
|
for k, vv := range resp.Header {
|
|
for _, v := range vv {
|
|
w.Header().Add(k, v)
|
|
}
|
|
}
|
|
|
|
w.WriteHeader(resp.StatusCode)
|
|
io.Copy(w, resp.Body)
|
|
}
|
|
|
|
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method == http.MethodConnect {
|
|
p.handleConnect(w, r)
|
|
} else {
|
|
p.handleHTTP(w, r)
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) Start() error {
|
|
addr := fmt.Sprintf("%s:%s", p.config.Addr, p.config.Port)
|
|
p.logf(LogSystem, "Starting rproxy on %s", addr)
|
|
|
|
if p.config.NetOutIPv4 != "" {
|
|
p.logf(LogSystem, "Default outgoing IPv4: %s", p.config.NetOutIPv4)
|
|
}
|
|
if p.config.NetOutIPv6 != "" {
|
|
p.logf(LogSystem, "Default outgoing IPv6: %s", p.config.NetOutIPv6)
|
|
}
|
|
|
|
server := &http.Server{
|
|
Addr: addr,
|
|
Handler: p,
|
|
ReadTimeout: 30 * time.Second,
|
|
WriteTimeout: 30 * time.Second,
|
|
IdleTimeout: 120 * time.Second,
|
|
}
|
|
|
|
return server.ListenAndServe()
|
|
}
|
|
|
|
func main() {
|
|
configPath := "/etc/rproxy/rproxy.conf"
|
|
if len(os.Args) > 1 {
|
|
configPath = os.Args[1]
|
|
}
|
|
|
|
// Expand path
|
|
if strings.HasPrefix(configPath, "~/") {
|
|
home, err := os.UserHomeDir()
|
|
if err == nil {
|
|
configPath = filepath.Join(home, configPath[2:])
|
|
}
|
|
}
|
|
|
|
config := NewConfig()
|
|
if err := config.Load(configPath); err != nil {
|
|
log.Fatalf("Failed to load config from %s: %v", configPath, err)
|
|
}
|
|
|
|
proxy := NewProxy(config)
|
|
if err := proxy.Start(); err != nil {
|
|
log.Fatalf("Proxy error: %v", err)
|
|
}
|
|
}
|