You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
185 lines
5.8 KiB
185 lines
5.8 KiB
/*
|
|
* Copyright (C) 2018 The Android Open Source Project
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#include "utils/regex-match.h"
|
|
|
|
#include <memory>
|
|
|
|
#include "annotator/types.h"
|
|
|
|
#ifndef TC3_DISABLE_LUA
|
|
#include "utils/lua-utils.h"
|
|
#ifdef __cplusplus
|
|
extern "C" {
|
|
#endif
|
|
#include "lauxlib.h"
|
|
#include "lualib.h"
|
|
#ifdef __cplusplus
|
|
}
|
|
#endif
|
|
#endif
|
|
|
|
namespace libtextclassifier3 {
|
|
namespace {
|
|
|
|
#ifndef TC3_DISABLE_LUA
|
|
// Provide a lua environment for running regex match post verification.
|
|
// It sets up and exposes the match data as well as the context.
|
|
class LuaVerifier : public LuaEnvironment {
|
|
public:
|
|
static std::unique_ptr<LuaVerifier> Create(
|
|
const std::string& context, const std::string& verifier_code,
|
|
const UniLib::RegexMatcher* matcher);
|
|
|
|
bool Verify(bool* result);
|
|
|
|
private:
|
|
explicit LuaVerifier(const std::string& context,
|
|
const std::string& verifier_code,
|
|
const UniLib::RegexMatcher* matcher)
|
|
: context_(context), verifier_code_(verifier_code), matcher_(matcher) {}
|
|
bool Initialize();
|
|
|
|
// Provides details of a capturing group to lua.
|
|
int GetCapturingGroup();
|
|
|
|
const std::string& context_;
|
|
const std::string& verifier_code_;
|
|
const UniLib::RegexMatcher* matcher_;
|
|
};
|
|
|
|
bool LuaVerifier::Initialize() {
|
|
// Run protected to not lua panic in case of setup failure.
|
|
return RunProtected([this] {
|
|
LoadDefaultLibraries();
|
|
|
|
// Expose context of the match as `context` global variable.
|
|
PushString(context_);
|
|
lua_setglobal(state_, "context");
|
|
|
|
// Expose match array as `match` global variable.
|
|
// Each entry `match[i]` exposes the ith capturing group as:
|
|
// * `begin`: span start
|
|
// * `end`: span end
|
|
// * `text`: the text
|
|
PushLazyObject(&LuaVerifier::GetCapturingGroup);
|
|
lua_setglobal(state_, "match");
|
|
return LUA_OK;
|
|
}) == LUA_OK;
|
|
}
|
|
|
|
std::unique_ptr<LuaVerifier> LuaVerifier::Create(
|
|
const std::string& context, const std::string& verifier_code,
|
|
const UniLib::RegexMatcher* matcher) {
|
|
auto verifier = std::unique_ptr<LuaVerifier>(
|
|
new LuaVerifier(context, verifier_code, matcher));
|
|
if (!verifier->Initialize()) {
|
|
TC3_LOG(ERROR) << "Could not initialize lua environment.";
|
|
return nullptr;
|
|
}
|
|
return verifier;
|
|
}
|
|
|
|
int LuaVerifier::GetCapturingGroup() {
|
|
if (lua_type(state_, /*idx=*/-1) != LUA_TNUMBER) {
|
|
TC3_LOG(ERROR) << "Unexpected type for match group lookup: "
|
|
<< lua_type(state_, /*idx=*/-1);
|
|
lua_error(state_);
|
|
return 0;
|
|
}
|
|
const int group_id = static_cast<int>(lua_tonumber(state_, /*idx=*/-1));
|
|
int status = UniLib::RegexMatcher::kNoError;
|
|
const CodepointSpan span = {matcher_->Start(group_id, &status),
|
|
matcher_->End(group_id, &status)};
|
|
std::string text = matcher_->Group(group_id, &status).ToUTF8String();
|
|
if (status != UniLib::RegexMatcher::kNoError) {
|
|
TC3_LOG(ERROR) << "Could not extract span from capturing group.";
|
|
lua_error(state_);
|
|
return 0;
|
|
}
|
|
lua_newtable(state_);
|
|
lua_pushinteger(state_, span.first);
|
|
lua_setfield(state_, /*idx=*/-2, "begin");
|
|
lua_pushinteger(state_, span.second);
|
|
lua_setfield(state_, /*idx=*/-2, "end");
|
|
PushString(text);
|
|
lua_setfield(state_, /*idx=*/-2, "text");
|
|
return 1;
|
|
}
|
|
|
|
bool LuaVerifier::Verify(bool* result) {
|
|
if (luaL_loadbuffer(state_, verifier_code_.data(), verifier_code_.size(),
|
|
/*name=*/nullptr) != LUA_OK) {
|
|
TC3_LOG(ERROR) << "Could not load verifier snippet.";
|
|
return false;
|
|
}
|
|
|
|
if (lua_pcall(state_, /*nargs=*/0, /*nresults=*/1, /*errfunc=*/0) != LUA_OK) {
|
|
TC3_LOG(ERROR) << "Could not run verifier snippet.";
|
|
return false;
|
|
}
|
|
|
|
if (RunProtected(
|
|
[this, result] {
|
|
if (lua_type(state_, /*idx=*/-1) != LUA_TBOOLEAN) {
|
|
TC3_LOG(ERROR) << "Unexpected verification result type: "
|
|
<< lua_type(state_, /*idx=*/-1);
|
|
lua_error(state_);
|
|
return LUA_ERRRUN;
|
|
}
|
|
*result = lua_toboolean(state_, /*idx=*/-1);
|
|
return LUA_OK;
|
|
},
|
|
/*num_args=*/1) != LUA_OK) {
|
|
TC3_LOG(ERROR) << "Could not read lua result.";
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
#endif // TC3_DISABLE_LUA
|
|
|
|
} // namespace
|
|
|
|
Optional<std::string> GetCapturingGroupText(const UniLib::RegexMatcher* matcher,
|
|
const int group_id) {
|
|
int status = UniLib::RegexMatcher::kNoError;
|
|
std::string group_text = matcher->Group(group_id, &status).ToUTF8String();
|
|
if (status != UniLib::RegexMatcher::kNoError || group_text.empty()) {
|
|
return Optional<std::string>();
|
|
}
|
|
return Optional<std::string>(group_text);
|
|
}
|
|
|
|
bool VerifyMatch(const std::string& context,
|
|
const UniLib::RegexMatcher* matcher,
|
|
const std::string& lua_verifier_code) {
|
|
bool status = false;
|
|
#ifndef TC3_DISABLE_LUA
|
|
auto verifier = LuaVerifier::Create(context, lua_verifier_code, matcher);
|
|
if (verifier == nullptr) {
|
|
TC3_LOG(ERROR) << "Could not create verifier.";
|
|
return false;
|
|
}
|
|
if (!verifier->Verify(&status)) {
|
|
TC3_LOG(ERROR) << "Could not create verifier.";
|
|
return false;
|
|
}
|
|
#endif // TC3_DISABLE_LUA
|
|
return status;
|
|
}
|
|
|
|
} // namespace libtextclassifier3
|