diff --git a/redis/conn.go b/redis/conn.go index 753644b..634300b 100644 --- a/redis/conn.go +++ b/redis/conn.go @@ -324,71 +324,90 @@ func DialURL(rawurl string, options ...DialOption) (Conn, error) { } // DialURLContext connects to a Redis server at the given URL using the Redis -// URI scheme. URLs should follow the draft IANA specification for the -// scheme (https://www.iana.org/assignments/uri-schemes/prov/redis). +// URI scheme. It supports: +// redis - unencrypted tcp connection +// rediss - TLS encrypted tcp connection +// redis+unix - UNIX socket connection func DialURLContext(ctx context.Context, rawurl string, options ...DialOption) (Conn, error) { u, err := url.Parse(rawurl) if err != nil { return nil, err } - - if u.Scheme != "redis" && u.Scheme != "rediss" { - return nil, fmt.Errorf("invalid redis URL scheme: %s", u.Scheme) - } - if u.Opaque != "" { return nil, fmt.Errorf("invalid redis URL, url is opaque: %s", rawurl) } - // As per the IANA draft spec, the host defaults to localhost and - // the port defaults to 6379. - host, port, err := net.SplitHostPort(u.Host) - if err != nil { - // assume port is missing - host = u.Host - port = "6379" - } - if host == "" { - host = "localhost" - } - address := net.JoinHostPort(host, port) - - if u.User != nil { - password, isSet := u.User.Password() - username := u.User.Username() - if isSet { - if username != "" { - // ACL - options = append(options, DialUsername(username), DialPassword(password)) - } else { - // requirepass - user-info username:password with blank username - options = append(options, DialPassword(password)) + var ( + network string + address string + db = 0 + ) + switch u.Scheme { + case "redis", "rediss": + network = "tcp" + + // As per the IANA draft spec, the host defaults to localhost and + // the port defaults to 6379. + host, port, err := net.SplitHostPort(u.Host) + if err != nil { + // assume port is missing + host = u.Host + port = "6379" + } + if host == "" { + host = "localhost" + } + address = net.JoinHostPort(host, port) + if u.User != nil { + password, isSet := u.User.Password() + username := u.User.Username() + if isSet { + if username != "" { + // ACL + options = append(options, DialUsername(username), DialPassword(password)) + } else { + // requirepass - user-info username:password with blank username + options = append(options, DialPassword(password)) + } + } else if username != "" { + // requirepass - redis-cli compatibility which treats as single arg in user-info as a password + options = append(options, DialPassword(username)) } - } else if username != "" { - // requirepass - redis-cli compatibility which treats as single arg in user-info as a password - options = append(options, DialPassword(username)) } - } - - match := pathDBRegexp.FindStringSubmatch(u.Path) - if len(match) == 2 { - db := 0 - if len(match[1]) > 0 { - db, err = strconv.Atoi(match[1]) + match := pathDBRegexp.FindStringSubmatch(u.Path) + if len(match) == 2 { + if len(match[1]) > 0 { + db, err = strconv.Atoi(match[1]) + if err != nil { + return nil, fmt.Errorf("invalid database: %s", u.Path[1:]) + } + } + if db != 0 { + options = append(options, DialDatabase(db)) + } + } else if u.Path != "" { + return nil, fmt.Errorf("invalid database: %s", u.Path[1:]) + } + options = append(options, DialUseTLS(u.Scheme == "rediss")) + case "redis+unix": + network = "unix" + address = u.Path + dbParameter := u.Query().Get("db") + if dbParameter != "" { + db, err = strconv.Atoi(dbParameter) if err != nil { return nil, fmt.Errorf("invalid database: %s", u.Path[1:]) } + if db != 0 { + options = append(options, DialDatabase(db)) + } } - if db != 0 { - options = append(options, DialDatabase(db)) - } - } else if u.Path != "" { - return nil, fmt.Errorf("invalid database: %s", u.Path[1:]) + options = append(options, DialUseTLS(false)) + default: + return nil, fmt.Errorf("invalid redis URL scheme: %s", u.Scheme) } - options = append(options, DialUseTLS(u.Scheme == "rediss")) - - return DialContext(ctx, "tcp", address, options...) + return DialContext(ctx, network, address, options...) } // NewConn returns a new Redigo connection for the given net connection. diff --git a/redis/conn_test.go b/redis/conn_test.go index b94f0f4..6392d86 100644 --- a/redis/conn_test.go +++ b/redis/conn_test.go @@ -23,10 +23,11 @@ import ( "io" "math" "net" - "sync" "os" + "path/filepath" "reflect" "strings" + "sync" "testing" "time" @@ -661,6 +662,8 @@ func TestDialURLHost(t *testing.T) { } } +var workingDirectory, _ = os.Getwd() + var dialURLTests = []struct { description string url string @@ -680,6 +683,9 @@ var dialURLTests = []struct { {"database 3", "redis://localhost/3", "+OK\r\n", "*2\r\n$6\r\nSELECT\r\n$1\r\n3\r\n"}, {"database 99", "redis://localhost/99", "+OK\r\n", "*2\r\n$6\r\nSELECT\r\n$2\r\n99\r\n"}, {"no database", "redis://localhost/", "+OK\r\n", ""}, + {"absolute socket path", "redis+unix://" + filepath.Join(workingDirectory, "server.sock"), "+OK\r\n", ""}, + {"relative socket path", "redis+unix://./server.sock", "+OK\r\n", ""}, + {"unix socket path database 99", "redis+unix://./server.sock?db=99", "+OK\r\n", "*2\r\n$6\r\nSELECT\r\n$2\r\n99\r\n"}, } func TestDialURL(t *testing.T) { diff --git a/redis/test_test.go b/redis/test_test.go index f759868..94ec0ea 100644 --- a/redis/test_test.go +++ b/redis/test_test.go @@ -21,7 +21,6 @@ import ( "flag" "fmt" "io" - "io/ioutil" "os" "os/exec" "regexp" @@ -40,10 +39,11 @@ var ( ErrNegativeInt = errNegativeInt serverPath = flag.String("redis-server", "redis-server", "Path to redis server binary") - serverAddress = flag.String("redis-address", "127.0.0.1", "The address of the server") + serverAddress = flag.String("redis-address", "127.0.0.1", "The TCP address of the server") serverBasePort = flag.Int("redis-port", 16379, "Beginning of port range for test servers") + serverSocket = flag.String("redis-socket", "./server.sock", "The UNIX socket of the server") serverLogName = flag.String("redis-log", "", "Write Redis server logs to `filename`") - serverLog = ioutil.Discard + serverLog = io.Discard defaultServerMu sync.Mutex defaultServer *Server @@ -190,6 +190,7 @@ func DefaultServerAddr() (string, error) { "default", "--port", strconv.Itoa(*serverBasePort), "--bind", *serverAddress, + "--unixsocket", *serverSocket, "--save", "", "--appendonly", "no") return addr, defaultServerErr