You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
304 lines
8.6 KiB
304 lines
8.6 KiB
/*
|
|
* Copyright (C) 2017 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "chre_host/socket_server.h"
|
|
|
|
#include <poll.h>
|
|
|
|
#include <cassert>
|
|
#include <cinttypes>
|
|
#include <csignal>
|
|
#include <cstdlib>
|
|
#include <map>
|
|
#include <mutex>
|
|
|
|
#include <cutils/sockets.h>
|
|
|
|
#include "chre_host/log.h"
|
|
|
|
namespace android {
|
|
namespace chre {
|
|
|
|
std::atomic<bool> SocketServer::sSignalReceived(false);
|
|
|
|
namespace {
|
|
|
|
void maskAllSignals() {
|
|
sigset_t signalMask;
|
|
sigfillset(&signalMask);
|
|
if (sigprocmask(SIG_SETMASK, &signalMask, NULL) != 0) {
|
|
LOG_ERROR("Couldn't mask all signals", errno);
|
|
}
|
|
}
|
|
|
|
void maskAllSignalsExceptIntAndTerm() {
|
|
sigset_t signalMask;
|
|
sigfillset(&signalMask);
|
|
sigdelset(&signalMask, SIGINT);
|
|
sigdelset(&signalMask, SIGTERM);
|
|
if (sigprocmask(SIG_SETMASK, &signalMask, NULL) != 0) {
|
|
LOG_ERROR("Couldn't mask all signals except INT/TERM", errno);
|
|
}
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
SocketServer::SocketServer() {
|
|
// Initialize the socket fds field for all inactive client slots to -1, so
|
|
// poll skips over it, and we don't attempt to send on it
|
|
for (size_t i = 1; i <= kMaxActiveClients; i++) {
|
|
mPollFds[i].fd = -1;
|
|
mPollFds[i].events = POLLIN;
|
|
}
|
|
}
|
|
|
|
void SocketServer::run(const char *socketName, bool allowSocketCreation,
|
|
ClientMessageCallback clientMessageCallback) {
|
|
mClientMessageCallback = clientMessageCallback;
|
|
|
|
mSockFd = android_get_control_socket(socketName);
|
|
if (mSockFd == INVALID_SOCKET && allowSocketCreation) {
|
|
LOGI("Didn't inherit socket, creating...");
|
|
mSockFd = socket_local_server(socketName, ANDROID_SOCKET_NAMESPACE_RESERVED,
|
|
SOCK_SEQPACKET);
|
|
}
|
|
|
|
if (mSockFd == INVALID_SOCKET) {
|
|
LOGE("Couldn't get/create socket");
|
|
} else {
|
|
int ret = listen(mSockFd, kMaxPendingConnectionRequests);
|
|
if (ret < 0) {
|
|
LOG_ERROR("Couldn't listen on socket", errno);
|
|
} else {
|
|
serviceSocket();
|
|
}
|
|
|
|
{
|
|
std::lock_guard<std::mutex> lock(mClientsMutex);
|
|
for (const auto &pair : mClients) {
|
|
int clientSocket = pair.first;
|
|
if (close(clientSocket) != 0) {
|
|
LOGI("Couldn't close client %" PRIu16 "'s socket: %s",
|
|
pair.second.clientId, strerror(errno));
|
|
}
|
|
}
|
|
mClients.clear();
|
|
}
|
|
close(mSockFd);
|
|
}
|
|
}
|
|
|
|
void SocketServer::sendToAllClients(const void *data, size_t length) {
|
|
std::lock_guard<std::mutex> lock(mClientsMutex);
|
|
|
|
int deliveredCount = 0;
|
|
for (const auto &pair : mClients) {
|
|
int clientSocket = pair.first;
|
|
uint16_t clientId = pair.second.clientId;
|
|
if (sendToClientSocket(data, length, clientSocket, clientId)) {
|
|
deliveredCount++;
|
|
} else if (errno == EINTR) {
|
|
// Exit early if we were interrupted - we should only get this for
|
|
// SIGINT/SIGTERM, so we should exit quickly
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (deliveredCount == 0) {
|
|
LOGW("Got message but didn't deliver to any clients");
|
|
}
|
|
}
|
|
|
|
bool SocketServer::sendToClientById(const void *data, size_t length,
|
|
uint16_t clientId) {
|
|
std::lock_guard<std::mutex> lock(mClientsMutex);
|
|
|
|
bool sent = false;
|
|
for (const auto &pair : mClients) {
|
|
uint16_t thisClientId = pair.second.clientId;
|
|
if (thisClientId == clientId) {
|
|
int clientSocket = pair.first;
|
|
sent = sendToClientSocket(data, length, clientSocket, thisClientId);
|
|
break;
|
|
}
|
|
}
|
|
|
|
return sent;
|
|
}
|
|
|
|
void SocketServer::acceptClientConnection() {
|
|
int clientSocket = accept(mSockFd, NULL, NULL);
|
|
if (clientSocket < 0) {
|
|
LOG_ERROR("Couldn't accept client connection", errno);
|
|
} else if (mClients.size() >= kMaxActiveClients) {
|
|
LOGW("Rejecting client request - maximum number of clients reached");
|
|
close(clientSocket);
|
|
} else {
|
|
ClientData clientData;
|
|
clientData.clientId = mNextClientId++;
|
|
|
|
// We currently don't handle wraparound - if we're getting this many
|
|
// connects/disconnects, then something is wrong.
|
|
// TODO: can handle this properly by iterating over the existing clients to
|
|
// avoid a conflict.
|
|
if (clientData.clientId == 0) {
|
|
LOGE("Couldn't allocate client ID");
|
|
std::exit(-1);
|
|
}
|
|
|
|
bool slotFound = false;
|
|
for (size_t i = 1; i <= kMaxActiveClients; i++) {
|
|
if (mPollFds[i].fd < 0) {
|
|
mPollFds[i].fd = clientSocket;
|
|
slotFound = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (!slotFound) {
|
|
LOGE("Couldn't find slot for client!");
|
|
assert(slotFound);
|
|
close(clientSocket);
|
|
} else {
|
|
{
|
|
std::lock_guard<std::mutex> lock(mClientsMutex);
|
|
mClients[clientSocket] = clientData;
|
|
}
|
|
LOGI(
|
|
"Accepted new client connection (count %zu), assigned client ID "
|
|
"%" PRIu16,
|
|
mClients.size(), clientData.clientId);
|
|
}
|
|
}
|
|
}
|
|
|
|
void SocketServer::handleClientData(int clientSocket) {
|
|
const ClientData &clientData = mClients[clientSocket];
|
|
uint16_t clientId = clientData.clientId;
|
|
|
|
ssize_t packetSize =
|
|
recv(clientSocket, mRecvBuffer.data(), mRecvBuffer.size(), MSG_DONTWAIT);
|
|
if (packetSize < 0) {
|
|
LOGE("Couldn't get packet from client %" PRIu16 ": %s", clientId,
|
|
strerror(errno));
|
|
} else if (packetSize == 0) {
|
|
LOGI("Client %" PRIu16 " disconnected", clientId);
|
|
disconnectClient(clientSocket);
|
|
} else {
|
|
LOGV("Got %zd byte packet from client %" PRIu16, packetSize, clientId);
|
|
mClientMessageCallback(clientId, mRecvBuffer.data(), packetSize);
|
|
}
|
|
}
|
|
|
|
void SocketServer::disconnectClient(int clientSocket) {
|
|
{
|
|
std::lock_guard<std::mutex> lock(mClientsMutex);
|
|
mClients.erase(clientSocket);
|
|
}
|
|
close(clientSocket);
|
|
|
|
bool removed = false;
|
|
for (size_t i = 1; i <= kMaxActiveClients; i++) {
|
|
if (mPollFds[i].fd == clientSocket) {
|
|
mPollFds[i].fd = -1;
|
|
removed = true;
|
|
break;
|
|
}
|
|
}
|
|
|
|
if (!removed) {
|
|
LOGE("Out of sync");
|
|
assert(removed);
|
|
}
|
|
}
|
|
|
|
bool SocketServer::sendToClientSocket(const void *data, size_t length,
|
|
int clientSocket, uint16_t clientId) {
|
|
errno = 0;
|
|
ssize_t bytesSent = send(clientSocket, data, length, 0);
|
|
if (bytesSent < 0) {
|
|
LOGE("Error sending packet of size %zu to client %" PRIu16 ": %s", length,
|
|
clientId, strerror(errno));
|
|
} else if (bytesSent == 0) {
|
|
LOGW("Client %" PRIu16 " disconnected before message could be delivered",
|
|
clientId);
|
|
} else {
|
|
LOGV("Delivered message of size %zu bytes to client %" PRIu16, length,
|
|
clientId);
|
|
}
|
|
|
|
return (bytesSent > 0);
|
|
}
|
|
|
|
void SocketServer::serviceSocket() {
|
|
constexpr size_t kListenIndex = 0;
|
|
static_assert(kListenIndex == 0,
|
|
"Code assumes that the first index is always the listen "
|
|
"socket");
|
|
|
|
mPollFds[kListenIndex].fd = mSockFd;
|
|
mPollFds[kListenIndex].events = POLLIN;
|
|
|
|
// Signal mask used with ppoll() so we gracefully handle SIGINT and SIGTERM,
|
|
// and ignore other signals
|
|
sigset_t signalMask;
|
|
sigfillset(&signalMask);
|
|
sigdelset(&signalMask, SIGINT);
|
|
sigdelset(&signalMask, SIGTERM);
|
|
|
|
// Masking signals here ensure that after this point, we won't handle INT/TERM
|
|
// until after we call into ppoll()
|
|
maskAllSignals();
|
|
std::signal(SIGINT, signalHandler);
|
|
std::signal(SIGTERM, signalHandler);
|
|
|
|
LOGI("Ready to accept connections");
|
|
while (!sSignalReceived) {
|
|
int ret = ppoll(mPollFds, 1 + kMaxActiveClients, nullptr, &signalMask);
|
|
maskAllSignalsExceptIntAndTerm();
|
|
if (ret == -1) {
|
|
LOGI("Exiting poll loop: %s", strerror(errno));
|
|
break;
|
|
}
|
|
|
|
if (mPollFds[kListenIndex].revents & POLLIN) {
|
|
acceptClientConnection();
|
|
}
|
|
|
|
for (size_t i = 1; i <= kMaxActiveClients; i++) {
|
|
if (mPollFds[i].fd < 0) {
|
|
continue;
|
|
}
|
|
|
|
if (mPollFds[i].revents & POLLIN) {
|
|
handleClientData(mPollFds[i].fd);
|
|
}
|
|
}
|
|
|
|
// Mask all signals to ensure that sSignalReceived can't become true between
|
|
// checking it in the while condition and calling into ppoll()
|
|
maskAllSignals();
|
|
}
|
|
}
|
|
|
|
void SocketServer::signalHandler(int signal) {
|
|
LOGD("Caught signal %d", signal);
|
|
sSignalReceived = true;
|
|
}
|
|
|
|
} // namespace chre
|
|
} // namespace android
|