tests upgrade
This commit is contained in:
parent
a6b3b21894
commit
8a60bec73a
|
@ -1,119 +0,0 @@
|
||||||
import time
|
|
||||||
|
|
||||||
from sqlalchemy import (
|
|
||||||
JSON,
|
|
||||||
Boolean,
|
|
||||||
Column,
|
|
||||||
DateTime,
|
|
||||||
ForeignKey,
|
|
||||||
Integer,
|
|
||||||
String,
|
|
||||||
func,
|
|
||||||
)
|
|
||||||
from sqlalchemy.orm import relationship
|
|
||||||
|
|
||||||
from services.db import Base
|
|
||||||
|
|
||||||
|
|
||||||
class Permission(Base):
|
|
||||||
__tablename__ = "permission"
|
|
||||||
|
|
||||||
id = Column(String, primary_key=True, unique=True, nullable=False, default=None)
|
|
||||||
resource = Column(String, nullable=False)
|
|
||||||
operation = Column(String, nullable=False)
|
|
||||||
|
|
||||||
|
|
||||||
class Role(Base):
|
|
||||||
__tablename__ = "role"
|
|
||||||
|
|
||||||
id = Column(String, primary_key=True, unique=True, nullable=False, default=None)
|
|
||||||
name = Column(String, nullable=False)
|
|
||||||
permissions = relationship(Permission)
|
|
||||||
|
|
||||||
|
|
||||||
class AuthorizerUser(Base):
|
|
||||||
__tablename__ = "authorizer_users"
|
|
||||||
|
|
||||||
id = Column(String, primary_key=True, unique=True, nullable=False, default=None)
|
|
||||||
key = Column(String)
|
|
||||||
email = Column(String, unique=True)
|
|
||||||
email_verified_at = Column(Integer)
|
|
||||||
family_name = Column(String)
|
|
||||||
gender = Column(String)
|
|
||||||
given_name = Column(String)
|
|
||||||
is_multi_factor_auth_enabled = Column(Boolean)
|
|
||||||
middle_name = Column(String)
|
|
||||||
nickname = Column(String)
|
|
||||||
password = Column(String)
|
|
||||||
phone_number = Column(String, unique=True)
|
|
||||||
phone_number_verified_at = Column(Integer)
|
|
||||||
# preferred_username = Column(String, nullable=False)
|
|
||||||
picture = Column(String)
|
|
||||||
revoked_timestamp = Column(Integer)
|
|
||||||
roles = Column(String, default="author,reader")
|
|
||||||
signup_methods = Column(String, default="magic_link_login")
|
|
||||||
created_at = Column(Integer, default=lambda: int(time.time()))
|
|
||||||
updated_at = Column(Integer, default=lambda: int(time.time()))
|
|
||||||
|
|
||||||
|
|
||||||
class UserRating(Base):
|
|
||||||
__tablename__ = "user_rating"
|
|
||||||
|
|
||||||
id = None
|
|
||||||
rater: Column = Column(ForeignKey("user.id"), primary_key=True, index=True)
|
|
||||||
user: Column = Column(ForeignKey("user.id"), primary_key=True, index=True)
|
|
||||||
value: Column = Column(Integer)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def init_table():
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class UserRole(Base):
|
|
||||||
__tablename__ = "user_role"
|
|
||||||
|
|
||||||
id = None
|
|
||||||
user = Column(ForeignKey("user.id"), primary_key=True, index=True)
|
|
||||||
role = Column(ForeignKey("role.id"), primary_key=True, index=True)
|
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
|
||||||
__tablename__ = "user"
|
|
||||||
default_user = None
|
|
||||||
|
|
||||||
email = Column(String, unique=True, nullable=False, comment="Email")
|
|
||||||
username = Column(String, nullable=False, comment="Login")
|
|
||||||
password = Column(String, nullable=True, comment="Password")
|
|
||||||
bio = Column(String, nullable=True, comment="Bio") # status description
|
|
||||||
about = Column(String, nullable=True, comment="About") # long and formatted
|
|
||||||
userpic = Column(String, nullable=True, comment="Userpic")
|
|
||||||
name = Column(String, nullable=True, comment="Display name")
|
|
||||||
slug = Column(String, unique=True, comment="User's slug")
|
|
||||||
links = Column(JSON, nullable=True, comment="Links")
|
|
||||||
oauth = Column(String, nullable=True)
|
|
||||||
oid = Column(String, nullable=True)
|
|
||||||
|
|
||||||
muted = Column(Boolean, default=False)
|
|
||||||
confirmed = Column(Boolean, default=False)
|
|
||||||
|
|
||||||
created_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now(), comment="Created at")
|
|
||||||
updated_at = Column(DateTime(timezone=True), nullable=False, server_default=func.now(), comment="Updated at")
|
|
||||||
last_seen = Column(DateTime(timezone=True), nullable=False, server_default=func.now(), comment="Was online at")
|
|
||||||
deleted_at = Column(DateTime(timezone=True), nullable=True, comment="Deleted at")
|
|
||||||
|
|
||||||
ratings = relationship(UserRating, foreign_keys=UserRating.user)
|
|
||||||
roles = relationship(lambda: Role, secondary=UserRole.__tablename__)
|
|
||||||
|
|
||||||
def get_permission(self):
|
|
||||||
scope = {}
|
|
||||||
for role in self.roles:
|
|
||||||
for p in role.permissions:
|
|
||||||
if p.resource not in scope:
|
|
||||||
scope[p.resource] = set()
|
|
||||||
scope[p.resource].add(p.operation)
|
|
||||||
print(scope)
|
|
||||||
return scope
|
|
||||||
|
|
||||||
|
|
||||||
# if __name__ == "__main__":
|
|
||||||
# print(User.get_permission(user_id=1))
|
|
24
cache/triggers.py
vendored
24
cache/triggers.py
vendored
|
@ -88,7 +88,11 @@ def after_reaction_handler(mapper, connection, target):
|
||||||
with local_session() as session:
|
with local_session() as session:
|
||||||
shout = (
|
shout = (
|
||||||
session.query(Shout)
|
session.query(Shout)
|
||||||
.filter(Shout.id == shout_id, Shout.published_at.is_not(None), Shout.deleted_at.is_(None))
|
.filter(
|
||||||
|
Shout.id == shout_id,
|
||||||
|
Shout.published_at.is_not(None),
|
||||||
|
Shout.deleted_at.is_(None),
|
||||||
|
)
|
||||||
.first()
|
.first()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -108,15 +112,27 @@ def events_register():
|
||||||
|
|
||||||
event.listen(AuthorFollower, "after_insert", after_follower_handler)
|
event.listen(AuthorFollower, "after_insert", after_follower_handler)
|
||||||
event.listen(AuthorFollower, "after_update", after_follower_handler)
|
event.listen(AuthorFollower, "after_update", after_follower_handler)
|
||||||
event.listen(AuthorFollower, "after_delete", lambda *args: after_follower_handler(*args, is_delete=True))
|
event.listen(
|
||||||
|
AuthorFollower,
|
||||||
|
"after_delete",
|
||||||
|
lambda *args: after_follower_handler(*args, is_delete=True),
|
||||||
|
)
|
||||||
|
|
||||||
event.listen(TopicFollower, "after_insert", after_follower_handler)
|
event.listen(TopicFollower, "after_insert", after_follower_handler)
|
||||||
event.listen(TopicFollower, "after_update", after_follower_handler)
|
event.listen(TopicFollower, "after_update", after_follower_handler)
|
||||||
event.listen(TopicFollower, "after_delete", lambda *args: after_follower_handler(*args, is_delete=True))
|
event.listen(
|
||||||
|
TopicFollower,
|
||||||
|
"after_delete",
|
||||||
|
lambda *args: after_follower_handler(*args, is_delete=True),
|
||||||
|
)
|
||||||
|
|
||||||
event.listen(ShoutReactionsFollower, "after_insert", after_follower_handler)
|
event.listen(ShoutReactionsFollower, "after_insert", after_follower_handler)
|
||||||
event.listen(ShoutReactionsFollower, "after_update", after_follower_handler)
|
event.listen(ShoutReactionsFollower, "after_update", after_follower_handler)
|
||||||
event.listen(ShoutReactionsFollower, "after_delete", lambda *args: after_follower_handler(*args, is_delete=True))
|
event.listen(
|
||||||
|
ShoutReactionsFollower,
|
||||||
|
"after_delete",
|
||||||
|
lambda *args: after_follower_handler(*args, is_delete=True),
|
||||||
|
)
|
||||||
|
|
||||||
event.listen(Reaction, "after_update", mark_for_revalidation)
|
event.listen(Reaction, "after_update", mark_for_revalidation)
|
||||||
event.listen(Author, "after_update", mark_for_revalidation)
|
event.listen(Author, "after_update", mark_for_revalidation)
|
||||||
|
|
|
@ -90,7 +90,6 @@ class Author(Base):
|
||||||
Модель автора в системе.
|
Модель автора в системе.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
user (str): Идентификатор пользователя в системе авторизации
|
|
||||||
name (str): Отображаемое имя
|
name (str): Отображаемое имя
|
||||||
slug (str): Уникальный строковый идентификатор
|
slug (str): Уникальный строковый идентификатор
|
||||||
bio (str): Краткая биография/статус
|
bio (str): Краткая биография/статус
|
||||||
|
@ -105,8 +104,6 @@ class Author(Base):
|
||||||
|
|
||||||
__tablename__ = "author"
|
__tablename__ = "author"
|
||||||
|
|
||||||
user = Column(String) # unbounded link with authorizer's User type
|
|
||||||
|
|
||||||
name = Column(String, nullable=True, comment="Display name")
|
name = Column(String, nullable=True, comment="Display name")
|
||||||
slug = Column(String, unique=True, comment="Author's slug")
|
slug = Column(String, unique=True, comment="Author's slug")
|
||||||
bio = Column(String, nullable=True, comment="Bio") # status description
|
bio = Column(String, nullable=True, comment="Bio") # status description
|
||||||
|
@ -124,12 +121,14 @@ class Author(Base):
|
||||||
|
|
||||||
# Определяем индексы
|
# Определяем индексы
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
|
# Индекс для быстрого поиска по имени
|
||||||
|
Index("idx_author_name", "name"),
|
||||||
# Индекс для быстрого поиска по slug
|
# Индекс для быстрого поиска по slug
|
||||||
Index("idx_author_slug", "slug"),
|
Index("idx_author_slug", "slug"),
|
||||||
# Индекс для быстрого поиска по идентификатору пользователя
|
|
||||||
Index("idx_author_user", "user"),
|
|
||||||
# Индекс для фильтрации неудаленных авторов
|
# Индекс для фильтрации неудаленных авторов
|
||||||
Index("idx_author_deleted_at", "deleted_at", postgresql_where=deleted_at.is_(None)),
|
Index(
|
||||||
|
"idx_author_deleted_at", "deleted_at", postgresql_where=deleted_at.is_(None)
|
||||||
|
),
|
||||||
# Индекс для сортировки по времени создания (для новых авторов)
|
# Индекс для сортировки по времени создания (для новых авторов)
|
||||||
Index("idx_author_created_at", "created_at"),
|
Index("idx_author_created_at", "created_at"),
|
||||||
# Индекс для сортировки по времени последнего посещения
|
# Индекс для сортировки по времени последнего посещения
|
||||||
|
|
|
@ -6,7 +6,6 @@ from sqlalchemy.orm import relationship
|
||||||
from orm.author import Author
|
from orm.author import Author
|
||||||
from orm.topic import Topic
|
from orm.topic import Topic
|
||||||
from services.db import Base
|
from services.db import Base
|
||||||
from orm.shout import Shout
|
|
||||||
|
|
||||||
|
|
||||||
class DraftTopic(Base):
|
class DraftTopic(Base):
|
||||||
|
|
1
services/__init__.py
Normal file
1
services/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
# This file makes services a Python package
|
|
@ -88,7 +88,7 @@ async def notify_draft(draft_data, action: str = "publish"):
|
||||||
"subtitle": getattr(draft_data, "subtitle", None),
|
"subtitle": getattr(draft_data, "subtitle", None),
|
||||||
"media": getattr(draft_data, "media", None),
|
"media": getattr(draft_data, "media", None),
|
||||||
"created_at": getattr(draft_data, "created_at", None),
|
"created_at": getattr(draft_data, "created_at", None),
|
||||||
"updated_at": getattr(draft_data, "updated_at", None)
|
"updated_at": getattr(draft_data, "updated_at", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
# Если переданы связанные атрибуты, добавим их
|
# Если переданы связанные атрибуты, добавим их
|
||||||
|
@ -100,7 +100,12 @@ async def notify_draft(draft_data, action: str = "publish"):
|
||||||
|
|
||||||
if hasattr(draft_data, "authors") and draft_data.authors is not None:
|
if hasattr(draft_data, "authors") and draft_data.authors is not None:
|
||||||
draft_payload["authors"] = [
|
draft_payload["authors"] = [
|
||||||
{"id": a.id, "name": a.name, "slug": a.slug, "pic": getattr(a, "pic", None)}
|
{
|
||||||
|
"id": a.id,
|
||||||
|
"name": a.name,
|
||||||
|
"slug": a.slug,
|
||||||
|
"pic": getattr(a, "pic", None),
|
||||||
|
}
|
||||||
for a in draft_data.authors
|
for a in draft_data.authors
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
25
tests/auth/conftest.py
Normal file
25
tests/auth/conftest.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
import pytest
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def oauth_settings() -> Dict[str, Dict[str, str]]:
|
||||||
|
"""Тестовые настройки OAuth"""
|
||||||
|
return {
|
||||||
|
"GOOGLE": {"id": "test_google_id", "key": "test_google_secret"},
|
||||||
|
"GITHUB": {"id": "test_github_id", "key": "test_github_secret"},
|
||||||
|
"FACEBOOK": {"id": "test_facebook_id", "key": "test_facebook_secret"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def frontend_url() -> str:
|
||||||
|
"""URL фронтенда для тестов"""
|
||||||
|
return "http://localhost:3000"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_settings(monkeypatch, oauth_settings, frontend_url):
|
||||||
|
"""Подменяем настройки для тестов"""
|
||||||
|
monkeypatch.setattr("auth.oauth.OAUTH_CLIENTS", oauth_settings)
|
||||||
|
monkeypatch.setattr("auth.oauth.FRONTEND_URL", frontend_url)
|
224
tests/auth/test_oauth.py
Normal file
224
tests/auth/test_oauth.py
Normal file
|
@ -0,0 +1,224 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from starlette.responses import JSONResponse, RedirectResponse
|
||||||
|
|
||||||
|
from auth.oauth import get_user_profile, oauth_login, oauth_callback
|
||||||
|
|
||||||
|
# Подменяем настройки для тестов
|
||||||
|
with (
|
||||||
|
patch("auth.oauth.FRONTEND_URL", "http://localhost:3000"),
|
||||||
|
patch(
|
||||||
|
"auth.oauth.OAUTH_CLIENTS",
|
||||||
|
{
|
||||||
|
"GOOGLE": {"id": "test_google_id", "key": "test_google_secret"},
|
||||||
|
"GITHUB": {"id": "test_github_id", "key": "test_github_secret"},
|
||||||
|
"FACEBOOK": {"id": "test_facebook_id", "key": "test_facebook_secret"},
|
||||||
|
},
|
||||||
|
),
|
||||||
|
):
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_request():
|
||||||
|
"""Фикстура для мока запроса"""
|
||||||
|
request = MagicMock()
|
||||||
|
request.session = {}
|
||||||
|
request.path_params = {}
|
||||||
|
request.query_params = {}
|
||||||
|
return request
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_oauth_client():
|
||||||
|
"""Фикстура для мока OAuth клиента"""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.authorize_redirect = AsyncMock()
|
||||||
|
client.authorize_access_token = AsyncMock()
|
||||||
|
client.get = AsyncMock()
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_profile_google():
|
||||||
|
"""Тест получения профиля из Google"""
|
||||||
|
client = AsyncMock()
|
||||||
|
token = {
|
||||||
|
"userinfo": {
|
||||||
|
"sub": "123",
|
||||||
|
"email": "test@gmail.com",
|
||||||
|
"name": "Test User",
|
||||||
|
"picture": "https://lh3.googleusercontent.com/photo=s96",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
profile = await get_user_profile("google", client, token)
|
||||||
|
|
||||||
|
assert profile["id"] == "123"
|
||||||
|
assert profile["email"] == "test@gmail.com"
|
||||||
|
assert profile["name"] == "Test User"
|
||||||
|
assert profile["picture"] == "https://lh3.googleusercontent.com/photo=s600"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_profile_github():
|
||||||
|
"""Тест получения профиля из GitHub"""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.get.side_effect = [
|
||||||
|
MagicMock(
|
||||||
|
json=lambda: {
|
||||||
|
"id": 456,
|
||||||
|
"login": "testuser",
|
||||||
|
"name": "Test User",
|
||||||
|
"avatar_url": "https://github.com/avatar.jpg",
|
||||||
|
}
|
||||||
|
),
|
||||||
|
MagicMock(
|
||||||
|
json=lambda: [
|
||||||
|
{"email": "other@github.com", "primary": False},
|
||||||
|
{"email": "test@github.com", "primary": True},
|
||||||
|
]
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
profile = await get_user_profile("github", client, {})
|
||||||
|
|
||||||
|
assert profile["id"] == "456"
|
||||||
|
assert profile["email"] == "test@github.com"
|
||||||
|
assert profile["name"] == "Test User"
|
||||||
|
assert profile["picture"] == "https://github.com/avatar.jpg"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_user_profile_facebook():
|
||||||
|
"""Тест получения профиля из Facebook"""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.get.return_value = MagicMock(
|
||||||
|
json=lambda: {
|
||||||
|
"id": "789",
|
||||||
|
"name": "Test User",
|
||||||
|
"email": "test@facebook.com",
|
||||||
|
"picture": {"data": {"url": "https://facebook.com/photo.jpg"}},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
profile = await get_user_profile("facebook", client, {})
|
||||||
|
|
||||||
|
assert profile["id"] == "789"
|
||||||
|
assert profile["email"] == "test@facebook.com"
|
||||||
|
assert profile["name"] == "Test User"
|
||||||
|
assert profile["picture"] == "https://facebook.com/photo.jpg"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth_login_success(mock_request, mock_oauth_client):
|
||||||
|
"""Тест успешного начала OAuth авторизации"""
|
||||||
|
mock_request.path_params["provider"] = "google"
|
||||||
|
|
||||||
|
# Настраиваем мок для authorize_redirect
|
||||||
|
redirect_response = RedirectResponse(url="http://example.com")
|
||||||
|
mock_oauth_client.authorize_redirect.return_value = redirect_response
|
||||||
|
|
||||||
|
with patch("auth.oauth.oauth.create_client", return_value=mock_oauth_client):
|
||||||
|
response = await oauth_login(mock_request)
|
||||||
|
|
||||||
|
assert isinstance(response, RedirectResponse)
|
||||||
|
assert mock_request.session["provider"] == "google"
|
||||||
|
assert "code_verifier" in mock_request.session
|
||||||
|
assert "state" in mock_request.session
|
||||||
|
|
||||||
|
mock_oauth_client.authorize_redirect.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth_login_invalid_provider(mock_request):
|
||||||
|
"""Тест с неправильным провайдером"""
|
||||||
|
mock_request.path_params["provider"] = "invalid"
|
||||||
|
|
||||||
|
response = await oauth_login(mock_request)
|
||||||
|
|
||||||
|
assert isinstance(response, JSONResponse)
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "Invalid provider" in response.body.decode()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth_callback_success(mock_request, mock_oauth_client):
|
||||||
|
"""Тест успешного OAuth callback"""
|
||||||
|
mock_request.session = {
|
||||||
|
"provider": "google",
|
||||||
|
"code_verifier": "test_verifier",
|
||||||
|
"state": "test_state",
|
||||||
|
}
|
||||||
|
mock_request.query_params["state"] = "test_state"
|
||||||
|
|
||||||
|
mock_oauth_client.authorize_access_token.return_value = {
|
||||||
|
"userinfo": {"sub": "123", "email": "test@gmail.com", "name": "Test User"}
|
||||||
|
}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("auth.oauth.oauth.create_client", return_value=mock_oauth_client),
|
||||||
|
patch("auth.oauth.local_session") as mock_session,
|
||||||
|
patch("auth.oauth.TokenStorage.create_session", return_value="test_token"),
|
||||||
|
):
|
||||||
|
# Мокаем сессию базы данных
|
||||||
|
session = MagicMock()
|
||||||
|
session.query.return_value.filter.return_value.first.return_value = None
|
||||||
|
mock_session.return_value.__enter__.return_value = session
|
||||||
|
|
||||||
|
response = await oauth_callback(mock_request)
|
||||||
|
|
||||||
|
assert isinstance(response, RedirectResponse)
|
||||||
|
assert response.status_code == 307
|
||||||
|
assert "auth/success" in response.headers["location"]
|
||||||
|
|
||||||
|
# Проверяем cookie
|
||||||
|
cookies = response.headers.getlist("set-cookie")
|
||||||
|
assert any("session_token=test_token" in cookie for cookie in cookies)
|
||||||
|
assert any("httponly" in cookie.lower() for cookie in cookies)
|
||||||
|
assert any("secure" in cookie.lower() for cookie in cookies)
|
||||||
|
|
||||||
|
# Проверяем очистку сессии
|
||||||
|
assert "code_verifier" not in mock_request.session
|
||||||
|
assert "provider" not in mock_request.session
|
||||||
|
assert "state" not in mock_request.session
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth_callback_invalid_state(mock_request):
|
||||||
|
"""Тест с неправильным state параметром"""
|
||||||
|
mock_request.session = {"provider": "google", "state": "correct_state"}
|
||||||
|
mock_request.query_params["state"] = "wrong_state"
|
||||||
|
|
||||||
|
response = await oauth_callback(mock_request)
|
||||||
|
|
||||||
|
assert isinstance(response, JSONResponse)
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "Invalid state" in response.body.decode()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_oauth_callback_existing_user(mock_request, mock_oauth_client):
|
||||||
|
"""Тест OAuth callback с существующим пользователем"""
|
||||||
|
mock_request.session = {
|
||||||
|
"provider": "google",
|
||||||
|
"code_verifier": "test_verifier",
|
||||||
|
"state": "test_state",
|
||||||
|
}
|
||||||
|
mock_request.query_params["state"] = "test_state"
|
||||||
|
|
||||||
|
mock_oauth_client.authorize_access_token.return_value = {
|
||||||
|
"userinfo": {"sub": "123", "email": "test@gmail.com", "name": "Test User"}
|
||||||
|
}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("auth.oauth.oauth.create_client", return_value=mock_oauth_client),
|
||||||
|
patch("auth.oauth.local_session") as mock_session,
|
||||||
|
patch("auth.oauth.TokenStorage.create_session", return_value="test_token"),
|
||||||
|
):
|
||||||
|
# Мокаем существующего пользователя
|
||||||
|
existing_user = MagicMock()
|
||||||
|
session = MagicMock()
|
||||||
|
session.query.return_value.filter.return_value.first.return_value = (
|
||||||
|
existing_user
|
||||||
|
)
|
||||||
|
mock_session.return_value.__enter__.return_value = session
|
||||||
|
|
||||||
|
response = await oauth_callback(mock_request)
|
||||||
|
|
||||||
|
assert isinstance(response, RedirectResponse)
|
||||||
|
assert response.status_code == 307
|
||||||
|
|
||||||
|
# Проверяем обновление существующего пользователя
|
||||||
|
assert existing_user.name == "Test User"
|
||||||
|
assert existing_user.oauth == "google:123"
|
||||||
|
assert existing_user.email_verified is True
|
9
tests/auth/test_settings.py
Normal file
9
tests/auth/test_settings.py
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
"""Тестовые настройки для OAuth"""
|
||||||
|
|
||||||
|
FRONTEND_URL = "http://localhost:3000"
|
||||||
|
|
||||||
|
OAUTH_CLIENTS = {
|
||||||
|
"GOOGLE": {"id": "test_google_id", "key": "test_google_secret"},
|
||||||
|
"GITHUB": {"id": "test_github_id", "key": "test_github_secret"},
|
||||||
|
"FACEBOOK": {"id": "test_facebook_id", "key": "test_facebook_secret"},
|
||||||
|
}
|
|
@ -1,17 +1,7 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import create_engine
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
from starlette.testclient import TestClient
|
|
||||||
|
|
||||||
from main import app
|
|
||||||
from services.db import Base
|
|
||||||
from services.redis import redis
|
from services.redis import redis
|
||||||
|
from tests.test_config import get_test_client
|
||||||
# Use SQLite for testing
|
|
||||||
TEST_DB_URL = "sqlite:///test.db"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
@ -23,38 +13,36 @@ def event_loop():
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def test_engine():
|
def test_app():
|
||||||
"""Create a test database engine."""
|
"""Create a test client and session factory."""
|
||||||
engine = create_engine(TEST_DB_URL)
|
client, SessionLocal = get_test_client()
|
||||||
Base.metadata.create_all(engine)
|
return client, SessionLocal
|
||||||
yield engine
|
|
||||||
Base.metadata.drop_all(engine)
|
|
||||||
os.remove("test.db")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def db_session(test_engine):
|
def db_session(test_app):
|
||||||
"""Create a new database session for a test."""
|
"""Create a new database session for a test."""
|
||||||
connection = test_engine.connect()
|
_, SessionLocal = test_app
|
||||||
transaction = connection.begin()
|
session = SessionLocal()
|
||||||
session = Session(bind=connection)
|
|
||||||
|
|
||||||
yield session
|
yield session
|
||||||
|
|
||||||
|
session.rollback()
|
||||||
session.close()
|
session.close()
|
||||||
transaction.rollback()
|
|
||||||
connection.close()
|
|
||||||
|
@pytest.fixture
|
||||||
|
def test_client(test_app):
|
||||||
|
"""Get the test client."""
|
||||||
|
client, _ = test_app
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def redis_client():
|
async def redis_client():
|
||||||
"""Create a test Redis client."""
|
"""Create a test Redis client."""
|
||||||
await redis.connect()
|
await redis.connect()
|
||||||
|
await redis.flushall() # Очищаем Redis перед каждым тестом
|
||||||
yield redis
|
yield redis
|
||||||
|
await redis.flushall() # Очищаем после теста
|
||||||
await redis.disconnect()
|
await redis.disconnect()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def test_client():
|
|
||||||
"""Create a TestClient instance."""
|
|
||||||
return TestClient(app)
|
|
||||||
|
|
67
tests/test_config.py
Normal file
67
tests/test_config.py
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
"""
|
||||||
|
Конфигурация для тестов
|
||||||
|
"""
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlalchemy.pool import StaticPool
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
from starlette.middleware import Middleware
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
# Используем in-memory SQLite для тестов
|
||||||
|
TEST_DB_URL = "sqlite:///:memory:"
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Middleware для внедрения сессии БД"""
|
||||||
|
|
||||||
|
def __init__(self, app, session_maker):
|
||||||
|
super().__init__(app)
|
||||||
|
self.session_maker = session_maker
|
||||||
|
|
||||||
|
async def dispatch(self, request, call_next):
|
||||||
|
session = self.session_maker()
|
||||||
|
request.state.db = session
|
||||||
|
try:
|
||||||
|
response = await call_next(request)
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def create_test_app():
|
||||||
|
"""Create a test Starlette application."""
|
||||||
|
from services.db import Base
|
||||||
|
|
||||||
|
# Создаем движок и таблицы
|
||||||
|
engine = create_engine(
|
||||||
|
TEST_DB_URL,
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
echo=False,
|
||||||
|
)
|
||||||
|
Base.metadata.drop_all(bind=engine)
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
|
||||||
|
# Создаем фабрику сессий
|
||||||
|
SessionLocal = sessionmaker(bind=engine)
|
||||||
|
|
||||||
|
# Создаем middleware для сессий
|
||||||
|
middleware = [Middleware(DatabaseMiddleware, session_maker=SessionLocal)]
|
||||||
|
|
||||||
|
# Создаем тестовое приложение
|
||||||
|
app = Starlette(
|
||||||
|
debug=True,
|
||||||
|
middleware=middleware,
|
||||||
|
routes=[], # Здесь можно добавить тестовые маршруты если нужно
|
||||||
|
)
|
||||||
|
|
||||||
|
return app, SessionLocal
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_client():
|
||||||
|
"""Get a test client with initialized database."""
|
||||||
|
app, SessionLocal = create_test_app()
|
||||||
|
return TestClient(app), SessionLocal
|
|
@ -53,7 +53,11 @@ async def test_create_reaction(test_client, db_session, test_setup):
|
||||||
}
|
}
|
||||||
""",
|
""",
|
||||||
"variables": {
|
"variables": {
|
||||||
"reaction": {"shout": test_setup["shout"].id, "kind": ReactionKind.LIKE.value, "body": "Great post!"}
|
"reaction": {
|
||||||
|
"shout": test_setup["shout"].id,
|
||||||
|
"kind": ReactionKind.LIKE.value,
|
||||||
|
"body": "Great post!",
|
||||||
|
}
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -61,4 +65,6 @@ async def test_create_reaction(test_client, db_session, test_setup):
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert "error" not in data
|
assert "error" not in data
|
||||||
assert data["data"]["create_reaction"]["reaction"]["kind"] == ReactionKind.LIKE.value
|
assert (
|
||||||
|
data["data"]["create_reaction"]["reaction"]["kind"] == ReactionKind.LIKE.value
|
||||||
|
)
|
||||||
|
|
|
@ -1,70 +0,0 @@
|
||||||
from datetime import datetime, timedelta
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import ValidationError
|
|
||||||
|
|
||||||
from auth.validations import (
|
|
||||||
AuthInput,
|
|
||||||
AuthResponse,
|
|
||||||
TokenPayload,
|
|
||||||
UserRegistrationInput,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TestAuthValidations:
|
|
||||||
def test_auth_input(self):
|
|
||||||
"""Test basic auth input validation"""
|
|
||||||
# Valid case
|
|
||||||
auth = AuthInput(user_id="123", username="testuser", token="1234567890abcdef1234567890abcdef")
|
|
||||||
assert auth.user_id == "123"
|
|
||||||
assert auth.username == "testuser"
|
|
||||||
|
|
||||||
# Invalid cases
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
AuthInput(user_id="", username="test", token="x" * 32)
|
|
||||||
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
AuthInput(user_id="123", username="t", token="x" * 32)
|
|
||||||
|
|
||||||
def test_user_registration(self):
|
|
||||||
"""Test user registration validation"""
|
|
||||||
# Valid case
|
|
||||||
user = UserRegistrationInput(email="test@example.com", password="SecurePass123!", name="Test User")
|
|
||||||
assert user.email == "test@example.com"
|
|
||||||
assert user.name == "Test User"
|
|
||||||
|
|
||||||
# Test email validation
|
|
||||||
with pytest.raises(ValidationError) as exc:
|
|
||||||
UserRegistrationInput(email="invalid-email", password="SecurePass123!", name="Test")
|
|
||||||
assert "Invalid email format" in str(exc.value)
|
|
||||||
|
|
||||||
# Test password validation
|
|
||||||
with pytest.raises(ValidationError) as exc:
|
|
||||||
UserRegistrationInput(email="test@example.com", password="weak", name="Test")
|
|
||||||
assert "String should have at least 8 characters" in str(exc.value)
|
|
||||||
|
|
||||||
def test_token_payload(self):
|
|
||||||
"""Test token payload validation"""
|
|
||||||
now = datetime.utcnow()
|
|
||||||
exp = now + timedelta(hours=1)
|
|
||||||
|
|
||||||
payload = TokenPayload(user_id="123", username="testuser", exp=exp, iat=now)
|
|
||||||
assert payload.user_id == "123"
|
|
||||||
assert payload.username == "testuser"
|
|
||||||
assert payload.scopes == [] # Default empty list
|
|
||||||
|
|
||||||
def test_auth_response(self):
|
|
||||||
"""Test auth response validation"""
|
|
||||||
# Success case
|
|
||||||
success_resp = AuthResponse(success=True, token="valid_token", user={"id": "123", "name": "Test"})
|
|
||||||
assert success_resp.success is True
|
|
||||||
assert success_resp.token == "valid_token"
|
|
||||||
|
|
||||||
# Error case
|
|
||||||
error_resp = AuthResponse(success=False, error="Invalid credentials")
|
|
||||||
assert error_resp.success is False
|
|
||||||
assert error_resp.error == "Invalid credentials"
|
|
||||||
|
|
||||||
# Invalid case - отсутствует обязательное поле token при success=True
|
|
||||||
with pytest.raises(ValidationError):
|
|
||||||
AuthResponse(success=True, user={"id": "123", "name": "Test"})
|
|
Loading…
Reference in New Issue
Block a user