Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions gpt4all-chat/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).

## [Unreleased]

### Fixed
- Make the the local server resistant to DNS rebind attacks ([#3587](https://github.com/nomic-ai/gpt4all/pull/3587))

## [3.10.0] - 2025-02-24

### Added
Expand Down Expand Up @@ -312,6 +317,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/).
- Fix several Vulkan resource management issues ([#2694](https://github.com/nomic-ai/gpt4all/pull/2694))
- Fix crash/hang when some models stop generating, by showing special tokens ([#2701](https://github.com/nomic-ai/gpt4all/pull/2701))

[Unreleased]: https://github.com/nomic-ai/gpt4all/compare/v3.10.0...HEAD
[3.10.0]: https://github.com/nomic-ai/gpt4all/compare/v3.9.0...v3.10.0
[3.9.0]: https://github.com/nomic-ai/gpt4all/compare/v3.8.0...v3.9.0
[3.8.0]: https://github.com/nomic-ai/gpt4all/compare/v3.7.0...v3.8.0
Expand Down
1 change: 1 addition & 0 deletions gpt4all-chat/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ qt_add_executable(chat
src/localdocsmodel.cpp src/localdocsmodel.h
src/logger.cpp src/logger.h
src/modellist.cpp src/modellist.h
src/mwhttpserver.cpp src/mwhttpserver.h
src/mysettings.cpp src/mysettings.h
src/network.cpp src/network.h
src/server.cpp src/server.h
Expand Down
15 changes: 15 additions & 0 deletions gpt4all-chat/src/mwhttpserver.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include <QTcpServer>

#include "mwhttpserver.h"


namespace gpt4all::ui {


MwHttpServer::MwHttpServer()
: m_httpServer()
, m_tcpServer (new QTcpServer(&m_httpServer))
{}


} // namespace gpt4all::ui
60 changes: 60 additions & 0 deletions gpt4all-chat/src/mwhttpserver.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#pragma once

#include <QHttpServer>
#include <QHttpServerRequest>

#include <functional>
#include <optional>
#include <utility>
#include <vector>

class QHttpServerResponse;
class QHttpServerRouterRule;
class QString;


namespace gpt4all::ui {


/// @brief QHttpServer wrapper with middleware support.
///
/// This class wraps QHttpServer and provides addBeforeRequestHandler() to add middleware.
class MwHttpServer
{
using BeforeRequestHandler = std::function<std::optional<QHttpServerResponse>(const QHttpServerRequest &)>;

public:
explicit MwHttpServer();

bool bind() { return m_httpServer.bind(m_tcpServer); }

void addBeforeRequestHandler(BeforeRequestHandler handler)
{ m_beforeRequestHandlers.push_back(std::move(handler)); }

template <typename Handler>
void addAfterRequestHandler(
const typename QtPrivate::ContextTypeForFunctor<Handler>::ContextType *context, Handler &&handler
) {
return m_httpServer.addAfterRequestHandler(context, std::forward<Handler>(handler));
}

template <typename... Args>
QHttpServerRouterRule *route(
const QString &pathPattern,
QHttpServerRequest::Methods method,
std::function<QHttpServerResponse(Args..., const QHttpServerRequest &)> viewHandler
);

QTcpServer *tcpServer() { return m_tcpServer; }

private:
QHttpServer m_httpServer;
QTcpServer *m_tcpServer;
std::vector<BeforeRequestHandler> m_beforeRequestHandlers;
};


} // namespace gpt4all::ui


#include "mwhttpserver.inl" // IWYU pragma: export
20 changes: 20 additions & 0 deletions gpt4all-chat/src/mwhttpserver.inl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
namespace gpt4all::ui {


template <typename... Args>
QHttpServerRouterRule *MwHttpServer::route(
const QString &pathPattern,
QHttpServerRequest::Methods method,
std::function<QHttpServerResponse(Args..., const QHttpServerRequest &)> viewHandler
) {
auto wrapped = [this, vh = std::move(viewHandler)](Args ...args, const QHttpServerRequest &req) {
for (auto &handler : m_beforeRequestHandlers)
if (auto resp = handler(req))
return *std::move(resp);
return vh(std::forward<Args>(args)..., req);
};
return m_httpServer.route(pathPattern, method, std::move(wrapped));
}


} // namespace gpt4all::ui
68 changes: 58 additions & 10 deletions gpt4all-chat/src/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
#include "chat.h"
#include "chatmodel.h"
#include "modellist.h"
#include "mwhttpserver.h"
#include "mysettings.h"
#include "utils.h" // IWYU pragma: keep

#include <fmt/format.h>
#include <gpt4all-backend/llmodel.h>

#include <QAbstractSocket>
#include <QByteArray>
#include <QCborArray>
#include <QCborMap>
Expand Down Expand Up @@ -51,6 +53,7 @@

using namespace std::string_literals;
using namespace Qt::Literals::StringLiterals;
using namespace gpt4all::ui;

//#define DEBUG

Expand Down Expand Up @@ -443,6 +446,8 @@ Server::Server(Chat *chat)
connect(chat, &Chat::collectionListChanged, this, &Server::handleCollectionListChanged, Qt::QueuedConnection);
}

Server::~Server() = default;

static QJsonObject requestFromJson(const QByteArray &request)
{
QJsonParseError err;
Expand All @@ -455,17 +460,60 @@ static QJsonObject requestFromJson(const QByteArray &request)
return document.object();
}

/// @brief Check if a host is safe to use to connect to the server.
///
/// GPT4All's local server is not safe to expose to the internet, as it does not provide
/// any form of authentication. DNS rebind attacks bypass CORS and without additional host
/// header validation, malicious websites can access the server in client-side js.
///
/// @param host The value of the "Host" header or ":authority" pseudo-header
/// @return true if the host is unsafe, false otherwise
static bool isHostUnsafe(const QString &host)
{
QHostAddress addr;
if (addr.setAddress(host) && addr.protocol() == QAbstractSocket::IPv4Protocol)
return false; // ipv4

// ipv6 host is wrapped in square brackets
static const QRegularExpression ipv6Re(uR"(^\[(.+)\]$)"_s);
if (auto match = ipv6Re.match(host); match.hasMatch()) {
auto ipv6 = match.captured(1);
if (addr.setAddress(ipv6) && addr.protocol() == QAbstractSocket::IPv6Protocol)
return false; // ipv6
}

if (!host.contains('.'))
return false; // dotless hostname

static const QStringList allowedTlds { u".local"_s, u".test"_s, u".internal"_s };
for (auto &tld : allowedTlds)
if (host.endsWith(tld, Qt::CaseInsensitive))
return false; // local TLD

return true; // unsafe
}

void Server::start()
{
m_server = std::make_unique<QHttpServer>(this);
auto *tcpServer = new QTcpServer(m_server.get());
m_server = std::make_unique<MwHttpServer>();

m_server->addBeforeRequestHandler([](const QHttpServerRequest &req) -> std::optional<QHttpServerResponse> {
// this works for HTTP/1.1 "Host" header and HTTP/2 ":authority" pseudo-header
auto host = req.url().host();
if (!host.isEmpty() && isHostUnsafe(host))
return QHttpServerResponse(
QJsonObject { { u"error"_s, u"Access to the server via non-local host %1 is forbidden."_s.arg(host) } },
QHttpServerResponder::StatusCode::Forbidden
);
return std::nullopt;
});

auto port = MySettings::globalInstance()->networkPort();
if (!tcpServer->listen(QHostAddress::LocalHost, port)) {
if (!m_server->tcpServer()->listen(QHostAddress::LocalHost, port)) {
qWarning() << "Server ERROR: Failed to listen on port" << port;
return;
}
if (!m_server->bind(tcpServer)) {
if (!m_server->bind()) {
qWarning() << "Server ERROR: Failed to HTTP server to socket" << port;
return;
}
Expand All @@ -490,7 +538,7 @@ void Server::start()
}
);

m_server->route("/v1/models/<arg>", QHttpServerRequest::Method::Get,
m_server->route<const QString &>("/v1/models/<arg>", QHttpServerRequest::Method::Get,
[](const QString &model, const QHttpServerRequest &) {
if (!MySettings::globalInstance()->serverChat())
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
Expand Down Expand Up @@ -562,7 +610,7 @@ void Server::start()

// Respond with code 405 to wrong HTTP methods:
m_server->route("/v1/models", QHttpServerRequest::Method::Post,
[] {
[](const QHttpServerRequest &) {
if (!MySettings::globalInstance()->serverChat())
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
return QHttpServerResponse(
Expand All @@ -573,8 +621,8 @@ void Server::start()
}
);

m_server->route("/v1/models/<arg>", QHttpServerRequest::Method::Post,
[](const QString &model) {
m_server->route<const QString &>("/v1/models/<arg>", QHttpServerRequest::Method::Post,
[](const QString &model, const QHttpServerRequest &) {
(void)model;
if (!MySettings::globalInstance()->serverChat())
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
Expand All @@ -587,7 +635,7 @@ void Server::start()
);

m_server->route("/v1/completions", QHttpServerRequest::Method::Get,
[] {
[](const QHttpServerRequest &) {
if (!MySettings::globalInstance()->serverChat())
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
return QHttpServerResponse(
Expand All @@ -598,7 +646,7 @@ void Server::start()
);

m_server->route("/v1/chat/completions", QHttpServerRequest::Method::Get,
[] {
[](const QHttpServerRequest &) {
if (!MySettings::globalInstance()->serverChat())
return QHttpServerResponse(QHttpServerResponder::StatusCode::Unauthorized);
return QHttpServerResponse(
Expand Down
6 changes: 3 additions & 3 deletions gpt4all-chat/src/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "chatllm.h"
#include "database.h"

#include <QHttpServer>
#include <QHttpServerResponse>
#include <QJsonObject>
#include <QList>
Expand All @@ -18,6 +17,7 @@
class Chat;
class ChatRequest;
class CompletionRequest;
namespace gpt4all::ui { class MwHttpServer; }


class Server : public ChatLLM
Expand All @@ -26,7 +26,7 @@ class Server : public ChatLLM

public:
explicit Server(Chat *chat);
~Server() override = default;
~Server() override;

public Q_SLOTS:
void start();
Expand All @@ -44,7 +44,7 @@ private Q_SLOTS:

private:
Chat *m_chat;
std::unique_ptr<QHttpServer> m_server;
std::unique_ptr<gpt4all::ui::MwHttpServer> m_server;
QList<ResultInfo> m_databaseResults;
QList<QString> m_collections;
};
Expand Down