939 lines
22 KiB
Go
939 lines
22 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
const (
|
|
LogSilent = iota
|
|
LogSystem
|
|
LogWarn
|
|
LogFull
|
|
)
|
|
|
|
// SOCKS5 protocol constants
|
|
const (
|
|
Socks5Version = 0x05
|
|
|
|
Socks5AuthNone = 0x00
|
|
Socks5AuthNoAccept = 0xFF
|
|
|
|
Socks5CmdConnect = 0x01
|
|
|
|
Socks5AddrIPv4 = 0x01
|
|
Socks5AddrDomain = 0x03
|
|
Socks5AddrIPv6 = 0x04
|
|
|
|
Socks5RepSuccess = 0x00
|
|
Socks5RepServerFailure = 0x01
|
|
Socks5RepConnectionNotAllowed = 0x02
|
|
Socks5RepNetworkUnreachable = 0x03
|
|
Socks5RepHostUnreachable = 0x04
|
|
Socks5RepConnectionRefused = 0x05
|
|
Socks5RepTTLExpired = 0x06
|
|
Socks5RepCmdNotSupported = 0x07
|
|
Socks5RepAddrNotSupported = 0x08
|
|
)
|
|
|
|
type ACL struct {
|
|
Name string
|
|
Patterns []*regexp.Regexp
|
|
}
|
|
|
|
type Config struct {
|
|
NetOutIPv4 string
|
|
NetOutIPv6 string
|
|
Addr string
|
|
HTTPPort string
|
|
SocksPort string
|
|
LogLevel int
|
|
ACLs map[string]*ACL
|
|
NetOutACLs map[string]NetOutIPs // ACL name -> IPs
|
|
ReqDeny []string // ACL names to deny
|
|
ReqAllow []string // ACL names to allow
|
|
AllowAll bool
|
|
DNSServers []string // Custom DNS servers
|
|
SystemHosts bool // Use system hosts file
|
|
Protocols map[string]bool // Enabled protocols (http, socks)
|
|
}
|
|
|
|
type NetOutIPs struct {
|
|
IPv4 string
|
|
IPv6 string
|
|
}
|
|
|
|
func NewConfig() *Config {
|
|
return &Config{
|
|
Addr: "",
|
|
HTTPPort: "3128",
|
|
SocksPort: "1080",
|
|
LogLevel: LogSystem,
|
|
ACLs: make(map[string]*ACL),
|
|
NetOutACLs: make(map[string]NetOutIPs),
|
|
ReqDeny: []string{},
|
|
ReqAllow: []string{},
|
|
AllowAll: false,
|
|
DNSServers: []string{},
|
|
SystemHosts: false,
|
|
Protocols: map[string]bool{"http": true}, // Default to HTTP only for backward compatibility
|
|
}
|
|
}
|
|
|
|
func (c *Config) Load(path string) error {
|
|
f, err := os.Open(path)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer f.Close()
|
|
|
|
scanner := bufio.NewScanner(f)
|
|
lineNum := 0
|
|
|
|
for scanner.Scan() {
|
|
lineNum++
|
|
line := strings.TrimSpace(scanner.Text())
|
|
|
|
// Skip empty lines and comments
|
|
if line == "" || strings.HasPrefix(line, "#") {
|
|
continue
|
|
}
|
|
|
|
// Strip inline comments
|
|
if idx := strings.Index(line, "#"); idx >= 0 {
|
|
line = strings.TrimSpace(line[:idx])
|
|
}
|
|
|
|
// Skip if line is now empty after stripping inline comment
|
|
if line == "" {
|
|
continue
|
|
}
|
|
|
|
parts := strings.Fields(line)
|
|
if len(parts) == 0 {
|
|
continue
|
|
}
|
|
|
|
switch parts[0] {
|
|
case "net_out":
|
|
if len(parts) < 2 {
|
|
return fmt.Errorf("line %d: net_out requires an IP address", lineNum)
|
|
}
|
|
|
|
ip := parts[1]
|
|
parsedIP := net.ParseIP(ip)
|
|
if parsedIP == nil {
|
|
return fmt.Errorf("line %d: invalid IP address %s", lineNum, ip)
|
|
}
|
|
|
|
if len(parts) == 2 {
|
|
// Default net_out
|
|
if parsedIP.To4() != nil {
|
|
c.NetOutIPv4 = ip
|
|
} else {
|
|
c.NetOutIPv6 = ip
|
|
}
|
|
} else if len(parts) == 3 {
|
|
// net_out IP ACL_NAME
|
|
aclName := parts[2]
|
|
if c.NetOutACLs[aclName].IPv4 == "" && c.NetOutACLs[aclName].IPv6 == "" {
|
|
c.NetOutACLs[aclName] = NetOutIPs{}
|
|
}
|
|
netOut := c.NetOutACLs[aclName]
|
|
if parsedIP.To4() != nil {
|
|
netOut.IPv4 = ip
|
|
} else {
|
|
netOut.IPv6 = ip
|
|
}
|
|
c.NetOutACLs[aclName] = netOut
|
|
}
|
|
|
|
case "addr":
|
|
if len(parts) < 2 {
|
|
return fmt.Errorf("line %d: addr requires an address", lineNum)
|
|
}
|
|
c.Addr = parts[1]
|
|
|
|
case "http_port":
|
|
if len(parts) < 2 {
|
|
return fmt.Errorf("line %d: http_port requires a port number", lineNum)
|
|
}
|
|
c.HTTPPort = parts[1]
|
|
|
|
case "socks_port":
|
|
if len(parts) < 2 {
|
|
return fmt.Errorf("line %d: socks_port requires a port number", lineNum)
|
|
}
|
|
c.SocksPort = parts[1]
|
|
|
|
case "log_level":
|
|
if len(parts) < 2 {
|
|
return fmt.Errorf("line %d: log_level requires a level", lineNum)
|
|
}
|
|
switch parts[1] {
|
|
case "0":
|
|
c.LogLevel = LogSilent
|
|
case "1":
|
|
c.LogLevel = LogSystem
|
|
case "2":
|
|
c.LogLevel = LogWarn
|
|
case "3":
|
|
c.LogLevel = LogFull
|
|
default:
|
|
return fmt.Errorf("line %d: invalid log level %s", lineNum, parts[1])
|
|
}
|
|
|
|
case "acl":
|
|
if len(parts) < 3 {
|
|
return fmt.Errorf("line %d: acl requires name and pattern", lineNum)
|
|
}
|
|
aclName := parts[1]
|
|
pattern := parts[2]
|
|
|
|
if c.ACLs[aclName] == nil {
|
|
c.ACLs[aclName] = &ACL{
|
|
Name: aclName,
|
|
Patterns: []*regexp.Regexp{},
|
|
}
|
|
}
|
|
|
|
// Convert wildcard pattern to regex
|
|
regexPattern := "^" + regexp.QuoteMeta(pattern) + "$"
|
|
regexPattern = strings.ReplaceAll(regexPattern, `\*`, ".*")
|
|
re, err := regexp.Compile(regexPattern)
|
|
if err != nil {
|
|
return fmt.Errorf("line %d: invalid pattern %s: %v", lineNum, pattern, err)
|
|
}
|
|
|
|
c.ACLs[aclName].Patterns = append(c.ACLs[aclName].Patterns, re)
|
|
|
|
case "req_deny":
|
|
if len(parts) < 2 {
|
|
return fmt.Errorf("line %d: req_deny requires ACL name", lineNum)
|
|
}
|
|
c.ReqDeny = append(c.ReqDeny, parts[1])
|
|
|
|
case "req_allow":
|
|
if len(parts) < 2 {
|
|
return fmt.Errorf("line %d: req_allow requires ACL name", lineNum)
|
|
}
|
|
if parts[1] == "all" {
|
|
c.AllowAll = true
|
|
} else {
|
|
c.ReqAllow = append(c.ReqAllow, parts[1])
|
|
}
|
|
|
|
case "dns_servers":
|
|
if len(parts) < 2 {
|
|
return fmt.Errorf("line %d: dns_servers requires at least one DNS server", lineNum)
|
|
}
|
|
// Add all DNS servers (parts[1:])
|
|
for _, dnsServer := range parts[1:] {
|
|
// Validate DNS server format (should be IP or IP:port)
|
|
if !strings.Contains(dnsServer, ":") {
|
|
dnsServer = dnsServer + ":53"
|
|
}
|
|
c.DNSServers = append(c.DNSServers, dnsServer)
|
|
}
|
|
|
|
case "system_hosts":
|
|
if len(parts) < 2 {
|
|
return fmt.Errorf("line %d: system_hosts requires true or false", lineNum)
|
|
}
|
|
switch parts[1] {
|
|
case "true":
|
|
c.SystemHosts = true
|
|
case "false":
|
|
c.SystemHosts = false
|
|
default:
|
|
return fmt.Errorf("line %d: system_hosts must be true or false", lineNum)
|
|
}
|
|
|
|
case "proto":
|
|
if len(parts) < 2 {
|
|
return fmt.Errorf("line %d: proto requires at least one protocol (http, socks)", lineNum)
|
|
}
|
|
// Clear existing protocols and add specified ones
|
|
c.Protocols = make(map[string]bool)
|
|
for _, proto := range parts[1:] {
|
|
switch proto {
|
|
case "http", "socks":
|
|
c.Protocols[proto] = true
|
|
default:
|
|
return fmt.Errorf("line %d: unknown protocol %s (supported: http, socks)", lineNum, proto)
|
|
}
|
|
}
|
|
|
|
default:
|
|
return fmt.Errorf("line %d: unknown directive %s", lineNum, parts[0])
|
|
}
|
|
}
|
|
|
|
return scanner.Err()
|
|
}
|
|
|
|
func (c *Config) MatchACL(host string) []string {
|
|
matched := []string{}
|
|
for name, acl := range c.ACLs {
|
|
for _, pattern := range acl.Patterns {
|
|
if pattern.MatchString(host) {
|
|
matched = append(matched, name)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
return matched
|
|
}
|
|
|
|
type Proxy struct {
|
|
config *Config
|
|
hosts map[string]string // hostname -> IP mapping from /etc/hosts
|
|
resolver *net.Resolver // custom DNS resolver
|
|
}
|
|
|
|
func loadHostsFile() map[string]string {
|
|
hosts := make(map[string]string)
|
|
|
|
// Try common hosts file locations
|
|
hostsFiles := []string{"/etc/hosts", "C:\\Windows\\System32\\drivers\\etc\\hosts"}
|
|
|
|
var file *os.File
|
|
var err error
|
|
for _, path := range hostsFiles {
|
|
file, err = os.Open(path)
|
|
if err == nil {
|
|
break
|
|
}
|
|
}
|
|
|
|
if err != nil {
|
|
return hosts
|
|
}
|
|
defer file.Close()
|
|
|
|
scanner := bufio.NewScanner(file)
|
|
for scanner.Scan() {
|
|
line := strings.TrimSpace(scanner.Text())
|
|
|
|
// Skip empty lines and comments
|
|
if line == "" || strings.HasPrefix(line, "#") {
|
|
continue
|
|
}
|
|
|
|
// Parse line: IP hostname [aliases...]
|
|
fields := strings.Fields(line)
|
|
if len(fields) < 2 {
|
|
continue
|
|
}
|
|
|
|
ip := fields[0]
|
|
// Add all hostnames/aliases on this line
|
|
for _, hostname := range fields[1:] {
|
|
hosts[hostname] = ip
|
|
}
|
|
}
|
|
|
|
return hosts
|
|
}
|
|
|
|
func createCustomResolver(dnsServers []string) *net.Resolver {
|
|
if len(dnsServers) == 0 {
|
|
return net.DefaultResolver
|
|
}
|
|
|
|
// Create a custom dialer that uses the specified DNS servers
|
|
return &net.Resolver{
|
|
PreferGo: true,
|
|
Dial: func(ctx context.Context, network, address string) (net.Conn, error) {
|
|
d := net.Dialer{
|
|
Timeout: 10 * time.Second,
|
|
}
|
|
// Try each DNS server in order
|
|
var lastErr error
|
|
for _, dnsServer := range dnsServers {
|
|
conn, err := d.DialContext(ctx, "udp", dnsServer)
|
|
if err == nil {
|
|
return conn, nil
|
|
}
|
|
lastErr = err
|
|
}
|
|
return nil, fmt.Errorf("failed to connect to any DNS server: %w", lastErr)
|
|
},
|
|
}
|
|
}
|
|
|
|
func NewProxy(config *Config) *Proxy {
|
|
proxy := &Proxy{
|
|
config: config,
|
|
hosts: make(map[string]string),
|
|
resolver: net.DefaultResolver,
|
|
}
|
|
|
|
// Load hosts file if enabled
|
|
if config.SystemHosts {
|
|
proxy.hosts = loadHostsFile()
|
|
}
|
|
|
|
// Create custom resolver if DNS servers are specified
|
|
if len(config.DNSServers) > 0 {
|
|
proxy.resolver = createCustomResolver(config.DNSServers)
|
|
}
|
|
|
|
return proxy
|
|
}
|
|
|
|
func (p *Proxy) logf(level int, format string, args ...interface{}) {
|
|
if p.config.LogLevel >= level {
|
|
log.Printf(format, args...)
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) resolveHost(ctx context.Context, hostname string) ([]string, error) {
|
|
// Check hosts file first if enabled
|
|
if p.config.SystemHosts {
|
|
if ip, ok := p.hosts[hostname]; ok {
|
|
p.logf(LogFull, "Resolved %s to %s from hosts file", hostname, ip)
|
|
return []string{ip}, nil
|
|
}
|
|
}
|
|
|
|
// Use custom resolver or default
|
|
ips, err := p.resolver.LookupHost(ctx, hostname)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(ips) > 0 && p.config.LogLevel >= LogFull {
|
|
p.logf(LogFull, "Resolved %s to %s via DNS", hostname, ips[0])
|
|
}
|
|
|
|
return ips, nil
|
|
}
|
|
|
|
func (p *Proxy) dialWithCustomDNS(ctx context.Context, dialer *net.Dialer, network, address string) (net.Conn, error) {
|
|
// Parse the address to extract hostname and port
|
|
host, port, err := net.SplitHostPort(address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Check if host is already an IP address
|
|
if net.ParseIP(host) != nil {
|
|
// Already an IP, dial directly
|
|
return dialer.DialContext(ctx, network, address)
|
|
}
|
|
|
|
// Use custom DNS resolution
|
|
ips, err := p.resolveHost(ctx, host)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to resolve %s: %w", host, err)
|
|
}
|
|
|
|
if len(ips) == 0 {
|
|
return nil, fmt.Errorf("no IP addresses found for %s", host)
|
|
}
|
|
|
|
// Try each resolved IP until one works
|
|
var lastErr error
|
|
for _, ip := range ips {
|
|
targetAddr := net.JoinHostPort(ip, port)
|
|
conn, err := dialer.DialContext(ctx, network, targetAddr)
|
|
if err == nil {
|
|
return conn, nil
|
|
}
|
|
lastErr = err
|
|
}
|
|
|
|
return nil, fmt.Errorf("failed to connect to any IP for %s: %w", host, lastErr)
|
|
}
|
|
|
|
func (p *Proxy) getDialer(host string) *net.Dialer {
|
|
dialer := &net.Dialer{
|
|
Timeout: 30 * time.Second,
|
|
KeepAlive: 30 * time.Second,
|
|
}
|
|
|
|
matchedACLs := p.config.MatchACL(host)
|
|
|
|
// Check for net_out ACL match
|
|
for _, aclName := range matchedACLs {
|
|
if netOut, ok := p.config.NetOutACLs[aclName]; ok {
|
|
// Try IPv4 first if available
|
|
if netOut.IPv4 != "" {
|
|
localAddr, err := net.ResolveTCPAddr("tcp4", netOut.IPv4+":0")
|
|
if err != nil {
|
|
p.logf(LogWarn, "Failed to resolve local IPv4 address %s: %v", netOut.IPv4, err)
|
|
} else {
|
|
dialer.LocalAddr = localAddr
|
|
p.logf(LogFull, "Using outgoing IPv4 %s for %s (ACL: %s)", netOut.IPv4, host, aclName)
|
|
return dialer
|
|
}
|
|
}
|
|
// Try IPv6 if IPv4 not available or failed
|
|
if netOut.IPv6 != "" {
|
|
localAddr, err := net.ResolveTCPAddr("tcp6", "["+netOut.IPv6+"]:0")
|
|
if err != nil {
|
|
p.logf(LogWarn, "Failed to resolve local IPv6 address %s: %v", netOut.IPv6, err)
|
|
} else {
|
|
dialer.LocalAddr = localAddr
|
|
p.logf(LogFull, "Using outgoing IPv6 %s for %s (ACL: %s)", netOut.IPv6, host, aclName)
|
|
return dialer
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Use default net_out if set (try IPv4 first, then IPv6)
|
|
if p.config.NetOutIPv4 != "" {
|
|
localAddr, err := net.ResolveTCPAddr("tcp4", p.config.NetOutIPv4+":0")
|
|
if err != nil {
|
|
p.logf(LogWarn, "Failed to resolve default local IPv4 address %s: %v", p.config.NetOutIPv4, err)
|
|
} else {
|
|
dialer.LocalAddr = localAddr
|
|
}
|
|
} else if p.config.NetOutIPv6 != "" {
|
|
localAddr, err := net.ResolveTCPAddr("tcp6", "["+p.config.NetOutIPv6+"]:0")
|
|
if err != nil {
|
|
p.logf(LogWarn, "Failed to resolve default local IPv6 address %s: %v", p.config.NetOutIPv6, err)
|
|
} else {
|
|
dialer.LocalAddr = localAddr
|
|
}
|
|
}
|
|
|
|
return dialer
|
|
}
|
|
|
|
func (p *Proxy) isRequestAllowed(host string) bool {
|
|
matchedACLs := p.config.MatchACL(host)
|
|
|
|
// Check deny ACLs first
|
|
for _, aclName := range matchedACLs {
|
|
for _, denyACL := range p.config.ReqDeny {
|
|
if aclName == denyACL {
|
|
p.logf(LogFull, "Request to %s denied (ACL: %s)", host, aclName)
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
|
|
// Check allow ACLs
|
|
if p.config.AllowAll {
|
|
return true
|
|
}
|
|
|
|
for _, aclName := range matchedACLs {
|
|
for _, allowACL := range p.config.ReqAllow {
|
|
if aclName == allowACL {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
|
|
// If no allow ACLs matched and not allow all, deny
|
|
if len(p.config.ReqAllow) > 0 {
|
|
p.logf(LogFull, "Request to %s denied (no matching allow ACL)", host)
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func (p *Proxy) handleConnect(w http.ResponseWriter, r *http.Request) {
|
|
host := r.Host
|
|
if !strings.Contains(host, ":") {
|
|
host = host + ":443"
|
|
}
|
|
|
|
p.logf(LogFull, "CONNECT %s from %s", host, r.RemoteAddr)
|
|
|
|
// Check ACLs
|
|
hostname := strings.Split(r.Host, ":")[0]
|
|
if !p.isRequestAllowed(hostname) {
|
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
// Get appropriate dialer
|
|
dialer := p.getDialer(hostname)
|
|
|
|
// Connect to target using custom DNS resolution
|
|
ctx, cancel := context.WithTimeout(r.Context(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
targetConn, err := p.dialWithCustomDNS(ctx, dialer, "tcp", host)
|
|
if err != nil {
|
|
p.logf(LogSystem, "Failed to connect to %s: %v", host, err)
|
|
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
|
return
|
|
}
|
|
defer targetConn.Close()
|
|
|
|
// Log the actual local address being used
|
|
if localAddr := targetConn.LocalAddr(); localAddr != nil {
|
|
p.logf(LogFull, "Connected from local address: %s", localAddr.String())
|
|
}
|
|
|
|
// Hijack the connection
|
|
hijacker, ok := w.(http.Hijacker)
|
|
if !ok {
|
|
p.logf(LogSystem, "Hijacking not supported")
|
|
http.Error(w, "Internal Server Error", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
clientConn, _, err := hijacker.Hijack()
|
|
if err != nil {
|
|
p.logf(LogSystem, "Failed to hijack connection: %v", err)
|
|
return
|
|
}
|
|
defer clientConn.Close()
|
|
|
|
// Send 200 Connection Established
|
|
clientConn.Write([]byte("HTTP/1.1 200 Connection Established\r\n\r\n"))
|
|
|
|
// Bidirectional copy
|
|
errChan := make(chan error, 2)
|
|
|
|
go func() {
|
|
_, err := io.Copy(targetConn, clientConn)
|
|
errChan <- err
|
|
}()
|
|
|
|
go func() {
|
|
_, err := io.Copy(clientConn, targetConn)
|
|
errChan <- err
|
|
}()
|
|
|
|
// Wait for one direction to finish
|
|
<-errChan
|
|
}
|
|
|
|
func (p *Proxy) handleHTTP(w http.ResponseWriter, r *http.Request) {
|
|
p.logf(LogFull, "%s %s from %s", r.Method, r.URL, r.RemoteAddr)
|
|
|
|
// Check ACLs
|
|
host := r.URL.Hostname()
|
|
if !p.isRequestAllowed(host) {
|
|
http.Error(w, "Forbidden", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
// Get appropriate dialer
|
|
dialer := p.getDialer(host)
|
|
|
|
// Create transport with custom DNS resolution
|
|
transport := &http.Transport{
|
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
|
return p.dialWithCustomDNS(ctx, dialer, network, addr)
|
|
},
|
|
MaxIdleConns: 100,
|
|
IdleConnTimeout: 90 * time.Second,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
}
|
|
|
|
// Create new request
|
|
outReq := r.Clone(r.Context())
|
|
if outReq.URL.Scheme == "" {
|
|
outReq.URL.Scheme = "http"
|
|
}
|
|
if outReq.URL.Host == "" {
|
|
outReq.URL.Host = r.Host
|
|
}
|
|
|
|
// Remove hop-by-hop headers
|
|
outReq.RequestURI = ""
|
|
outReq.Header.Del("Proxy-Connection")
|
|
outReq.Header.Del("Proxy-Authenticate")
|
|
outReq.Header.Del("Proxy-Authorization")
|
|
outReq.Header.Del("Connection")
|
|
|
|
// Perform request
|
|
resp, err := transport.RoundTrip(outReq)
|
|
if err != nil {
|
|
p.logf(LogSystem, "Failed to forward request to %s: %v", r.URL, err)
|
|
http.Error(w, "Bad Gateway", http.StatusBadGateway)
|
|
return
|
|
}
|
|
defer resp.Body.Close()
|
|
|
|
// Copy response headers
|
|
for k, vv := range resp.Header {
|
|
for _, v := range vv {
|
|
w.Header().Add(k, v)
|
|
}
|
|
}
|
|
|
|
w.WriteHeader(resp.StatusCode)
|
|
io.Copy(w, resp.Body)
|
|
}
|
|
|
|
func (p *Proxy) handleSOCKS5(clientConn net.Conn) {
|
|
defer clientConn.Close()
|
|
|
|
// Read version and auth methods
|
|
buf := make([]byte, 256)
|
|
n, err := clientConn.Read(buf)
|
|
if err != nil || n < 2 {
|
|
p.logf(LogWarn, "SOCKS5: failed to read greeting: %v", err)
|
|
return
|
|
}
|
|
|
|
version := buf[0]
|
|
if version != Socks5Version {
|
|
p.logf(LogWarn, "SOCKS5: unsupported version: %d", version)
|
|
return
|
|
}
|
|
|
|
// Respond with no authentication required
|
|
_, err = clientConn.Write([]byte{Socks5Version, Socks5AuthNone})
|
|
if err != nil {
|
|
p.logf(LogWarn, "SOCKS5: failed to send auth response: %v", err)
|
|
return
|
|
}
|
|
|
|
// Read connection request
|
|
n, err = clientConn.Read(buf)
|
|
if err != nil || n < 7 {
|
|
p.logf(LogWarn, "SOCKS5: failed to read request: %v", err)
|
|
return
|
|
}
|
|
|
|
if buf[0] != Socks5Version {
|
|
p.logf(LogWarn, "SOCKS5: invalid version in request: %d", buf[0])
|
|
return
|
|
}
|
|
|
|
cmd := buf[1]
|
|
if cmd != Socks5CmdConnect {
|
|
// Only support CONNECT
|
|
reply := []byte{Socks5Version, Socks5RepCmdNotSupported, 0x00, 0x01, 0, 0, 0, 0, 0, 0}
|
|
clientConn.Write(reply)
|
|
p.logf(LogWarn, "SOCKS5: unsupported command: %d", cmd)
|
|
return
|
|
}
|
|
|
|
// Parse target address
|
|
addrType := buf[3]
|
|
var host string
|
|
var port uint16
|
|
var addrEnd int
|
|
|
|
switch addrType {
|
|
case Socks5AddrIPv4:
|
|
if n < 10 {
|
|
p.logf(LogWarn, "SOCKS5: incomplete IPv4 address")
|
|
return
|
|
}
|
|
host = net.IP(buf[4:8]).String()
|
|
port = uint16(buf[8])<<8 | uint16(buf[9])
|
|
addrEnd = 10
|
|
|
|
case Socks5AddrDomain:
|
|
if n < 5 {
|
|
p.logf(LogWarn, "SOCKS5: incomplete domain length")
|
|
return
|
|
}
|
|
domainLen := int(buf[4])
|
|
if n < 5+domainLen+2 {
|
|
p.logf(LogWarn, "SOCKS5: incomplete domain name")
|
|
return
|
|
}
|
|
host = string(buf[5 : 5+domainLen])
|
|
port = uint16(buf[5+domainLen])<<8 | uint16(buf[5+domainLen+1])
|
|
addrEnd = 5 + domainLen + 2
|
|
|
|
case Socks5AddrIPv6:
|
|
if n < 22 {
|
|
p.logf(LogWarn, "SOCKS5: incomplete IPv6 address")
|
|
return
|
|
}
|
|
host = net.IP(buf[4:20]).String()
|
|
port = uint16(buf[20])<<8 | uint16(buf[21])
|
|
addrEnd = 22
|
|
|
|
default:
|
|
reply := []byte{Socks5Version, Socks5RepAddrNotSupported, 0x00, 0x01, 0, 0, 0, 0, 0, 0}
|
|
clientConn.Write(reply)
|
|
p.logf(LogWarn, "SOCKS5: unsupported address type: %d", addrType)
|
|
return
|
|
}
|
|
|
|
target := fmt.Sprintf("%s:%d", host, port)
|
|
p.logf(LogFull, "SOCKS5 CONNECT %s from %s", target, clientConn.RemoteAddr())
|
|
|
|
// Check ACLs
|
|
if !p.isRequestAllowed(host) {
|
|
reply := []byte{Socks5Version, Socks5RepConnectionNotAllowed, 0x00, 0x01, 0, 0, 0, 0, 0, 0}
|
|
clientConn.Write(reply)
|
|
return
|
|
}
|
|
|
|
// Get appropriate dialer
|
|
dialer := p.getDialer(host)
|
|
|
|
// Connect to target using custom DNS resolution
|
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
|
defer cancel()
|
|
|
|
targetConn, err := p.dialWithCustomDNS(ctx, dialer, "tcp", target)
|
|
if err != nil {
|
|
p.logf(LogSystem, "SOCKS5: failed to connect to %s: %v", target, err)
|
|
reply := []byte{Socks5Version, Socks5RepHostUnreachable, 0x00, 0x01, 0, 0, 0, 0, 0, 0}
|
|
clientConn.Write(reply)
|
|
return
|
|
}
|
|
defer targetConn.Close()
|
|
|
|
// Log the actual local address being used
|
|
if localAddr := targetConn.LocalAddr(); localAddr != nil {
|
|
p.logf(LogFull, "SOCKS5: connected from local address: %s", localAddr.String())
|
|
}
|
|
|
|
// Send success response
|
|
// Use the original address from the request
|
|
reply := make([]byte, addrEnd)
|
|
reply[0] = Socks5Version
|
|
reply[1] = Socks5RepSuccess
|
|
reply[2] = 0x00
|
|
copy(reply[3:], buf[3:addrEnd])
|
|
_, err = clientConn.Write(reply)
|
|
if err != nil {
|
|
p.logf(LogWarn, "SOCKS5: failed to send success response: %v", err)
|
|
return
|
|
}
|
|
|
|
// Bidirectional copy
|
|
errChan := make(chan error, 2)
|
|
|
|
go func() {
|
|
_, err := io.Copy(targetConn, clientConn)
|
|
errChan <- err
|
|
}()
|
|
|
|
go func() {
|
|
_, err := io.Copy(clientConn, targetConn)
|
|
errChan <- err
|
|
}()
|
|
|
|
// Wait for one direction to finish
|
|
<-errChan
|
|
}
|
|
|
|
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
if r.Method == http.MethodConnect {
|
|
p.handleConnect(w, r)
|
|
} else {
|
|
p.handleHTTP(w, r)
|
|
}
|
|
}
|
|
|
|
func (p *Proxy) Start() error {
|
|
p.logf(LogSystem, "Starting rproxy")
|
|
|
|
// Log enabled protocols
|
|
protocols := []string{}
|
|
for proto := range p.config.Protocols {
|
|
protocols = append(protocols, proto)
|
|
}
|
|
p.logf(LogSystem, "Enabled protocols: %s", strings.Join(protocols, ", "))
|
|
|
|
if p.config.NetOutIPv4 != "" {
|
|
p.logf(LogSystem, "Default outgoing IPv4: %s", p.config.NetOutIPv4)
|
|
}
|
|
if p.config.NetOutIPv6 != "" {
|
|
p.logf(LogSystem, "Default outgoing IPv6: %s", p.config.NetOutIPv6)
|
|
}
|
|
|
|
if len(p.config.DNSServers) > 0 {
|
|
p.logf(LogSystem, "Custom DNS servers: %s", strings.Join(p.config.DNSServers, ", "))
|
|
}
|
|
|
|
if p.config.SystemHosts {
|
|
p.logf(LogSystem, "System hosts file enabled (%d entries loaded)", len(p.hosts))
|
|
}
|
|
|
|
httpEnabled := p.config.Protocols["http"]
|
|
socksEnabled := p.config.Protocols["socks"]
|
|
|
|
// Channel to collect errors from listeners
|
|
errChan := make(chan error, 2)
|
|
|
|
// Start HTTP listener if enabled
|
|
if httpEnabled {
|
|
httpAddr := fmt.Sprintf("%s:%s", p.config.Addr, p.config.HTTPPort)
|
|
p.logf(LogSystem, "Starting HTTP proxy on %s", httpAddr)
|
|
go func() {
|
|
server := &http.Server{
|
|
Addr: httpAddr,
|
|
Handler: p,
|
|
ReadTimeout: 30 * time.Second,
|
|
WriteTimeout: 30 * time.Second,
|
|
IdleTimeout: 120 * time.Second,
|
|
}
|
|
errChan <- server.ListenAndServe()
|
|
}()
|
|
}
|
|
|
|
// Start SOCKS listener if enabled
|
|
if socksEnabled {
|
|
socksAddr := fmt.Sprintf("%s:%s", p.config.Addr, p.config.SocksPort)
|
|
p.logf(LogSystem, "Starting SOCKS5 proxy on %s", socksAddr)
|
|
go func() {
|
|
listener, err := net.Listen("tcp", socksAddr)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
defer listener.Close()
|
|
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
p.logf(LogWarn, "SOCKS5: Failed to accept connection: %v", err)
|
|
continue
|
|
}
|
|
go p.handleSOCKS5(conn)
|
|
}
|
|
}()
|
|
}
|
|
|
|
// Wait for first error (if any listener fails)
|
|
return <-errChan
|
|
}
|
|
|
|
func main() {
|
|
configPath := "/etc/rproxy/rproxy.conf"
|
|
if len(os.Args) > 1 {
|
|
configPath = os.Args[1]
|
|
}
|
|
|
|
// Expand path
|
|
if strings.HasPrefix(configPath, "~/") {
|
|
home, err := os.UserHomeDir()
|
|
if err == nil {
|
|
configPath = filepath.Join(home, configPath[2:])
|
|
}
|
|
}
|
|
|
|
config := NewConfig()
|
|
if err := config.Load(configPath); err != nil {
|
|
log.Fatalf("Failed to load config from %s: %v", configPath, err)
|
|
}
|
|
|
|
proxy := NewProxy(config)
|
|
if err := proxy.Start(); err != nil {
|
|
log.Fatalf("Proxy error: %v", err)
|
|
}
|
|
}
|