You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

230 lines
7.2 KiB
Python

3 years ago
""" Testing
"""
import os
import zlib
from io import BytesIO
from tempfile import mkstemp
from contextlib import contextmanager
import numpy as np
from numpy.testing import assert_, assert_equal
from pytest import raises as assert_raises
from scipy.io.matlab.streams import (make_stream,
GenericStream, ZlibInputStream,
_read_into, _read_string, BLOCK_SIZE)
@contextmanager
def setup_test_file():
val = b'a\x00string'
fd, fname = mkstemp()
with os.fdopen(fd, 'wb') as fs:
fs.write(val)
with open(fname, 'rb') as fs:
gs = BytesIO(val)
cs = BytesIO(val)
yield fs, gs, cs
os.unlink(fname)
def test_make_stream():
with setup_test_file() as (fs, gs, cs):
# test stream initialization
assert_(isinstance(make_stream(gs), GenericStream))
def test_tell_seek():
with setup_test_file() as (fs, gs, cs):
for s in (fs, gs, cs):
st = make_stream(s)
res = st.seek(0)
assert_equal(res, 0)
assert_equal(st.tell(), 0)
res = st.seek(5)
assert_equal(res, 0)
assert_equal(st.tell(), 5)
res = st.seek(2, 1)
assert_equal(res, 0)
assert_equal(st.tell(), 7)
res = st.seek(-2, 2)
assert_equal(res, 0)
assert_equal(st.tell(), 6)
def test_read():
with setup_test_file() as (fs, gs, cs):
for s in (fs, gs, cs):
st = make_stream(s)
st.seek(0)
res = st.read(-1)
assert_equal(res, b'a\x00string')
st.seek(0)
res = st.read(4)
assert_equal(res, b'a\x00st')
# read into
st.seek(0)
res = _read_into(st, 4)
assert_equal(res, b'a\x00st')
res = _read_into(st, 4)
assert_equal(res, b'ring')
assert_raises(IOError, _read_into, st, 2)
# read alloc
st.seek(0)
res = _read_string(st, 4)
assert_equal(res, b'a\x00st')
res = _read_string(st, 4)
assert_equal(res, b'ring')
assert_raises(IOError, _read_string, st, 2)
class TestZlibInputStream(object):
def _get_data(self, size):
data = np.random.randint(0, 256, size).astype(np.uint8).tobytes()
compressed_data = zlib.compress(data)
stream = BytesIO(compressed_data)
return stream, len(compressed_data), data
def test_read(self):
SIZES = [0, 1, 10, BLOCK_SIZE//2, BLOCK_SIZE-1,
BLOCK_SIZE, BLOCK_SIZE+1, 2*BLOCK_SIZE-1]
READ_SIZES = [BLOCK_SIZE//2, BLOCK_SIZE-1,
BLOCK_SIZE, BLOCK_SIZE+1]
def check(size, read_size):
compressed_stream, compressed_data_len, data = self._get_data(size)
stream = ZlibInputStream(compressed_stream, compressed_data_len)
data2 = b''
so_far = 0
while True:
block = stream.read(min(read_size,
size - so_far))
if not block:
break
so_far += len(block)
data2 += block
assert_equal(data, data2)
for size in SIZES:
for read_size in READ_SIZES:
check(size, read_size)
def test_read_max_length(self):
size = 1234
data = np.random.randint(0, 256, size).astype(np.uint8).tobytes()
compressed_data = zlib.compress(data)
compressed_stream = BytesIO(compressed_data + b"abbacaca")
stream = ZlibInputStream(compressed_stream, len(compressed_data))
stream.read(len(data))
assert_equal(compressed_stream.tell(), len(compressed_data))
assert_raises(IOError, stream.read, 1)
def test_read_bad_checksum(self):
data = np.random.randint(0, 256, 10).astype(np.uint8).tobytes()
compressed_data = zlib.compress(data)
# break checksum
compressed_data = compressed_data[:-1] + bytes([(compressed_data[-1] + 1) & 255])
compressed_stream = BytesIO(compressed_data)
stream = ZlibInputStream(compressed_stream, len(compressed_data))
assert_raises(zlib.error, stream.read, len(data))
def test_seek(self):
compressed_stream, compressed_data_len, data = self._get_data(1024)
stream = ZlibInputStream(compressed_stream, compressed_data_len)
stream.seek(123)
p = 123
assert_equal(stream.tell(), p)
d1 = stream.read(11)
assert_equal(d1, data[p:p+11])
stream.seek(321, 1)
p = 123+11+321
assert_equal(stream.tell(), p)
d2 = stream.read(21)
assert_equal(d2, data[p:p+21])
stream.seek(641, 0)
p = 641
assert_equal(stream.tell(), p)
d3 = stream.read(11)
assert_equal(d3, data[p:p+11])
assert_raises(IOError, stream.seek, 10, 2)
assert_raises(IOError, stream.seek, -1, 1)
assert_raises(ValueError, stream.seek, 1, 123)
stream.seek(10000, 1)
assert_raises(IOError, stream.read, 12)
def test_seek_bad_checksum(self):
data = np.random.randint(0, 256, 10).astype(np.uint8).tobytes()
compressed_data = zlib.compress(data)
# break checksum
compressed_data = compressed_data[:-1] + bytes([(compressed_data[-1] + 1) & 255])
compressed_stream = BytesIO(compressed_data)
stream = ZlibInputStream(compressed_stream, len(compressed_data))
assert_raises(zlib.error, stream.seek, len(data))
def test_all_data_read(self):
compressed_stream, compressed_data_len, data = self._get_data(1024)
stream = ZlibInputStream(compressed_stream, compressed_data_len)
assert_(not stream.all_data_read())
stream.seek(512)
assert_(not stream.all_data_read())
stream.seek(1024)
assert_(stream.all_data_read())
def test_all_data_read_overlap(self):
COMPRESSION_LEVEL = 6
data = np.arange(33707000).astype(np.uint8).tobytes()
compressed_data = zlib.compress(data, COMPRESSION_LEVEL)
compressed_data_len = len(compressed_data)
# check that part of the checksum overlaps
assert_(compressed_data_len == BLOCK_SIZE + 2)
compressed_stream = BytesIO(compressed_data)
stream = ZlibInputStream(compressed_stream, compressed_data_len)
assert_(not stream.all_data_read())
stream.seek(len(data))
assert_(stream.all_data_read())
def test_all_data_read_bad_checksum(self):
COMPRESSION_LEVEL = 6
data = np.arange(33707000).astype(np.uint8).tobytes()
compressed_data = zlib.compress(data, COMPRESSION_LEVEL)
compressed_data_len = len(compressed_data)
# check that part of the checksum overlaps
assert_(compressed_data_len == BLOCK_SIZE + 2)
# break checksum
compressed_data = compressed_data[:-1] + bytes([(compressed_data[-1] + 1) & 255])
compressed_stream = BytesIO(compressed_data)
stream = ZlibInputStream(compressed_stream, compressed_data_len)
assert_(not stream.all_data_read())
stream.seek(len(data))
assert_raises(zlib.error, stream.all_data_read)