You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

72 lines
2.0 KiB

  1. import unittest
  2. import functools
  3. import psycopg2
  4. import os
  5. from typing import Any
  6. def dbconnect(func):
  7. @functools.wraps(func)
  8. def inner(*args, **kwargs):
  9. inner.__wrapped__ = func
  10. params = {
  11. "host": "testdb",
  12. "port": 5432,
  13. "dbname": "test",
  14. "user": "postgres",
  15. "password": "password",
  16. }
  17. print(f"Connecting for {params}")
  18. # http://initd.org/psycopg/docs/usage.html#with-statement
  19. conn = None
  20. try:
  21. with psycopg2.connect(**params) as conn:
  22. func(*args, conn=conn, **kwargs)
  23. finally:
  24. if conn:
  25. print(f"Close connection for {params}")
  26. conn.close()
  27. return inner
  28. class DbTest(unittest.TestCase):
  29. @dbconnect
  30. def setUp(self, conn):
  31. print("Invoking setUp")
  32. print("Set up database schema")
  33. path_to_schema = os.path.join(
  34. os.path.dirname(__file__),
  35. "..",
  36. "sql",
  37. "schema.sql"
  38. )
  39. with conn.cursor() as cur:
  40. cur.execute("CREATE SCHEMA IF NOT EXISTS public;")
  41. schema_sql = self.read_file(path_to_schema)
  42. print(f"Loading {path_to_schema}")
  43. cur.execute(schema_sql)
  44. print(f"Loaded {path_to_schema}")
  45. @dbconnect
  46. def tearDown(self, conn):
  47. print("Invoking tearDown")
  48. print("Tore down database schema")
  49. with conn.cursor() as cur:
  50. print("Droping schema")
  51. cur.execute("DROP SCHEMA IF EXISTS public CASCADE;")
  52. print("Dropped schema")
  53. def load_fixtures(self, conn: Any, *path_to_sqls: str) -> None:
  54. for path_to_sql in path_to_sqls:
  55. sql = self.read_file(path_to_sql)
  56. with conn.cursor() as cur:
  57. print(f"Executing {path_to_sql}")
  58. cur.execute(sql)
  59. print(f"Executed {path_to_sql}")
  60. def read_file(self, path_to_file: str) -> str:
  61. with open(path_to_file, "r") as f:
  62. return f.read()