Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(plc4j/opcua): Remove additional hostname resolution within OPC-UA #2028

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,11 +136,11 @@ public class OpcuaConfiguration implements PlcConnectionConfiguration {
private Limits limits;

@ConfigurationParameter("endpoint-host")
@Description("Endpoint host used to establish secure channel.")
@Description("Endpoint host used to establish secure channel connection. Used when client made connection to server which advertises different hostname than one used for network connection.")
private String endpointHost;

@ConfigurationParameter("endpoint-port")
@Description("Endpoint port used to establish secure channel")
@Description("Endpoint port used to establish secure channel. Used when client made connection to server which advertises different port number than one used for network connection.")
private Integer endpointPort;

public String getProtocolCode() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,6 @@ public void setConfiguration(OpcuaConfiguration configuration) {
port = matcher.group("transportPort");
transportEndpoint = matcher.group("transportEndpoint");

if (configuration.getEndpointHost() != null) {
host = configuration.getEndpointHost();
}
if (configuration.getEndpointPort() != null) {
port = String.valueOf(configuration.getEndpointPort());
}

String portAddition = port != null ? ":" + port : "";
endpoint = "opc." + code + "://" + host + portAddition + transportEndpoint;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
import java.security.cert.CertificateEncodingException;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Map.Entry;
import java.util.Set;
import java.util.concurrent.ScheduledExecutorService;
Expand All @@ -54,8 +52,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.net.InetAddress;
import java.net.UnknownHostException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.ArrayList;
Expand Down Expand Up @@ -95,7 +91,6 @@ public class SecureChannel {
private final OpcuaDriverContext driverContext;
private final Conversation conversation;
private ScheduledFuture<?> keepAlive;
private final Set<String> endpoints = new HashSet<>();
private double sessionTimeout;
private long revisedLifetime;

Expand All @@ -118,17 +113,6 @@ public SecureChannel(Conversation conversation, RequestTransactionManager tm, Op
this.password = configuration.getPassword();
}

// Generate a list of endpoints we can use.
try {
InetAddress address = InetAddress.getByName(driverContext.getHost());
this.endpoints.add("opc.tcp://" + address.getHostAddress() + ":" + driverContext.getPort() + driverContext.getTransportEndpoint());
this.endpoints.add("opc.tcp://" + address.getHostName() + ":" + driverContext.getPort() + driverContext.getTransportEndpoint());
this.endpoints.add("opc.tcp://" + address.getCanonicalHostName() + ":" + driverContext.getPort() + driverContext.getTransportEndpoint());
} catch (UnknownHostException e) {
LOGGER.warn("Unable to resolve host name. Using original host from connection string which may cause issues connecting to server");
this.endpoints.add(driverContext.getHost());
}

if (conversation.getSecurityPolicy() == SecurityPolicy.NONE) {
this.localCertificateString = NULL_BYTE_STRING;
this.remoteCertificateThumbprint = NULL_BYTE_STRING;
Expand Down Expand Up @@ -314,23 +298,10 @@ private CompletableFuture<ActivateSessionResponse> onConnectActivateSessionReque
conversation.setRemoteCertificate(getX509Certificate(sessionResponse.getServerCertificate().getStringValue()));
conversation.setRemoteNonce(sessionResponse.getServerNonce().getStringValue());

List<String> contactPoints = new ArrayList<>(3);
String port = driverContext.getPort() == null ? "" : ":" + driverContext.getPort();
try {
InetAddress address = InetAddress.getByName(driverContext.getHost());
contactPoints.add("opc.tcp://" + address.getHostAddress() + port + driverContext.getTransportEndpoint());
contactPoints.add("opc.tcp://" + address.getHostName() + port + driverContext.getTransportEndpoint());
contactPoints.add("opc.tcp://" + address.getCanonicalHostName() + port + driverContext.getTransportEndpoint());
} catch (UnknownHostException e) {
// fall back to declared host
contactPoints.add("opc.tcp://" + driverContext.getHost() + port + driverContext.getTransportEndpoint());
LOGGER.warn("Could not reach host {}, possible network failure", driverContext.getHost(), e);
}

Entry<EndpointDescription, UserTokenPolicy> selectedEndpoint = selectEndpoint(sessionResponse.getServerEndpoints(), contactPoints,
Entry<EndpointDescription, UserTokenPolicy> selectedEndpoint = selectEndpoint(sessionResponse.getServerEndpoints(),
configuration.getSecurityPolicy(), configuration.getMessageSecurity());
if (selectedEndpoint == null) {
throw new PlcRuntimeException("Unable to find endpoint matching " + contactPoints.get(0));
throw new PlcRuntimeException("Unable to find endpoint matching " + driverContext.getEndpoint());
}

PascalString policyId = selectedEndpoint.getValue().getPolicyId();
Expand Down Expand Up @@ -421,7 +392,8 @@ public CompletableFuture<EndpointDescription> onDiscoverGetEndpointsRequest() {
);

return conversation.submit(endpointsRequest, GetEndpointsResponse.class).thenApply(response -> {
Entry<EndpointDescription, UserTokenPolicy> entry = selectEndpoint(response.getEndpoints(), this.endpoints, this.configuration.getSecurityPolicy(), this.configuration.getMessageSecurity());
Entry<EndpointDescription, UserTokenPolicy> entry = selectEndpoint(response.getEndpoints(),
this.configuration.getSecurityPolicy(), this.configuration.getMessageSecurity());

if (entry == null) {
Set<String> endpointUris = response.getEndpoints().stream()
Expand Down Expand Up @@ -494,19 +466,18 @@ private static ReadBufferByteBased toBuffer(Supplier<Payload> supplier) {
* Selects the endpoint and authentication policy based on client settings.
*
* @param extensionObjects Endpoint descriptions returned by the server.
* @param contactPoints Contact points expected by client.
* @param securityPolicy Security policy searched in endpoints.
* @param messageSecurity Message security needed by client.
* @return Endpoint matching given.
*/
private Entry<EndpointDescription, UserTokenPolicy> selectEndpoint(List<EndpointDescription> extensionObjects, Collection<String> contactPoints,
private Entry<EndpointDescription, UserTokenPolicy> selectEndpoint(List<EndpointDescription> extensionObjects,
SecurityPolicy securityPolicy, MessageSecurity messageSecurity) throws PlcRuntimeException {
// Get a list of the endpoints which match ours.
MessageSecurityMode effectiveMessageSecurity = SecurityPolicy.NONE == securityPolicy ? MessageSecurityMode.messageSecurityModeNone : messageSecurity.getMode();
List<Entry<EndpointDescription, UserTokenPolicy>> serverEndpoints = new ArrayList<>();

for (EndpointDescription endpointDescription : extensionObjects) {
if (isMatchingEndpoint(endpointDescription, contactPoints)) {
if (isMatchingEndpointDescription(endpointDescription)) {
boolean policyMatch = endpointDescription.getSecurityPolicyUri().getStringValue().equals(securityPolicy.getSecurityPolicyUri());
boolean msgSecurityMatch = endpointDescription.getSecurityMode().equals(effectiveMessageSecurity);

Expand All @@ -530,22 +501,32 @@ private Entry<EndpointDescription, UserTokenPolicy> selectEndpoint(List<Endpoint
return serverEndpoints.get(0);
}

private boolean isMatchingEndpointDescription(EndpointDescription endpointDescription) {
if (isMatchingEndpoint(endpointDescription, driverContext.getHost(), driverContext.getPort(), driverContext.getTransportEndpoint())) {
return true;
}
if (configuration.getEndpointHost() != null) {
return isMatchingEndpoint(endpointDescription, configuration.getEndpointHost(), configuration.getEndpointPort() == null ? driverContext.getPort() : String.valueOf(configuration.getEndpointPort()), driverContext.getTransportEndpoint());
} else if (configuration.getEndpointPort() != null) {
return isMatchingEndpoint(endpointDescription, driverContext.getHost(), configuration.getEndpointPort().toString(), driverContext.getTransportEndpoint());
}
return false;
}

/**
* Checks each component of the return endpoint description against the connection string.
* If all are correct then return true.
*
* @param endpoint - EndpointDescription returned from server
* @param host Permitted host
* @param port Permitted port
* @param transportEndpoint Transport endpoint
* @return true if this endpoint matches our configuration
* @throws PlcRuntimeException - If the returned endpoint string doesn't match the format expected
*/
private static boolean isMatchingEndpoint(EndpointDescription endpoint, Collection<String> contactPoints) throws PlcRuntimeException {
// Split up the connection string into it's individual segments.
for (String contactPoint : contactPoints) {
if (endpoint.getEndpointUrl().getStringValue().startsWith(contactPoint)) {
return true;
}
}
return false;
private static boolean isMatchingEndpoint(EndpointDescription endpoint, String host, String port, String transportEndpoint) throws PlcRuntimeException {
String portAddition = port == null ? "" : ":" + port;
return endpoint.getEndpointUrl().getStringValue().startsWith("opc.tcp://" + host + portAddition + transportEndpoint);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.apache.plc4x.java.DefaultPlcDriverManager;
Expand Down Expand Up @@ -500,9 +501,11 @@ public void writeVariables(SecurityPolicy policy, MessageSecurity messageSecurit
public void multipleThreads() throws Exception {
class ReadWorker extends Thread {
private final PlcConnection connection;
private final CountDownLatch latch;

public ReadWorker(PlcConnection opcuaConnection) {
public ReadWorker(PlcConnection opcuaConnection, CountDownLatch latch) {
this.connection = opcuaConnection;
this.latch = latch;
}

@Override
Expand All @@ -516,21 +519,24 @@ public void run() {
PlcReadResponse read_response = read_request.execute().get();
assertThat(read_response.getResponseCode("Bool")).isEqualTo(PlcResponseCode.OK);
}

} catch (ExecutionException e) {
LOGGER.error("run aborted", e);
} catch (InterruptedException e) {
LOGGER.error("thread interrupted", e);
Thread.currentThread().interrupt();
throw new RuntimeException(e);
} finally {
this.latch.countDown();
}
}
}

class WriteWorker extends Thread {
private final PlcConnection connection;
private final CountDownLatch latch;

public WriteWorker(PlcConnection opcuaConnection) {
public WriteWorker(PlcConnection opcuaConnection, CountDownLatch latch) {
this.connection = opcuaConnection;
this.latch = latch;
}

@Override
Expand All @@ -547,8 +553,10 @@ public void run() {
} catch (ExecutionException e) {
LOGGER.error("run aborted", e);
} catch (InterruptedException e) {
LOGGER.error("thread interrupted", e);
Thread.currentThread().interrupt();
throw new RuntimeException(e);
} finally {
this.latch.countDown();
}
}
}
Expand All @@ -558,13 +566,13 @@ public void run() {
Condition<PlcConnection> is_connected = new Condition<>(PlcConnection::isConnected, "is connected");
assertThat(opcuaConnection).is(is_connected);

ReadWorker read_worker = new ReadWorker(opcuaConnection);
WriteWorker write_worker = new WriteWorker(opcuaConnection);
CountDownLatch latch = new CountDownLatch(2);
ReadWorker read_worker = new ReadWorker(opcuaConnection, latch);
WriteWorker write_worker = new WriteWorker(opcuaConnection, latch);
read_worker.start();
write_worker.start();

read_worker.join();
write_worker.join();
latch.await();

opcuaConnection.close();
assert !opcuaConnection.isConnected();
Expand Down
Loading