diff --git a/activerecord/lib/active_record/transactions.rb b/activerecord/lib/active_record/transactions.rb index bcb2ba04d3566d861cc394dc7e01dfa71f34557f..5702d6ca001ef71bc7f55e6136c0c62a33f91647 100644 --- a/activerecord/lib/active_record/transactions.rb +++ b/activerecord/lib/active_record/transactions.rb @@ -340,6 +340,7 @@ def before_committed! # :nodoc: # Ensure that it is not called if the object was never persisted (failed create), # but call it after the commit of a destroyed object. def committed!(should_run_callbacks: true) #:nodoc: + force_clear_transaction_record_state if should_run_callbacks && (destroyed? || persisted?) @_committed_already_called = true _run_commit_without_transaction_enrollment_callbacks @@ -347,7 +348,6 @@ def committed!(should_run_callbacks: true) #:nodoc: end ensure @_committed_already_called = false - force_clear_transaction_record_state end # Call the #after_rollback callbacks. The +force_restore_state+ argument indicates if the record diff --git a/activerecord/test/cases/transaction_callbacks_test.rb b/activerecord/test/cases/transaction_callbacks_test.rb index c0be45eee72f9adf2b523904c137bbb2da21624a..f0e67333b40b0bd54110bfe1b94d40cd9e66c86c 100644 --- a/activerecord/test/cases/transaction_callbacks_test.rb +++ b/activerecord/test/cases/transaction_callbacks_test.rb @@ -350,6 +350,24 @@ def test_after_rollback_callback_should_not_swallow_errors_when_set_to_raise end end + def test_after_commit_callback_should_not_rollback_state_that_already_been_succeeded + klass = Class.new(TopicWithCallbacks) do + self.inheritance_column = nil + validates :title, presence: true + def self.name; "TopicWithCallbacks"; end + end + + first = klass.new(title: "foo") + first.after_commit_block { |r| r.update(title: nil) if r.persisted? } + first.save! + + assert_predicate first, :persisted? + assert_not_nil first.id + ensure + first.destroy! + end + uses_transaction :test_after_commit_callback_should_not_rollback_state_that_already_been_succeeded + def test_after_rollback_callback_when_raise_should_restore_state error_class = Class.new(StandardError)