diff --git a/bindings/python/fdb/impl.py b/bindings/python/fdb/impl.py index 77c7a74d13..e69bda22cb 100644 --- a/bindings/python/fdb/impl.py +++ b/bindings/python/fdb/impl.py @@ -22,16 +22,17 @@ import ctypes import ctypes.util +import datetime import functools +import inspect +import multiprocessing +import os +import platform +import sys import threading import traceback -import inspect -import datetime -import platform -import os -import sys -import multiprocessing +import fdb from fdb import six _network_thread = None @@ -203,7 +204,9 @@ def transactional(*tr_args, **tr_kwargs): It is important to note that the wrapped method may be called multiple times in the event of a commit failure, until the commit - succeeds. + succeeds. This restriction requires that the wrapped function + may not be a generator, or a function that returns a closure that + contains the `tr` object. If given a Transaction, the Transaction will be passed into the wrapped code, and WILL NOT be committed at completion of the @@ -247,7 +250,6 @@ def transactional(*tr_args, **tr_kwargs): except FDBError as e: yield asyncio.From(tr.on_error(e.code)) else: - @functools.wraps(func) def wrapper(*args, **kwargs): if isinstance(args[index], TransactionRead): @@ -269,6 +271,9 @@ def transactional(*tr_args, **tr_kwargs): except FDBError as e: tr.on_error(e.code).wait() + if fdb.get_api_version() >= 620 and isinstance(ret, types.GeneratorType): + raise ValueError("Generators can not be wrapped with fdb.transactional") + # now = datetime.datetime.now() # td = now - last # elapsed = (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10**6) / float(10**6) diff --git a/bindings/python/tests/tester.py b/bindings/python/tests/tester.py index 32ae2c01a3..c430d699f9 100644 --- a/bindings/python/tests/tester.py +++ b/bindings/python/tests/tester.py @@ -125,6 +125,16 @@ class Instruction: self.stack.push(self.index, val) +def test_fdb_transactional_generator(db): + try: + @fdb.transactional + def function_that_yields(tr): + yield 0 + assert fdb.get_api_version() < 620, "Generators post-6.2.0 should throw" + except ValueError as e: + pass + + def test_db_options(db): db.options.set_max_watches(100001) db.options.set_datacenter_id("dc_id")