Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/commands/cmd_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ class CommandClient : public Commander {

Status Execute([[maybe_unused]] engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override {
if (subcommand_ == "list") {
*output = conn->VerbatimString("txt", srv->GetClientsStr());
*output = conn->VerbatimString("txt", srv->GetClientsStr(conn));
return Status::OK();
} else if (subcommand_ == "info") {
*output = conn->VerbatimString("txt", conn->ToString());
Expand Down
23 changes: 16 additions & 7 deletions src/server/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1805,16 +1805,19 @@ void Server::SlowlogPushEntryIfNeeded(const std::vector<std::string> *args, uint
slow_log_.PushEntry(std::move(entry));
}

std::string Server::GetClientsStr() {
std::string Server::GetClientsStr(const redis::Connection *conn) {
std::string clients;
for (const auto &t : worker_threads_) {
clients.append(t->GetWorker()->GetClientsStr());
clients.append(t->GetWorker()->GetClientsStr(conn));
}

std::shared_lock<std::shared_mutex> guard(slave_threads_mu_);

for (const auto &st : slave_threads_) {
clients.append(st->GetConn()->ToString());
// Slave (replication) connections live outside any tenant namespace, so
// only admin (default-namespace) callers may enumerate them.
if (conn->IsAdmin()) {
std::shared_lock<std::shared_mutex> guard(slave_threads_mu_);
for (const auto &st : slave_threads_) {
clients.append(st->GetConn()->ToString());
}
}

return clients;
Expand All @@ -1824,13 +1827,19 @@ void Server::KillClient(int64_t *killed, const std::string &addr, uint64_t id, u
redis::Connection *conn) {
*killed = 0;

// Normal clients and pubsub clients
// Normal clients and pubsub clients (per-worker filtering applies the
// namespace check for non-admin callers).
for (const auto &t : worker_threads_) {
int64_t killed_in_worker = 0;
t->GetWorker()->KillClient(conn, id, addr, type, skipme, &killed_in_worker);
*killed += killed_in_worker;
}

// Replication links (master / slave) are not tenant-owned; only admin
// callers may terminate them, otherwise a non-admin tenant could
// disrupt replication.
if (!conn->IsAdmin()) return;

// Slave clients
{
std::unique_lock<std::shared_mutex> guard(slave_threads_mu_);
Expand Down
2 changes: 1 addition & 1 deletion src/server/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ class Server {
int DecrMonitorClientNum();
int IncrBlockedClientNum();
int DecrBlockedClientNum();
std::string GetClientsStr();
std::string GetClientsStr(const redis::Connection *conn);
uint64_t GetClientID();
void KillClient(int64_t *killed, const std::string &addr, uint64_t id, uint64_t type, bool skipme,
redis::Connection *conn);
Expand Down
12 changes: 9 additions & 3 deletions src/server/worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -536,13 +536,16 @@ void Worker::FeedMonitorConns(redis::Connection *conn, const std::string &respon
}
}

std::string Worker::GetClientsStr() {
std::string Worker::GetClientsStr(const redis::Connection *conn) {
std::unique_lock<std::mutex> lock(conns_mu_);

std::string output;
for (const auto &iter : conns_) {
redis::Connection *conn = iter.second;
output.append(conn->ToString());
// Non-admin callers must only see clients in their own namespace. Admin
// (default-namespace) callers see every client. Mirrors the namespace
// filtering in Worker::FeedMonitorConns.
if (!conn->IsAdmin() && iter.second->GetNamespace() != conn->GetNamespace()) continue;
output.append(iter.second->ToString());
}

return output;
Expand All @@ -555,6 +558,9 @@ void Worker::KillClient(redis::Connection *self, uint64_t id, const std::string
for (const auto &iter : conns_) {
redis::Connection *conn = iter.second;
if (skipme && self == conn) continue;
// Non-admin callers may only target clients in their own namespace, to
// prevent cross-tenant denial of service via CLIENT KILL.
if (!self->IsAdmin() && conn->GetNamespace() != self->GetNamespace()) continue;

// no need to kill the client again if the kCloseAfterReply flag is set
if (conn->IsFlagEnabled(redis::Connection::kCloseAfterReply)) {
Expand Down
2 changes: 1 addition & 1 deletion src/server/worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class Worker : EventCallbackBase<Worker>, EvconnlistenerBase<Worker> {
void QuitMonitorConn(redis::Connection *conn);
void FeedMonitorConns(redis::Connection *conn, const std::string &response);

std::string GetClientsStr();
std::string GetClientsStr(const redis::Connection *conn);
void KillClient(redis::Connection *self, uint64_t id, const std::string &addr, uint64_t type, bool skipme,
int64_t *killed);
void KickoutIdleClients(int timeout);
Expand Down
226 changes: 226 additions & 0 deletions tests/gocase/unit/introspection/client_namespace_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package introspection

import (
"context"
"strings"
"testing"
"time"

"github.com/apache/kvrocks/tests/gocase/util"
"github.com/redis/go-redis/v9"
"github.com/stretchr/testify/require"
)

// kvrocks' internal name for the default (admin) namespace.
const defaultNS = "__namespace"

// tenantConn is a TCP-level connection authenticated against a namespace
// token (or requirepass, for admin). TCP rather than go-redis is used so
// that a server-side kill is directly observable — go-redis transparently
// reconnects, masking the close.
type tenantConn struct {
*util.TCPClient
}

// dial opens a new authenticated TCP connection.
func dial(t *testing.T, srv *util.KvrocksServer, password string) *tenantConn {
t.Helper()
c := srv.NewTCPClient()
t.Cleanup(func() { _ = c.Close() })
require.NoError(t, c.WriteArgs("AUTH", password))
c.MustRead(t, "+OK")
return &tenantConn{c}
}

// info returns one parsed field from CLIENT INFO (e.g. "id", "addr").
func (c *tenantConn) info(t *testing.T, key string) string {
t.Helper()
require.NoError(t, c.WriteArgs("CLIENT", "INFO"))
// CLIENT INFO returns a bulk string. Connection::ToString already ends
// with \n, so the RESP frame is "$<len>\r\n<body>\n\r\n" and ReadLine
// breaks at the embedded \n. Consume header, body, then trailer to
// realign the buffer.
_, err := c.ReadLine()
require.NoError(t, err)
body, err := c.ReadLine()
require.NoError(t, err)
_, err = c.ReadLine()
require.NoError(t, err)
for field := range strings.FieldsSeq(body) {
if v, ok := strings.CutPrefix(field, key+"="); ok {
return v
}
}
t.Fatalf("no %s= field in CLIENT INFO: %q", key, body)
return ""
}

// requireAlive asserts the connection still responds to PING.
func (c *tenantConn) requireAlive(t *testing.T) {
t.Helper()
require.NoError(t, c.WriteArgs("PING"))
c.MustRead(t, "+PONG")
}

// requireKilled asserts the server has (or imminently will) close the connection.
func (c *tenantConn) requireKilled(t *testing.T) {
t.Helper()
require.Eventually(t, func() bool {
if err := c.WriteArgs("PING"); err != nil {
return true
}
_, err := c.ReadLine()
return err != nil
}, 5*time.Second, 100*time.Millisecond, "connection was expected to be killed")
}

// countNamespaceLines counts CLIENT LIST rows whose `namespace=` field equals ns.
func countNamespaceLines(list, ns string) int {
count := 0
for line := range strings.SplitSeq(list, "\n") {
if strings.Contains(line, " namespace="+ns+" ") {
count++
}
}
return count
}

// TestClientCommandNamespaceIsolation verifies that CLIENT LIST / INFO / KILL
// are scoped to the caller's namespace for non-admin (tenant) connections,
// while admin connections (authenticated via requirepass / default namespace)
// retain server-wide visibility and control.
//
// These tests cover the cross-namespace isolation bypass on CLIENT LIST /
// INFO / KILL: without filtering, a tenant authenticated against a
// non-default namespace can both enumerate and terminate connections that
// belong to other namespaces (including the admin namespace).
func TestClientCommandNamespaceIsolation(t *testing.T) {
const adminPass = "adminpass"
srv := util.StartServer(t, map[string]string{"requirepass": adminPass})
defer srv.Close()

ctx := context.Background()

admin := srv.NewClientWithOption(&redis.Options{Password: adminPass})
defer func() { require.NoError(t, admin.Close()) }()
require.NoError(t, admin.Do(ctx, "NAMESPACE", "ADD", "ns1", "token1").Err())
require.NoError(t, admin.Do(ctx, "NAMESPACE", "ADD", "ns2", "token2").Err())

t.Run("CLIENT LIST: tenant only sees its own namespace", func(t *testing.T) {
_ = dial(t, srv, "token1")
_ = dial(t, srv, "token1")
_ = dial(t, srv, "token2")

ns1 := srv.NewClientWithOption(&redis.Options{Password: "token1"})
defer func() { require.NoError(t, ns1.Close()) }()

list := ns1.ClientList(ctx).Val()
require.NotEmpty(t, list)
require.GreaterOrEqual(t, countNamespaceLines(list, "ns1"), 2,
"ns1 tenant should see at least its own connections, got:\n%s", list)
require.Equal(t, 0, countNamespaceLines(list, "ns2"),
"ns1 tenant must not see ns2 connections, got:\n%s", list)
require.Equal(t, 0, countNamespaceLines(list, defaultNS),
"ns1 tenant must not see default-namespace (admin) connections, got:\n%s", list)
})

t.Run("CLIENT LIST: admin sees every namespace", func(t *testing.T) {
_ = dial(t, srv, "token1")
_ = dial(t, srv, "token2")

list := admin.ClientList(ctx).Val()
require.GreaterOrEqual(t, countNamespaceLines(list, "ns1"), 1, list)
require.GreaterOrEqual(t, countNamespaceLines(list, "ns2"), 1, list)
require.GreaterOrEqual(t, countNamespaceLines(list, defaultNS), 1, list)
})

t.Run("CLIENT INFO: only describes the caller's own connection", func(t *testing.T) {
ns1 := srv.NewClientWithOption(&redis.Options{Password: "token1"})
defer func() { require.NoError(t, ns1.Close()) }()

info, err := ns1.Do(ctx, "CLIENT", "INFO").Text()
require.NoError(t, err)
require.Contains(t, info, " namespace=ns1 ")
require.NotContains(t, info, " namespace="+defaultNS+" ")
require.NotContains(t, info, " namespace=ns2 ")
})

t.Run("CLIENT KILL by ID: tenant cannot kill another namespace", func(t *testing.T) {
conn2 := dial(t, srv, "token2")
attacker := srv.NewClientWithOption(&redis.Options{Password: "token1"})
defer func() { require.NoError(t, attacker.Close()) }()

killed := attacker.ClientKillByFilter(ctx, "id", conn2.info(t, "id")).Val()
require.EqualValues(t, 0, killed,
"ns1 tenant must not be able to kill a ns2 connection by ID")
conn2.requireAlive(t)
})

t.Run("CLIENT KILL by ID: tenant cannot kill an admin connection", func(t *testing.T) {
adminConn := dial(t, srv, adminPass)
attacker := srv.NewClientWithOption(&redis.Options{Password: "token1"})
defer func() { require.NoError(t, attacker.Close()) }()

killed := attacker.ClientKillByFilter(ctx, "id", adminConn.info(t, "id")).Val()
require.EqualValues(t, 0, killed,
"ns1 tenant must not be able to kill an admin/default-namespace connection")
adminConn.requireAlive(t)
})

t.Run("CLIENT KILL by ADDR: tenant cannot kill another namespace", func(t *testing.T) {
conn2 := dial(t, srv, "token2")
attacker := srv.NewClientWithOption(&redis.Options{Password: "token1"})
defer func() { require.NoError(t, attacker.Close()) }()

// The legacy "CLIENT KILL <addr>" form should reply with an error
// ("No such client") because, from ns1's perspective, the ns2
// connection does not exist.
err := attacker.ClientKill(ctx, conn2.info(t, "addr")).Err()
require.Error(t, err, "ns1 tenant must not be able to kill a ns2 connection by ADDR")
conn2.requireAlive(t)
})

t.Run("CLIENT KILL TYPE normal: tenant only affects its own namespace", func(t *testing.T) {
conn1 := dial(t, srv, "token1")
conn2 := dial(t, srv, "token2")
adminConn := dial(t, srv, adminPass)

attacker := srv.NewClientWithOption(&redis.Options{Password: "token1"})
defer func() { require.NoError(t, attacker.Close()) }()

killed := attacker.ClientKillByFilter(ctx, "skipme", "yes", "type", "normal").Val()
require.GreaterOrEqual(t, killed, int64(1))

conn1.requireKilled(t)
conn2.requireAlive(t)
adminConn.requireAlive(t)
})

t.Run("CLIENT KILL: admin retains full server-wide power", func(t *testing.T) {
conn2 := dial(t, srv, "token2")

killed := admin.ClientKillByFilter(ctx, "id", conn2.info(t, "id")).Val()
require.EqualValues(t, 1, killed,
"admin must be able to kill a connection in any namespace by ID")
conn2.requireKilled(t)
})
}
Loading