You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

63 lines
1.8 KiB

import asyncio
import retaggr
import config
from config import SZURU_URL # Easier to use
import quart.flask_patch
from auth import requires_auth
from quart import Quart
from aiohttp_requests import requests
import aiohttp
app = Quart(__name__)
rconfig = retaggr.ReverseSearchConfig(
danbooru_username=config.DANBOORU_USERNAME,
danbooru_api_key=config.DANBOORU_API_KEY,
e621_username=config.E621_USERNAME,
app_name="szuru-retaggr",
version="1.0.0",
min_score=config.MIN_SCORE
)
rsearch = retaggr.ReverseSearch(rconfig)
# Szuru Headers
HEADERS = {"Accept": "application/json", "Content-Type": "application/json"}
AUTHORIZATION = aiohttp.BasicAuth(config.SZURU_USERNAME, config.SZURU_PASSWORD)
async def process_task(post_id):
# Get existing post + tag data
post_request = await requests.get(SZURU_URL + "api/post/" + str(post_id), headers=HEADERS, auth=AUTHORIZATION)
post_data = await post_request.json()
tags = [tag["names"][0] for tag in post_data["tags"]]
image_url = SZURU_URL + post_data["contentUrl"]
print(f"Reverse tagging {post_id}")
found_tags = await rsearch.reverse_search(image_url, download=True)
print(found_tags)
tags.extend(found_tags)
tags = list(set(tags)) # Stripping out the dupes
update_request = await requests.put(SZURU_URL + "api/post/" + str(post_id), json={
"version": post_data["version"],
"tags": tags,
}, headers=HEADERS, auth=AUTHORIZATION)
print(await update_request.json())
print(f"Done processing {post_id}")
@app.route('/')
@requires_auth
async def launch():
return "Up!"
@app.route('/start/<int:post_id>', methods=['POST'])
async def create_job(post_id):
"""
Start a job for a szuru post with post_id.
"""
asyncio.ensure_future(process_task(post_id))
return "Started task."
app.run()