You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1160 lines
47 KiB
Python

5 years ago
# -*- coding: utf-8 -*-
from __future__ import print_function
from __future__ import unicode_literals
from __future__ import division
from builtins import str, bytes, dict, int
from builtins import map, zip, filter
from builtins import object, range, next
from io import open
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import datetime
import codecs
import random
import unittest
from pattern import db
# To test MySQL, you need MySQLdb and a username + password with rights to create a database.
HOST, PORT, USERNAME, PASSWORD = \
"localhost", 3306, "root", ""
DB_MYSQL = DB_SQLITE = None
def create_db_mysql():
global DB_MYSQL
# Make sure the database handle is setup and connected
if not DB_MYSQL or not DB_MYSQL._connection:
DB_MYSQL = db.Database(
type = db.MYSQL,
name = "pattern_unittest_db",
host = HOST,
port = PORT,
username = USERNAME,
password = PASSWORD)
# Drop all tables first
for table in list(DB_MYSQL.tables):
DB_MYSQL.drop(table)
return DB_MYSQL
def create_db_sqlite():
global DB_SQLITE
# Make sure the database handle is setup and connected
if not DB_SQLITE or not DB_SQLITE._connection:
DB_SQLITE = db.Database(
type = db.SQLITE,
name = "pattern_unittest_db",
host = HOST,
port = PORT,
username = USERNAME,
password = PASSWORD)
# Drop all tables first
for table in list(DB_MYSQL.tables):
DB_SQLITE.drop(table)
return DB_SQLITE
#---------------------------------------------------------------------------------------------------
class TestUnicode(unittest.TestCase):
def setUp(self):
# Test data with different (or wrong) encodings.
self.strings = (
"ünîcøde",
"ünîcøde".encode("utf-16"),
"ünîcøde".encode("latin-1"),
"ünîcøde".encode("windows-1252"),
"ünîcøde",
"אוניקאָד"
)
def test_decode_utf8(self):
# Assert unicode.
for s in self.strings:
self.assertTrue(isinstance(db.decode_utf8(s), str))
print("pattern.db.decode_utf8()")
def test_encode_utf8(self):
# Assert Python bytestring.
for s in self.strings:
self.assertTrue(isinstance(db.encode_utf8(s), bytes))
print("pattern.db.encode_utf8()")
def test_string(self):
# Assert string() with default for "" and None.
for v, s in ((True, "True"), (1, "1"), (1.0, "1.0"), ("", "????"), (None, "????")):
self.assertEqual(db.string(v, default="????"), s)
print("pattern.db.string()")
#---------------------------------------------------------------------------------------------------
class TestEntities(unittest.TestCase):
def setUp(self):
pass
def test_encode_entities(self):
# Assert HTML entity encoder (e.g., "&" => "&&")
for a, b in (
("É", "É"),
("&", "&"),
("<", "&lt;"),
(">", "&gt;"),
('"', "&quot;"),
("'", "&#39;")):
self.assertEqual(db.encode_entities(a), b)
print("pattern.db.encode_entities()")
def test_decode_entities(self):
# Assert HMTL entity decoder (e.g., "&amp;" => "&")
for a, b in (
("&#38;", "&"),
("&amp;", "&"),
("&#x0026;", "&"),
("&#160;", "\xa0"),
("&foo;", "&foo;")):
self.assertEqual(db.decode_entities(a), b)
print("pattern.db.decode_entities()")
#---------------------------------------------------------------------------------------------------
class TestDate(unittest.TestCase):
def setUp(self):
pass
def test_date(self):
# Assert string input and default date formats.
for s in (
"2010-09-21 09:27:01",
"2010-09-21T09:27:01Z",
"2010-09-21T09:27:01+0000",
"2010-09-21 09:27",
"2010-09-21",
"21/09/2010",
"21 September 2010",
"September 21 2010",
"September 21, 2010",
1285054021):
v = db.date(s)
self.assertEqual(v.format, "%Y-%m-%d %H:%M:%S")
self.assertEqual(v.year, 2010)
self.assertEqual(v.month, 9)
self.assertEqual(v.day, 21)
# Assert NOW.
for v in (db.date(), db.date(db.NOW)):
self.assertEqual(v.year, datetime.datetime.now().year)
self.assertEqual(v.month, datetime.datetime.now().month)
self.assertEqual(v.day, datetime.datetime.now().day)
self.assertEqual(db.date().year, db.YEAR)
# Assert integer input.
v1 = db.date(2010, 9, 21, format=db.DEFAULT_DATE_FORMAT)
v2 = db.date(2010, 9, 21, 9, 27, 1, 0, db.DEFAULT_DATE_FORMAT)
v3 = db.date(2010, 9, 21, hour=9, minute=27, second=1, format=db.DEFAULT_DATE_FORMAT)
self.assertEqual(str(v1), "2010-09-21 00:00:00")
self.assertEqual(str(v2), "2010-09-21 09:27:01")
self.assertEqual(str(v3), "2010-09-21 09:27:01")
# Assert week and weekday input
v4 = db.date(2014, week=1, weekday=1, hour=12, format=db.DEFAULT_DATE_FORMAT)
self.assertEqual(str(v4), "2013-12-30 12:00:00")
# Assert Date input.
v5 = db.date(db.date(2014, 1, 1))
self.assertEqual(str(v5), "2014-01-01 00:00:00")
# Assert timestamp input.
v6 = db.date(db.date(2014, 1, 1).timestamp)
self.assertEqual(str(v5), "2014-01-01 00:00:00")
# Assert DateError for other input.
self.assertRaises(db.DateError, db.date, None)
print("pattern.db.date()")
def test_format(self):
# Assert custom input formats.
v = db.date("2010-09", "%Y-%m")
self.assertEqual(str(v), "2010-09-01 00:00:00")
self.assertEqual(v.year, 2010)
# Assert custom output formats.
v = db.date("2010-09", "%Y-%m", format="%Y-%m")
self.assertEqual(v.format, "%Y-%m")
self.assertEqual(str(v), "2010-09")
self.assertEqual(v.year, 2010)
# Assert strftime() for date < 1900.
v = db.date(1707, 4, 15)
self.assertEqual(str(v), "1707-04-15 00:00:00")
self.assertRaises(ValueError, lambda: v.timestamp)
print("pattern.db.Date.__str__()")
def test_timestamp(self):
# Assert Date.timestamp.
v = db.date(2010, 9, 21, format=db.DEFAULT_DATE_FORMAT)
self.assertEqual(v.timestamp, 1285020000)
print("pattern.db.Date.timestamp")
def test_time(self):
# Assert Date + time().
v = db.date("2010-09-21 9:27:00")
v = v - db.time(days=1, hours=1, minutes=1, seconds=1)
self.assertEqual(str(v), "2010-09-20 08:25:59")
# Assert Date + time(years, months)
v = db.date(2014, 1, 31)
v = v + db.time(years=1, months=1)
self.assertEqual(str(v), "2015-02-28 00:00:00")
print("pattern.db.time()")
#---------------------------------------------------------------------------------------------------
class TestUtilityFunctions(unittest.TestCase):
def setUp(self):
pass
def test_encryption(self):
# Assert string password encryption.
v1 = "test"
v2 = db.encrypt_string(v1, key="1234")
v3 = db.decrypt_string(v2, key="1234")
self.assertTrue(v2 != "test")
self.assertTrue(v3 == "test")
print("pattern.db.encrypt_string()")
print("pattern.db.decrypt_string()")
def test_json(self):
# Assert JSON input and output.
v1 = ["a,b", 1, 1.0, True, False, None, [1, 2], {"a:b": 1.2, "a,b": True, "a": [1, {"2": 3}], "1": "None"}]
v2 = db.json.dumps(v1)
v3 = db.json.loads(v2)
self.assertEqual(v1, v3)
print("pattern.db.json.dumps()")
print("pattern.db.json.loads()")
def test_order(self):
# Assert a list of indices in the order as when the given list is sorted.
v = [3, 1, 2]
self.assertEqual(db.order(v), [1, 2, 0])
self.assertEqual(db.order(v, reverse=True), [0, 2, 1])
self.assertEqual(db.order(v, cmp=lambda a, b: a - b), [1, 2, 0])
self.assertEqual(db.order(v, key=lambda i: i), [1, 2, 0])
print("pattern.db.order()")
def test_avg(self):
# Assert (1+2+3+4) / 4 = 2.5.
self.assertEqual(db.avg([1, 2, 3, 4]), 2.5)
print("pattern.db.avg()")
def test_variance(self):
# Assert 2.5.
self.assertEqual(db.variance([1, 2, 3, 4, 5]), 2.5)
print("pattern.db.variance()")
def test_stdev(self):
# Assert 2.429.
self.assertAlmostEqual(db.stdev([1, 5, 6, 7, 6, 8]), 2.429, places=3)
print("pattern.db.stdev()")
def test_sqlite_functions(self):
# Assert year(), month(), day(), ..., first(), last() and group_concat() for SQLite.
v = "1707-04-15 01:02:03"
self.assertEqual(db.sqlite_year(v), 1707)
self.assertEqual(db.sqlite_month(v), 4)
self.assertEqual(db.sqlite_day(v), 15)
self.assertEqual(db.sqlite_hour(v), 1)
self.assertEqual(db.sqlite_minute(v), 2)
self.assertEqual(db.sqlite_second(v), 3)
# Aggregate functions.
for f, a, b in (
(db.sqlite_first, [1, 2, 3], 1),
(db.sqlite_last, [1, 2, 3], 3),
(db.sqlite_group_concat, [1, 2, 3], "1,2,3")):
f = f()
for x in a:
f.step(x)
self.assertEqual(f.finalize(), b)
print("pattern.db.sqlite_year()")
print("pattern.db.sqlite_month()")
print("pattern.db.sqlite_day()")
print("pattern.db.sqlite_hour()")
print("pattern.db.sqlite_minute()")
print("pattern.db.sqlite_second()")
print("pattern.db.sqlite_first()")
print("pattern.db.sqlite_last()")
print("pattern.db.sqlite_group_concat()")
#---------------------------------------------------------------------------------------------------
class _TestDatabase(object):
def setUp(self):
# Delete all tables first
for table in list(self.db):
self.db.drop(table)
def tearDown(self):
for table in list(self.db):
self.db.drop(table)
def test_escape(self):
# Assert str, unicode, int, long, float, bool and None field values.
for v, s in (
( "a", "'a'"),
( 1, "1"),
(int(1), "1"),
( 1.0, "1.0"),
( True, "1"),
( False, "0"),
( None, "null")):
self.assertEqual(db._escape(v), s)
# Assert date.
v = db.date("1707-04-15")
self.assertEqual(db._escape(v), "'1707-04-15 00:00:00'")
# Assert current date.
v = "current_timestamp"
self.assertEqual(db._escape(v), "current_timestamp")
# Assert subquery.
v = self.db.create("dummy", fields=[db.pk()])
v = v.query()
self.assertEqual(db._escape(v), "(select dummy.* from `dummy`)")
# Assert MySQL and SQLite quotes.
if self.db.type == db.MYSQL:
self.assertEqual(self.db.escape("'"), "'\\''")
if self.db.type == db.SQLITE:
self.assertEqual(self.db.escape("'"), "''''")
print("pattern.db._escape()")
def test_database(self):
# Assert Database properties.
self.assertTrue(self.db.type == self.type)
self.assertTrue(self.db.name == "pattern_unittest_db")
self.assertTrue(self.db.host == HOST)
self.assertTrue(self.db.port == PORT)
self.assertTrue(self.db.username == USERNAME)
self.assertTrue(self.db.password == PASSWORD)
self.assertTrue(self.db.tables == {})
self.assertTrue(self.db.relations == [])
self.assertTrue(self.db.connected)
self.db.disconnect()
self.assertTrue(self.db.connected == False)
self.assertTrue(self.db.connection is None)
self.db.connect()
print("pattern.db.Database(type=%s)" % self.type.upper())
def test_create_table(self):
# Assert Database.create() new table.
v = self.db.create("products", fields=[
db.primary_key("pid"),
db.field("name", db.STRING, index=True, optional=False),
db.field("price", db.FLOAT)
])
# Assert that the last query executed is stored.
if self.db.type == db.SQLITE:
self.assertEqual(self.db.query, "pragma table_info(`products`);")
if self.db.type == db.MYSQL:
self.assertEqual(self.db.query, "show columns from `products`;")
# Assert new Table exists in Database.tables.
self.assertTrue(isinstance(v, db.Table))
self.assertTrue(len(self.db) == 1)
self.assertTrue(v.pk == "pid")
self.assertTrue(v.fields == ["pid", "name", "price"])
self.assertTrue(self.db[v.name] == v)
self.assertTrue(self.db.tables[v.name] == v)
self.assertTrue(getattr(self.db, v.name) == v)
# Assert Database._field_SQL subroutine for Database.create().
for field, sql1, sql2 in (
(db.primary_key("pid"),
("`pid` integer not null primary key auto_increment", None),
("`pid` integer not null primary key autoincrement", None)),
(db.field("name", db.STRING, index=True, optional=False),
("`name` varchar(100) not null", "create index `products_name` on `products` (`name`);"),
("`name` varchar(100) not null", "create index `products_name` on `products` (`name`);")),
(db.field("price", db.INTEGER),
("`price` integer null", None),
("`price` integer null", None))):
if self.db.type == db.MYSQL:
self.assertEqual(self.db._field_SQL(self.db["products"].name, field), sql1)
if self.db.type == db.SQLITE:
self.assertEqual(self.db._field_SQL(self.db["products"].name, field), sql2)
# Assert TableError if table already exists.
self.assertRaises(db.TableError, self.db.create, "products")
# Assert remove table.
self.db.drop("products")
self.assertTrue(len(self.db) == 0)
print("pattern.db.Database.create()")
class TestDeleteMySQLDatabase(unittest.TestCase):
def runTest(self):
create_db_mysql()._delete()
class TestDeleteSQLiteDatabase(unittest.TestCase):
def runTest(self):
create_db_sqlite()._delete()
class TestMySQLDatabase(unittest.TestCase, _TestDatabase):
def setUp(self):
self.db, self.type = create_db_mysql(), db.MYSQL
_TestDatabase.setUp(self)
class TestSQLiteDatabase(unittest.TestCase, _TestDatabase):
def setUp(self):
self.db, self.type = create_db_sqlite(), db.SQLITE
_TestDatabase.setUp(self)
#---------------------------------------------------------------------------------------------------
class TestSchema(unittest.TestCase):
def setUp(self):
pass
def test_string(self):
# Assert callable String.
v1 = db._String()
v2 = db._String()(0)
v3 = db._String()(200)
v4 = db._String()(300)
self.assertEqual(v1, "string")
self.assertEqual(v2, "varchar(1)")
self.assertEqual(v3, "varchar(200)")
self.assertEqual(v4, "varchar(255)")
def test_field(self):
# Assert field() return value with different optional parameters.
# NAME TYPE DEFAULT INDEX OPTIONAL
for kwargs, f in (
(dict(name="id", type=db.INT), ("id", "integer", None, False, True)),
(dict(name="id", type=db.INT, index=db.PRIMARY), ("id", "integer", None, "primary", True)),
(dict(name="id", type=db.INT, index=db.UNIQUE), ("id", "integer", None, "unique", True)),
(dict(name="id", type=db.INT, index="0"), ("id", "integer", None, False, True)),
(dict(name="id", type=db.INT, index="1"), ("id", "integer", None, True, True)),
(dict(name="id", type=db.INT, index=True), ("id", "integer", None, True, True)),
(dict(name="id", type=db.INT, default=0), ("id", "integer", 0, False, True)),
(dict(name="name", type=db.STRING), ("name", "varchar(100)", None, False, True)),
(dict(name="name", type=db.STRING, optional=False), ("name", "varchar(100)", None, False, False)),
(dict(name="name", type=db.STRING, optional="0"), ("name", "varchar(100)", None, False, False)),
(dict(name="name", type=db.STRING(50)), ("name", "varchar(50)", None, False, True)),
(dict(name="price", type=db.FLOAT, default=0), ("price", "real", 0, False, True)),
(dict(name="show", type=db.BOOL), ("show", "tinyint(1)", None, False, True)),
(dict(name="show", type=db.BOOL, default=True), ("show", "tinyint(1)", True, False, True)),
(dict(name="show", type=db.BOOL, default=False), ("show", "tinyint(1)", False, False, True)),
(dict(name="date", type=db.DATE), ("date", "timestamp", "now", False, True)),
(dict(name="date", type=db.DATE, default=db.NOW), ("date", "timestamp", "now", False, True)),
(dict(name="date", type=db.DATE, default="1999-12-31 23:59:59"),
("date", "timestamp", "1999-12-31 23:59:59", False, True))):
self.assertEqual(db.field(**kwargs), f)
# Assert primary_key() return value.
self.assertTrue(db.primary_key() == db.pk() == ("id", "integer", None, "primary", False))
print("pattern.db.field()")
def test_schema(self):
now1 = "current_timestamp"
now2 = "'CURRENT_TIMESTAMP'"
# Assert Schema (= table schema in a uniform way across database engines).
# NAME TYPE DEFAULT INDEX OPTIONAL
for args, v in (
(("id", "integer", None, "pri", False), ("id", db.INT, None, db.PRIMARY, False, None)),
(("id", "integer", None, "uni", False), ("id", db.INT, None, db.UNIQUE, False, None)),
(("id", "int", None, "yes", True), ("id", db.INT, None, True, True, None)),
(("id", "real", None, "mul", True), ("id", db.FLOAT, None, True, True, None)),
(("id", "real", None, "1", True), ("id", db.FLOAT, None, True, True, None)),
(("id", "double", None, "0", True), ("id", db.FLOAT, None, False, True, None)),
(("id", "double", 0, False, False), ("id", db.FLOAT, 0, False, False, None)),
(("text", "varchar(10)", "?", False, True), ("text", db.STRING, "?", False, True, 10)),
(("text", "char(20)", "", False, True), ("text", db.STRING, None, False, True, 20)),
(("text", "text", None, False, True), ("text", db.TEXT, None, False, True, None)),
(("text", "blob", None, False, True), ("text", db.BLOB, None, False, True, None)),
(("show", "tinyint(1)", None, False, True), ("show", db.BOOL, None, False, True, None)),
(("date", "timestamp", None, False, True), ("date", db.DATE, None, False, True, None)),
(("date", "timestamp", now1, False, True), ("date", db.DATE, db.NOW, False, True, None)),
(("date", "time", now2, False, "YES"), ("date", db.DATE, db.NOW, False, True, None))):
s = db.Schema(*args)
self.assertEqual(s.name, v[0])
self.assertEqual(s.type, v[1])
self.assertEqual(s.default, v[2])
self.assertEqual(s.index, v[3])
self.assertEqual(s.optional, v[4])
self.assertEqual(s.length, v[5])
print("pattern.db.Schema()")
#---------------------------------------------------------------------------------------------------
class _TestTable(object):
def setUp(self):
# Delete all tables first
for table in list(self.db):
self.db.drop(table)
# Create test tables.
self.db.create("persons", fields=[
db.primary_key("id"),
db.field("name", db.STRING)
])
self.db.create("products", fields=[
db.primary_key("id"),
db.field("name", db.STRING),
db.field("price", db.FLOAT, default=0.0)
])
self.db.create("orders", fields=[
db.primary_key("id"),
db.field("person", db.INTEGER, index=True),
db.field("product", db.INTEGER, index=True),
])
def tearDown(self):
# Drop test tables.
for table in list(self.db):
self.db.drop(table)
def test_table(self):
# Assert Table properties.
v = self.db.persons
self.assertTrue(v.db == self.db)
self.assertTrue(v.pk == "id")
self.assertTrue(v.fields == ["id", "name"])
self.assertTrue(v.name == "persons")
self.assertTrue(v.abs("name") == "persons.name")
self.assertTrue(v.rows() == [])
self.assertTrue(v.schema["id"].type == db.INTEGER)
self.assertTrue(v.schema["id"].index == db.PRIMARY)
print("pattern.db.Table")
def test_rename(self):
# Assert ALTER TABLE when name changes.
v = self.db.persons
v.name = "clients"
self.assertEqual(self.db.query, "alter table `persons` rename to `clients`;")
self.assertEqual(self.db.tables.get("clients"), v)
print("pattern.db.Table.name")
def test_fields(self):
# Assert ALTER TABLE when column is inserted.
v = self.db.products
v.fields.append(db.field("description", db.TEXT))
self.assertEqual(v.fields, ["id", "name", "price", "description"])
print("pattern.db.Table.fields")
def test_insert_update_delete(self):
# Assert Table.insert().
v1 = self.db.persons.insert(name="Kurt Gödel")
v2 = self.db.products.insert(name="pizza", price=10.0)
v3 = self.db.products.insert({"name": "garlic bread", "price": 3.0})
v4 = self.db.orders.insert(person=v1, product=v3)
self.assertEqual(v1, 1)
self.assertEqual(v2, 1)
self.assertEqual(v3, 2)
self.assertEqual(v4, 1)
self.assertEqual(self.db.persons.rows(), [(1, "Kurt Gödel")])
self.assertEqual(self.db.products.rows(), [(1, "pizza", 10.0), (2, "garlic bread", 3.0)])
self.assertEqual(self.db.orders.rows(), [(1, 1, 2)])
self.assertEqual(self.db.orders.count(), 1)
self.assertEqual(self.db.products.xml.replace(' extra="auto_increment"', ""),
'<?xml version="1.0" encoding="utf-8"?>\n'
'<table name="products" fields="id, name, price" count="2">\n'
'\t<schema>\n'
'\t\t<field name="id" type="integer" index="primary" optional="no" />\n'
'\t\t<field name="name" type="string" length="100" />\n'
'\t\t<field name="price" type="float" default="0.0" />\n'
'\t</schema>\n'
'\t<rows>\n'
'\t\t<row id="1" name="pizza" price="10.0" />\n'
'\t\t<row id="2" name="garlic bread" price="3.0" />\n'
'\t</rows>\n'
'</table>'
)
# Assert transactions with commit=False.
if self.db.type == db.SQLITE:
self.db.orders.insert(person=v1, product=v2, commit=False)
self.db.rollback()
self.assertEqual(len(self.db.orders), 1)
self.db.orders.insert(person=v1, product=v2, commit=False)
# Assert Table.update().
self.db.products.update(2, price=4.0)
self.db.products.update(2, {"price": 4.5})
self.db.products.update(db.all(db.filter("name", "pi*")), name="deeppan pizza")
self.assertEqual(self.db.products.rows(), [(1, "deeppan pizza", 10.0), (2, "garlic bread", 4.5)])
# Assert Table.delete().
self.db.products.delete(db.all(db.filter("name", "deeppan*")))
self.db.products.delete(db.ALL)
self.db.orders.delete(1)
self.assertEqual(len(self.db.products), 0)
self.assertEqual(len(self.db.orders), 1)
print("pattern.db.Table.insert()")
print("pattern.db.Table.update()")
print("pattern.db.Table.delete()")
def test_filter(self):
# Assert Table.filter().
self.db.persons.insert(name="Kurt Gödel")
self.db.persons.insert(name="M. C. Escher")
self.db.persons.insert(name="Johann Sebastian Bach")
f = self.db.persons.filter
self.assertEqual(f(("name",), id=1), [("Kurt Gödel",)])
self.assertEqual(f(db.ALL, id=(1, 2)), [(1, "Kurt Gödel"), (2, "M. C. Escher")])
self.assertEqual(f({"id": (1, 2)}), [(1, "Kurt Gödel"), (2, "M. C. Escher")])
self.assertEqual(f("id", name="Johan*"), [(3,)])
self.assertEqual(f("id", name=("J*", "K*")), [(1,), (3,)])
print("pattern.db.Table.filter()")
def test_search(self):
# Assert Table.search => Query object.
v = self.db.persons.search()
self.assertTrue(isinstance(v, db.Query))
self.assertTrue(v.table == self.db.persons)
def test_datasheet(self):
# Assert Table.datasheet() => Datasheet object.
v = self.db.persons.datasheet()
self.assertTrue(isinstance(v, db.Datasheet))
self.assertTrue(v.fields[0] == ("id", db.INTEGER))
print("pattern.db.Table.datasheet()")
class TestMySQLTable(unittest.TestCase, _TestTable):
def setUp(self):
self.db = create_db_mysql()
_TestTable.setUp(self)
class TestSQLiteTable(unittest.TestCase, _TestTable):
def setUp(self):
self.db = DB_SQLITE
_TestTable.setUp(self)
#---------------------------------------------------------------------------------------------------
class _TestQuery(object):
def setUp(self):
# Delete all tables first
for table in list(self.db):
self.db.drop(table)
# Create test tables.
self.db.create("persons", fields=[
db.primary_key("id"),
db.field("name", db.STRING),
db.field("age", db.INTEGER),
db.field("gender", db.INTEGER)
])
self.db.create("gender", fields=[
db.primary_key("id"),
db.field("name", db.STRING)
])
# Create test data.
self.db.persons.insert(name="john", age="30", gender=2)
self.db.persons.insert(name="jack", age="20", gender=2)
self.db.persons.insert(name="jane", age="30", gender=1)
self.db.gender.insert(name="female")
self.db.gender.insert(name="male")
def tearDown(self):
# Drop test tables.
for table in list(self.db):
self.db.drop(table)
def _query(self, *args, **kwargs):
""" Returns a pattern.db.Query object on a mock Table and Database.
"""
class Database(object):
escape, relations = lambda self, v: db._escape(v), []
class Table(object):
name, fields, db = "persons", ["id", "name", "age", "sex"], Database()
return db.Query(Table(), *args, **kwargs)
def test_abs(self):
# Assert absolute fieldname for trivial cases.
self.assertEqual(db.abs("persons", "name"), "persons.name")
self.assertEqual(db.abs("persons", ("id", "name")), ["persons.id", "persons.name"])
# Assert absolute fieldname with SQL functions (e.g., avg(product.price)).
for f in db.sql_functions.split("|"):
self.assertEqual(db.abs("persons", "%s(name)" % f), "%s(persons.name)" % f)
print("pattern.db.abs()")
def test_cmp(self):
# Assert WHERE-clause from cmp() function.
q = self.db.persons.search(fields=["name"])
self.assertTrue(isinstance(q, db.Query))
for args, sql in (
(("name", "Kurt%", db.LIKE), "name like 'Kurt%'"),
(("name", "Kurt*", "="), "name like 'Kurt%'"),
(("name", "*Gödel", "=="), "name like '%Gödel'"),
(("name", "Kurt*", "!="), "name not like 'Kurt%'"),
(("name", "Kurt*", "<>"), "name not like 'Kurt%'"),
(("name", "Gödel", "i="), "name like 'Gödel'"), # case-insensitive search
(("id", (1, 2), db.IN), "id in (1,2)"),
(("id", (1, 2), "="), "id in (1,2)"),
(("id", (1, 2), "=="), "id in (1,2)"),
(("id", (1, 2), "!="), "id not in (1,2)"),
(("id", (1, 2), "<>"), "id not in (1,2)"),
(("id", (1, 3), db.BETWEEN), "id between 1 and 3"),
(("id", (1, 3), ":"), "id between 1 and 3"),
(("name", ("G", "K*"), "="), "(name='G' or name like 'K%')"),
(("name", None, "="), "name is null"),
(("name", None, "=="), "name is null"),
(("name", None, "!="), "name is not null"),
(("name", None, "<>"), "name is not null"),
(("name", q, "="), "name in (select persons.name from `persons`)"),
(("name", q, "=="), "name in (select persons.name from `persons`)"),
(("name", q, "!="), "name not in (select persons.name from `persons`)"),
(("name", q, "<>"), "name not in (select persons.name from `persons`)"),
(("name", "Gödel", "="), "name='Gödel'"),
(("id", 1, ">"), "id>1")):
self.assertEqual(db.cmp(*args), sql)
print("pattern.db.cmp()")
def test_filterchain(self):
# Assert WHERE with AND/OR combinations from FilterChain object().
yesterday = db.date()
yesterday -= db.time(days=1)
f1 = db.FilterChain(("name", "garlic bread"))
f2 = db.FilterChain(("name", "pizza"), ("price", 10, "<"), operator=db.AND)
f3 = db.FilterChain(f1, f2, operator=db.OR)
f4 = db.FilterChain(f3, ("date", yesterday, ">"), operator=db.AND)
self.assertEqual(f1.SQL(), "name='garlic bread'")
self.assertEqual(f2.SQL(), "name='pizza' and price<10")
self.assertEqual(f3.SQL(), "(name='garlic bread') or (name='pizza' and price<10)")
self.assertEqual(f4.SQL(), "((name='garlic bread') or (name='pizza' and price<10)) and date>'%s'" % yesterday)
# Assert subquery in filter chain.
q = self._query(fields=["name"])
f = db.any(("name", "Gödel"), ("name", q))
self.assertEqual(f.SQL(), "name='Gödel' or name in (select persons.name from `persons`)")
print("pattern.db.FilterChain")
def test_query(self):
# Assert table query results from Table.search().
for kwargs, sql, rows in (
(dict(fields=db.ALL),
"select persons.* from `persons`;",
[(1, "john", 30, 2),
(2, "jack", 20, 2),
(3, "jane", 30, 1)]),
(dict(fields=db.ALL, range=(0, 2)),
"select persons.* from `persons` limit 0, 2;",
[(1, "john", 30, 2),
(2, "jack", 20, 2)]),
(dict(fields=db.ALL, filters=[("age", 30, "<")]),
"select persons.* from `persons` where persons.age<30;",
[(2, "jack", 20, 2)]),
(dict(fields=db.ALL, filters=db.any(("age", 30, "<"), ("name", "john"))),
"select persons.* from `persons` where persons.age<30 or persons.name='john';",
[(1, "john", 30, 2),
(2, "jack", 20, 2)]),
(dict(fields=["name", "gender.name"], relations=[db.relation("gender", "id", "gender")]),
"select persons.name, gender.name from `persons` left join `gender` on persons.gender=gender.id;",
[("john", "male"),
("jack", "male"),
("jane", "female")]),
(dict(fields=["name", "age"], sort="name"),
"select persons.name, persons.age from `persons` order by persons.name asc;",
[("jack", 20),
("jane", 30),
("john", 30)]),
(dict(fields=["name", "age"], sort=1, order=db.DESCENDING),
"select persons.name, persons.age from `persons` order by persons.name desc;",
[("john", 30),
("jane", 30),
("jack", 20)]),
(dict(fields=["age", "name"], sort=["age", "name"], order=[db.ASCENDING, db.DESCENDING]),
"select persons.age, persons.name from `persons` order by persons.age asc, persons.name desc;",
[(20, "jack"),
(30, "john"),
(30, "jane")]),
(dict(fields=["age", "name"], group="age", function=db.CONCATENATE),
"select persons.age, group_concat(persons.name) from `persons` group by persons.age;",
[(20, "jack"),
(30, "john,jane")]),
(dict(fields=["id", "name", "age"], group="age", function=[db.COUNT, db.CONCATENATE]),
"select count(persons.id), group_concat(persons.name), persons.age from `persons` group by persons.age;",
[(1, "jack", 20),
(2, "john,jane", 30)])):
v = self.db.persons.search(**kwargs)
v.xml
self.assertEqual(v.SQL(), sql)
self.assertEqual(v.rows(), rows)
# Assert Database.link() permanent relations.
v = self.db.persons.search(fields=["name", "gender.name"])
v.aliases["gender.name"] = "gender"
self.db.link("persons", "gender", "gender", "id", join=db.LEFT)
self.assertEqual(v.SQL(),
"select persons.name, gender.name as gender from `persons` left join `gender` on persons.gender=gender.id;")
self.assertEqual(v.rows(),
[('john', 'male'),
('jack', 'male'),
('jane', 'female')])
print("pattern.db.Table.search()")
print("pattern.db.Table.Query")
def test_xml(self):
# Assert Query.xml dump.
v = self.db.persons.search(fields=["name", "gender.name"])
v.aliases["gender.name"] = "gender"
self.db.link("persons", "gender", "gender", "id", join=db.LEFT)
self.assertEqual(v.xml,
'<?xml version="1.0" encoding="utf-8"?>\n'
'<query table="persons" fields="name, gender" count="3">\n'
'\t<schema>\n'
'\t\t<field name="name" type="string" length="100" />\n'
'\t\t<field name="gender" type="string" length="100" />\n'
'\t</schema>\n'
'\t<rows>\n'
'\t\t<row name="john" gender="male" />\n'
'\t\t<row name="jack" gender="male" />\n'
'\t\t<row name="jane" gender="female" />\n'
'\t</rows>\n'
'</query>'
)
# Assert Database.create() from XML.
self.assertRaises(db.TableError, self.db.create, v.xml) # table 'persons' already exists
self.db.create(v.xml, name="persons2")
self.assertTrue("persons2" in self.db)
self.assertTrue(self.db.persons2.fields == ["name", "gender"])
self.assertTrue(len(self.db.persons2) == 3)
print("pattern.db.Query.xml")
class TestMySQLQuery(unittest.TestCase, _TestQuery):
def setUp(self):
self.db = create_db_mysql()
_TestQuery.setUp(self)
class TestSQLiteQuery(unittest.TestCase, _TestQuery):
def setUp(self):
self.db = create_db_sqlite()
_TestQuery.setUp(self)
#---------------------------------------------------------------------------------------------------
class _TestView(object):
def setUp(self):
pass
def tearDown(self):
# Drop test tables.
for table in list(self.db):
self.db.drop(table)
def test_view(self):
class Products(db.View):
def __init__(self, database):
db.View.__init__(self, database, "products", schema=[
db.pk(),
db.field("name", db.STRING),
db.field("price", db.FLOAT)
])
self.setup()
self.table.insert(name="pizza", price=15.0)
def render(self, query, **kwargs):
q = self.table.search(fields=["name", "price"], filters=[("name", "*%s*" % query)])
s = []
for row in q.rows():
s.append("<tr>%s</tr>" % "".join(
["<td class=\"%s\">%s</td>" % f for f in zip(q.fields, row)]))
return "<table>" + "".join(s) + "</table>"
# Assert View with automatic Table creation.
v = Products(self.db)
self.assertEqual(v.render("iz"),
"<table>"
"<tr>"
"<td class=\"name\">pizza</td>"
"<td class=\"price\">15.0</td>"
"</tr>"
"</table>"
)
print("pattern.db.View")
class TestMySQLView(unittest.TestCase, _TestView):
def setUp(self):
self.db = create_db_mysql()
_TestView.setUp(self)
class TestSQLiteView(unittest.TestCase, _TestView):
def setUp(self):
self.db = create_db_sqlite()
_TestView.setUp(self)
#---------------------------------------------------------------------------------------------------
class TestCSV(unittest.TestCase):
def setUp(self):
# Create test table.
self.csv = db.CSV(
rows=[
["Schrödinger", "cat", True, 3, db.date(2009, 11, 3)],
["Hofstadter", "labrador", True, 5, db.date(2007, 8, 4)]
],
fields=[
["name", db.STRING],
["type", db.STRING],
["tail", db.BOOLEAN],
["age", db.INTEGER],
["date", db.DATE],
])
def test_csv_header(self):
# Assert field headers parser.
v1 = db.csv_header_encode("age", db.INTEGER)
v2 = db.csv_header_decode("age (INTEGER)")
self.assertEqual(v1, "age (INTEGER)")
self.assertEqual(v2, ("age", db.INTEGER))
print("pattern.db.csv_header_encode()")
print("pattern.db.csv_header_decode()")
def test_csv(self):
# Assert saving and loading data (field types are preserved).
v = self.csv
v.save("test.csv", headers=True)
v = db.CSV.load("test.csv", headers=True)
self.assertTrue(isinstance(v, list))
self.assertTrue(v.headers[0] == ("name", db.STRING))
self.assertTrue(v[0] == ["Schrödinger", "cat", True, 3, db.date(2009, 11, 3)])
os.unlink("test.csv")
print("pattern.db.CSV")
print("pattern.db.CSV.save()")
print("pattern.db.CSV.load()")
def test_file(self):
# Assert CSV file contents.
v = self.csv
v.save("test.csv", headers=True)
v = open("test.csv", "rb").read()
v = db.decode_utf8(v.lstrip(codecs.BOM_UTF8))
v = v.replace("\r\n", "\n")
self.assertEqual(v,
'"name (STRING)","type (STRING)","tail (BOOLEAN)","age (INTEGER)","date (DATE)"\n'
'"Schrödinger","cat","True","3","2009-11-03 00:00:00"\n'
'"Hofstadter","labrador","True","5","2007-08-04 00:00:00"'
)
os.unlink("test.csv")
#---------------------------------------------------------------------------------------------------
class TestDatasheet(unittest.TestCase):
def setUp(self):
pass
def test_rows(self):
# Assert Datasheet.rows DatasheetRows object.
v = db.Datasheet(rows=[[1, 2], [3, 4]])
v.rows += [5, 6]
v.rows[0] = [0, 0]
v.rows.swap(0, 1)
v.rows.insert(1, [1, 1])
v.rows.pop(1)
self.assertTrue(isinstance(v.rows, db.DatasheetRows))
self.assertEqual(v.rows, [[3, 4], [0, 0], [5, 6]])
self.assertEqual(v.rows[0], [3, 4])
self.assertEqual(v.rows[-1], [5, 6])
self.assertEqual(v.rows.count([3, 4]), 1)
self.assertEqual(v.rows.index([3, 4]), 0)
self.assertEqual(sorted(v.rows, reverse=True), [[5, 6], [3, 4], [0, 0]])
self.assertRaises(AttributeError, v._set_rows, [])
# Assert default for new rows with missing columns.
v.rows.extend([[7], [9]], default=0)
self.assertEqual(v.rows, [[3, 4], [0, 0], [5, 6], [7, 0], [9, 0]])
print("pattern.db.Datasheet.rows")
def test_columns(self):
# Assert Datasheet.columns DatasheetColumns object.
v = db.Datasheet(rows=[[1, 3], [2, 4]])
v.columns += [5, 6]
v.columns[0] = [0, 0]
v.columns.swap(0, 1)
v.columns.insert(1, [1, 1])
v.columns.pop(1)
self.assertTrue(isinstance(v.columns, db.DatasheetColumns))
self.assertEqual(v.columns, [[3, 4], [0, 0], [5, 6]])
self.assertEqual(v.columns[0], [3, 4])
self.assertEqual(v.columns[-1], [5, 6])
self.assertEqual(v.columns.count([3, 4]), 1)
self.assertEqual(v.columns.index([3, 4]), 0)
self.assertEqual(sorted(v.columns, reverse=True), [[5, 6], [3, 4], [0, 0]])
self.assertRaises(AttributeError, v._set_columns, [])
# Assert default for new columns with missing rows.
v.columns.extend([[7], [9]], default=0)
self.assertEqual(v.columns, [[3, 4], [0, 0], [5, 6], [7, 0], [9, 0]])
print("pattern.db.Datasheet.columns")
def test_column(self):
# Assert DatasheetColumn object.
# It has a reference to the parent Datasheet, as long as it is not deleted from the datasheet.
v = db.Datasheet(rows=[[1, 3], [2, 4]])
column = v.columns[0]
column.insert(1, 0, default=None)
self.assertEqual(v, [[1, 3], [0, None], [2, 4]])
del v.columns[0]
self.assertTrue(column._datasheet, None)
print("pattern.db.DatasheetColumn")
def test_fields(self):
# Assert Datasheet with incomplete headers.
v = db.Datasheet(rows=[["Schrödinger", "cat"]], fields=[("name", db.STRING)])
self.assertEqual(v.fields, [("name", db.STRING)])
# Assert (None, None) for missing headers.
v.columns.swap(0, 1)
self.assertEqual(v.fields, [(None, None), ("name", db.STRING)])
v.columns[0] = ["dog"]
self.assertEqual(v.fields, [(None, None), ("name", db.STRING)])
# Assert removing a column removes the header.
v.columns.pop(0)
self.assertEqual(v.fields, [("name", db.STRING)])
# Assert new columns with header description.
v.columns.append(["cat"])
v.columns.append([3], field=("age", db.INTEGER))
self.assertEqual(v.fields, [("name", db.STRING), (None, None), ("age", db.INTEGER)])
# Assert column by name.
self.assertEqual(v.name, ["Schrödinger"])
print("pattern.db.Datasheet.fields")
def test_group(self):
# Assert Datasheet.group().
v1 = db.Datasheet(rows=[[1, 2, "a"], [1, 3, "b"], [1, 4, "c"], [0, 0, "d"]])
v2 = v1.group(0)
v3 = v1.group(0, function=db.LAST)
v4 = v1.group(0, function=(db.FIRST, db.COUNT, db.CONCATENATE))
v5 = v1.group(0, function=db.CONCATENATE, key=lambda j: j > 0)
self.assertEqual(v2, [[1, 2, "a"], [0, 0, "d"]])
self.assertEqual(v3, [[1, 4, "c"], [0, 0, "d"]])
self.assertEqual(v4, [[1, 3, "a,b,c"], [0, 1, "d"]])
self.assertEqual(v5, [[True, "2,3,4", "a,b,c"], [False, "0", "d"]])
print("pattern.db.Datasheet.group()")
def test_slice(self):
# Assert Datasheet slices.
v = db.Datasheet([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
v = v.copy()
self.assertEqual(v.slice(0, 1, 3, 2), [[2, 3], [5, 6], [8, 9]])
self.assertEqual(v[2], [7, 8, 9])
self.assertEqual(v[2, 2], 9)
self.assertEqual(v[2, 1:], [8, 9])
self.assertEqual(v[0:2], [[1, 2, 3], [4, 5, 6]])
self.assertEqual(v[0:2, 1], [2, 5])
self.assertEqual(v[0:2, 0:2], [[1, 2], [4, 5]])
# Assert new Datasheet for i:j slices.
self.assertTrue(isinstance(v[0:2], db.Datasheet))
self.assertTrue(isinstance(v[0:2, 0:2], db.Datasheet))
print("pattern.db.Datasheet.slice()")
def test_copy(self):
# Assert Datasheet.copy().
v = db.Datasheet([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
self.assertTrue(v.copy(), [[1, 2, 3], [4, 5, 6], [7, 8, 9]])
self.assertTrue(v.copy(rows=[0]), [[1, 2, 3]])
self.assertTrue(v.copy(rows=[0], columns=[0]), [[1]])
self.assertTrue(v.copy(columns=[0]), [[1], [4], [7]])
print("pattern.db.Datasheet.copy()")
def test_map(self):
# Assert Datasheet.map() (in-place).
v = db.Datasheet(rows=[[1, 2], [3, 4]])
v.map(lambda x: x + 1)
self.assertEqual(v, [[2, 3], [4, 5]])
print("pattern.db.Datasheet.map()")
def test_json(self):
# Assert JSON output.
v = db.Datasheet(rows=[["Schrödinger", 3], ["Hofstadter", 5]])
self.assertEqual(v.json, '[["Schrödinger", 3], ["Hofstadter", 5]]')
# Assert JSON output with headers.
v = db.Datasheet(rows=[["Schrödinger", 3], ["Hofstadter", 5]],
fields=[("name", db.STRING), ("age", db.INT)])
random.seed(0)
w = db.json.loads(v.json)
self.assertTrue({"age": 3, "name": "Schrödinger"} in w)
self.assertTrue({"age": 5, "name": "Hofstadter"} in w)
print("pattern.db.Datasheet.json")
def test_flip(self):
# Assert flip matrix.
v = db.flip(db.Datasheet([[1, 2], [3, 4]]))
self.assertEqual(v, [[1, 3], [2, 4]])
print("pattern.db.flip()")
def test_truncate(self):
# Assert string truncate().
v1 = "a" * 50
v2 = "a" * 150
v3 = "aaa " * 50
self.assertEqual(db.truncate(v1), (v1, ""))
self.assertEqual(db.truncate(v2), ("a" * 99 + "-", "a" * 51))
self.assertEqual(db.truncate(v3), (("aaa " * 25).strip(), "aaa " * 25))
print("pattern.db.truncate()")
def test_pprint(self):
pass
#---------------------------------------------------------------------------------------------------
def suite(**kwargs):
suite = unittest.TestSuite()
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestUnicode))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestEntities))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestDate))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestUtilityFunctions))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestSchema))
# MySQL
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestMySQLDatabase))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestMySQLTable))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestMySQLQuery))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestMySQLView))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestDeleteMySQLDatabase))
# SQLite
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestSQLiteDatabase))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestSQLiteTable))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestSQLiteQuery))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestSQLiteView))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestDeleteSQLiteDatabase))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestCSV))
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestDatasheet))
return suite
if __name__ == "__main__":
result = unittest.TextTestRunner(verbosity=1).run(suite())
sys.exit(not result.wasSuccessful())