Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

THRIFT-5774: Add remote client's IP address to ServerContext in TServ… #2959

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 32 additions & 3 deletions lib/java/src/crossTest/java/org/apache/thrift/test/TestServer.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.thrift.test;

import java.net.SocketAddress;
import org.apache.thrift.TMultiplexedProcessor;
import org.apache.thrift.protocol.TBinaryProtocol;
import org.apache.thrift.protocol.TCompactProtocol;
Expand Down Expand Up @@ -69,8 +70,15 @@ static class TestServerContext implements ServerContext {

int connectionId;

public TestServerContext(int connectionId) {
SocketAddress remoteSocketAddress;

SocketAddress localSocketAddress;

public TestServerContext(
int connectionId, SocketAddress remoteSocketAddress, SocketAddress localSocketAddress) {
this.connectionId = connectionId;
this.remoteSocketAddress = remoteSocketAddress;
this.localSocketAddress = localSocketAddress;
}

public int getConnectionId() {
Expand All @@ -81,6 +89,22 @@ public void setConnectionId(int connectionId) {
this.connectionId = connectionId;
}

public SocketAddress getRemoteSocketAddress() {
return remoteSocketAddress;
}

public void setRemoteSocketAddress(SocketAddress remoteSocketAddress) {
this.remoteSocketAddress = remoteSocketAddress;
}

public SocketAddress getLocalSocketAddress() {
return localSocketAddress;
}

public void setLocalSocketAddress(SocketAddress localSocketAddress) {
this.localSocketAddress = localSocketAddress;
}

@Override
public <T> T unwrap(Class<T> iface) {
try {
Expand Down Expand Up @@ -110,9 +134,14 @@ public void preServe() {
"TServerEventHandler.preServe - called only once before server starts accepting connections");
}

public ServerContext createContext(TProtocol input, TProtocol output) {
public ServerContext createContext(
TProtocol input,
TProtocol output,
SocketAddress remoteSocketAddress,
SocketAddress localSocketAddress) {
// we can create some connection level data which is stored while connection is alive & served
TestServerContext ctx = new TestServerContext(nextConnectionId++);
TestServerContext ctx =
new TestServerContext(nextConnectionId++, remoteSocketAddress, localSocketAddress);
System.out.println(
"TServerEventHandler.createContext - connection #"
+ ctx.getConnectionId()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package org.apache.thrift.server;

import java.io.IOException;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.Selector;
Expand All @@ -31,6 +32,7 @@
import org.apache.thrift.TByteArrayOutputStream;
import org.apache.thrift.TException;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.transport.SocketAddressProvider;
import org.apache.thrift.transport.TIOStreamTransport;
import org.apache.thrift.transport.TMemoryInputTransport;
import org.apache.thrift.transport.TNonblockingServerTransport;
Expand Down Expand Up @@ -296,7 +298,15 @@ public FrameBuffer(
outProt_ = outputProtocolFactory_.getProtocol(outTrans_);

if (eventHandler_ != null) {
context_ = eventHandler_.createContext(inProt_, outProt_);
SocketAddress remoteSocketAddress = null;
SocketAddress localSocketAddress = null;
if (trans_ instanceof SocketAddressProvider) {
SocketAddressProvider socketAddressProvider = (SocketAddressProvider) trans_;
localSocketAddress = socketAddressProvider.getLocalSocketAddress();
remoteSocketAddress = socketAddressProvider.getRemoteSocketAddress();
}
context_ =
eventHandler_.createContext(inProt_, outProt_, remoteSocketAddress, localSocketAddress);
} else {
context_ = null;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.thrift.server;

import java.net.SocketAddress;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.transport.TTransport;

Expand All @@ -37,8 +38,15 @@ public interface TServerEventHandler {
/** Called before the server begins. */
void preServe();

/** Called when a new client has connected and is about to being processing. */
ServerContext createContext(TProtocol input, TProtocol output);
/**
* Called when a new client has connected and is about to being processing. The
* remoteSocketAddress and localSocketAddress are null when transport is not socket based
*/
ServerContext createContext(
TProtocol input,
TProtocol output,
SocketAddress remoteSocketAddress,
SocketAddress localSocketAddress);

/** Called when a client has finished request-handling to delete server context. */
void deleteContext(ServerContext serverContext, TProtocol input, TProtocol output);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@

package org.apache.thrift.server;

import java.net.SocketAddress;
import org.apache.thrift.TException;
import org.apache.thrift.TProcessor;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.transport.SocketAddressProvider;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import org.slf4j.Logger;
Expand Down Expand Up @@ -69,7 +71,16 @@ public void serve() {
inputProtocol = inputProtocolFactory_.getProtocol(inputTransport);
outputProtocol = outputProtocolFactory_.getProtocol(outputTransport);
if (eventHandler_ != null) {
connectionContext = eventHandler_.createContext(inputProtocol, outputProtocol);
SocketAddress remoteSocketAddress = null;
SocketAddress localSocketAddress = null;
if (client instanceof SocketAddressProvider) {
SocketAddressProvider socketAddressProvider = (SocketAddressProvider) client;
remoteSocketAddress = socketAddressProvider.getRemoteSocketAddress();
localSocketAddress = socketAddressProvider.getLocalSocketAddress();
}
connectionContext =
eventHandler_.createContext(
inputProtocol, outputProtocol, remoteSocketAddress, localSocketAddress);
}
while (true) {
if (eventHandler_ != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.thrift.server;

import java.net.SocketAddress;
import java.net.SocketException;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
Expand All @@ -31,6 +32,7 @@
import org.apache.thrift.TException;
import org.apache.thrift.TProcessor;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.transport.SocketAddressProvider;
import org.apache.thrift.transport.TServerTransport;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
Expand Down Expand Up @@ -239,7 +241,16 @@ public void run() {
eventHandler = Optional.ofNullable(getEventHandler());

if (eventHandler.isPresent()) {
connectionContext = eventHandler.get().createContext(inputProtocol, outputProtocol);
SocketAddress remoteSocketAddress = null;
SocketAddress localSocketAddress = null;
if (client_ instanceof SocketAddressProvider) {
SocketAddressProvider socketAddressProvider = (SocketAddressProvider) client_;
remoteSocketAddress = socketAddressProvider.getRemoteSocketAddress();
localSocketAddress = socketAddressProvider.getLocalSocketAddress();
}
connectionContext =
eventHandler_.createContext(
inputProtocol, outputProtocol, remoteSocketAddress, localSocketAddress);
}

while (true) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.thrift.transport;

import java.net.SocketAddress;

/** Interface that can retrieve the socket address. */
public interface SocketAddressProvider {

SocketAddress getRemoteSocketAddress();

SocketAddress getLocalSocketAddress();
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import org.slf4j.LoggerFactory;

/** Transport for use with async client. */
public class TNonblockingSocket extends TNonblockingTransport {
public class TNonblockingSocket extends TNonblockingTransport implements SocketAddressProvider {

private static final Logger LOGGER = LoggerFactory.getLogger(TNonblockingSocket.class.getName());

Expand Down Expand Up @@ -205,4 +205,14 @@ public String toString() {
+ socketChannel_.socket().getLocalAddress()
+ "]";
}

@Override
public SocketAddress getRemoteSocketAddress() {
return socketChannel_.socket().getRemoteSocketAddress();
}

@Override
public SocketAddress getLocalSocketAddress() {
return socketChannel_.socket().getLocalSocketAddress();
}
}
13 changes: 12 additions & 1 deletion lib/java/src/main/java/org/apache/thrift/transport/TSocket.java
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.Socket;
import java.net.SocketAddress;
import java.net.SocketException;
import org.apache.thrift.TConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Socket implementation of the TTransport interface. To be commented soon! */
public class TSocket extends TIOStreamTransport {
public class TSocket extends TIOStreamTransport implements SocketAddressProvider {

private static final Logger LOGGER = LoggerFactory.getLogger(TSocket.class.getName());

Expand Down Expand Up @@ -247,4 +248,14 @@ public void close() {
socket_ = null;
}
}

@Override
public SocketAddress getRemoteSocketAddress() {
return socket_.getRemoteSocketAddress();
}

@Override
public SocketAddress getLocalSocketAddress() {
return socket_.getLocalSocketAddress();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import static org.apache.thrift.transport.sasl.NegotiationStatus.COMPLETE;
import static org.apache.thrift.transport.sasl.NegotiationStatus.OK;

import java.net.SocketAddress;
import java.nio.channels.SelectionKey;
import java.nio.charset.StandardCharsets;
import javax.security.sasl.SaslServer;
Expand All @@ -31,6 +32,7 @@
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.server.ServerContext;
import org.apache.thrift.server.TServerEventHandler;
import org.apache.thrift.transport.SocketAddressProvider;
import org.apache.thrift.transport.TMemoryTransport;
import org.apache.thrift.transport.TNonblockingTransport;
import org.apache.thrift.transport.TTransportException;
Expand Down Expand Up @@ -324,7 +326,17 @@ private void executeProcessing() {

if (eventHandler != null) {
if (!serverContextCreated) {
serverContext = eventHandler.createContext(requestProtocol, responseProtocol);
SocketAddress remoteSocketAddress = null;
SocketAddress localSocketAddress = null;
if (underlyingTransport instanceof SocketAddressProvider) {
SocketAddressProvider socketAddressProvider =
(SocketAddressProvider) underlyingTransport;
remoteSocketAddress = socketAddressProvider.getRemoteSocketAddress();
localSocketAddress = socketAddressProvider.getLocalSocketAddress();
}
serverContext =
eventHandler.createContext(
requestProtocol, responseProtocol, remoteSocketAddress, localSocketAddress);
serverContextCreated = true;
}
eventHandler.processContext(serverContext, memoryTransport, memoryTransport);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import com.github.ajalt.clikt.parameters.options.option
import com.github.ajalt.clikt.parameters.types.enum
import com.github.ajalt.clikt.parameters.types.int
import com.github.ajalt.clikt.parameters.types.long
import java.net.SocketAddress
import kotlinx.coroutines.GlobalScope
import org.apache.thrift.TException
import org.apache.thrift.TMultiplexedProcessor
Expand Down Expand Up @@ -73,7 +74,11 @@ object TestServer {
}
}

internal class TestServerContext(var connectionId: Int) : ServerContext {
internal class TestServerContext(
var connectionId: Int,
var remoteSocketAddress: SocketAddress,
var localSocketAddress: SocketAddress
) : ServerContext {

override fun <T> unwrap(iface: Class<T>): T {
try {
Expand Down Expand Up @@ -102,10 +107,15 @@ object TestServer {
)
}

override fun createContext(input: TProtocol, output: TProtocol): ServerContext {
override fun createContext(
input: TProtocol,
output: TProtocol,
remoteSocketAddress: SocketAddress,
localSocketAddress: SocketAddress
): ServerContext {
// we can create some connection level data which is stored while connection is alive &
// served
val ctx = TestServerContext(nextConnectionId++)
val ctx = TestServerContext(nextConnectionId++, remoteSocketAddress, localSocketAddress)
println(
"TServerEventHandler.createContext - connection #" +
ctx.connectionId +
Expand Down