diff --git a/spring-orm/src/main/java/org/springframework/orm/jpa/hibernate/SpringBeanContainer.java b/spring-orm/src/main/java/org/springframework/orm/jpa/hibernate/SpringBeanContainer.java index be1ebc3437d8..80b70a53ed6c 100644 --- a/spring-orm/src/main/java/org/springframework/orm/jpa/hibernate/SpringBeanContainer.java +++ b/spring-orm/src/main/java/org/springframework/orm/jpa/hibernate/SpringBeanContainer.java @@ -69,6 +69,7 @@ * integration will be registered out of the box. * * @author Juergen Hoeller + * @author Yanming Zhou * @since 7.0 * @see LocalSessionFactoryBean#setBeanFactory * @see LocalSessionFactoryBuilder#setBeanContainer @@ -139,17 +140,18 @@ public void stop() { } - private SpringContainedBean createBean( - Class beanType, LifecycleOptions lifecycleOptions, BeanInstanceProducer fallbackProducer) { + private SpringContainedBean createBean( + Class beanType, LifecycleOptions lifecycleOptions, BeanInstanceProducer fallbackProducer) { try { if (lifecycleOptions.useJpaCompliantCreation()) { return new SpringContainedBean<>( + beanType, this.beanFactory.createBean(beanType), this.beanFactory::destroyBean); } else { - return new SpringContainedBean<>(this.beanFactory.getBean(beanType)); + return new SpringContainedBean<>(beanType, this.beanFactory.getBean(beanType)); } } catch (BeansException ex) { @@ -158,7 +160,7 @@ private SpringContainedBean createBean( beanType + ": " + ex); } try { - return new SpringContainedBean<>(fallbackProducer.produceBeanInstance(beanType)); + return new SpringContainedBean<>(beanType, fallbackProducer.produceBeanInstance(beanType)); } catch (RuntimeException ex2) { if (ex instanceof BeanCreationException) { @@ -176,42 +178,44 @@ private SpringContainedBean createBean( } } - private SpringContainedBean createBean( - String name, Class beanType, LifecycleOptions lifecycleOptions, BeanInstanceProducer fallbackProducer) { + @SuppressWarnings("unchecked") + private SpringContainedBean createBean( + String name, Class beanType, LifecycleOptions lifecycleOptions, BeanInstanceProducer fallbackProducer) { try { if (lifecycleOptions.useJpaCompliantCreation()) { - Object bean = null; + B bean = null; if (fallbackProducer instanceof TypeBootstrapContext) { // Special Hibernate type construction rules, including TypeBootstrapContext resolution. bean = fallbackProducer.produceBeanInstance(name, beanType); } if (this.beanFactory.containsBean(name)) { if (bean == null) { - bean = this.beanFactory.autowire(beanType, AutowireCapableBeanFactory.AUTOWIRE_CONSTRUCTOR, false); + bean = (B) this.beanFactory.autowire(beanType, AutowireCapableBeanFactory.AUTOWIRE_CONSTRUCTOR, false); } this.beanFactory.autowireBeanProperties(bean, AutowireCapableBeanFactory.AUTOWIRE_NO, false); this.beanFactory.applyBeanPropertyValues(bean, name); - bean = this.beanFactory.initializeBean(bean, name); - return new SpringContainedBean<>(bean, beanInstance -> this.beanFactory.destroyBean(name, beanInstance)); + bean = (B) this.beanFactory.initializeBean(bean, name); + return new SpringContainedBean<>(beanType, bean, beanInstance -> this.beanFactory.destroyBean(name, beanInstance)); } else if (bean != null) { // No bean found by name but constructed with TypeBootstrapContext rules this.beanFactory.autowireBeanProperties(bean, AutowireCapableBeanFactory.AUTOWIRE_NO, false); - bean = this.beanFactory.initializeBean(bean, name); - return new SpringContainedBean<>(bean, this.beanFactory::destroyBean); + bean = (B) this.beanFactory.initializeBean(bean, name); + return new SpringContainedBean<>(beanType, bean, this.beanFactory::destroyBean); } else { // No bean found by name -> construct by type using createBean return new SpringContainedBean<>( + beanType, this.beanFactory.createBean(beanType), this.beanFactory::destroyBean); } } else { return (this.beanFactory.containsBean(name) ? - new SpringContainedBean<>(this.beanFactory.getBean(name, beanType)) : - new SpringContainedBean<>(this.beanFactory.getBean(beanType))); + new SpringContainedBean<>(beanType, this.beanFactory.getBean(name, beanType)) : + new SpringContainedBean<>(beanType, this.beanFactory.getBean(beanType))); } } catch (BeansException ex) { @@ -220,7 +224,7 @@ else if (bean != null) { beanType + " with name '" + name + "': " + ex); } try { - return new SpringContainedBean<>(fallbackProducer.produceBeanInstance(name, beanType)); + return new SpringContainedBean<>(beanType, fallbackProducer.produceBeanInstance(name, beanType)); } catch (RuntimeException ex2) { if (ex instanceof BeanCreationException) { @@ -241,15 +245,19 @@ else if (bean != null) { private static final class SpringContainedBean implements ContainedBean { + private final Class beanClass; + private final B beanInstance; private @Nullable Consumer destructionCallback; - public SpringContainedBean(B beanInstance) { + public SpringContainedBean(Class beanClass, B beanInstance) { + this.beanClass = beanClass; this.beanInstance = beanInstance; } - public SpringContainedBean(B beanInstance, Consumer destructionCallback) { + public SpringContainedBean(Class beanClass, B beanInstance, Consumer destructionCallback) { + this.beanClass = beanClass; this.beanInstance = beanInstance; this.destructionCallback = destructionCallback; } @@ -260,9 +268,8 @@ public B getBeanInstance() { } @Override - @SuppressWarnings("unchecked") public Class getBeanClass() { - return (Class) this.beanInstance.getClass(); + return this.beanClass; } public void destroyIfNecessary() { diff --git a/spring-orm/src/test/java/org/springframework/orm/jpa/hibernate/HibernateNativeEntityManagerFactorySpringBeanContainerIntegrationTests.java b/spring-orm/src/test/java/org/springframework/orm/jpa/hibernate/HibernateNativeEntityManagerFactorySpringBeanContainerIntegrationTests.java index cfd05142c597..94a2a3af3958 100644 --- a/spring-orm/src/test/java/org/springframework/orm/jpa/hibernate/HibernateNativeEntityManagerFactorySpringBeanContainerIntegrationTests.java +++ b/spring-orm/src/test/java/org/springframework/orm/jpa/hibernate/HibernateNativeEntityManagerFactorySpringBeanContainerIntegrationTests.java @@ -33,6 +33,7 @@ import org.springframework.orm.jpa.hibernate.beans.MultiplePrototypesInSpringContextTestBean; import org.springframework.orm.jpa.hibernate.beans.NoDefinitionInSpringContextTestBean; import org.springframework.orm.jpa.hibernate.beans.SinglePrototypeInSpringContextTestBean; +import org.springframework.orm.jpa.hibernate.beans.TestBean; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; @@ -42,6 +43,7 @@ * * @author Yoann Rodiere * @author Juergen Hoeller + * @author Yanming Zhou */ class HibernateNativeEntityManagerFactorySpringBeanContainerIntegrationTests extends AbstractEntityManagerFactoryIntegrationTests { @@ -275,6 +277,20 @@ void testOriginalExceptionInCaseOfFallbackProducerFailureByName() { )); } + @Test + void testRetrieveBeanShouldRetainOriginalBeanType() { + BeanContainer beanContainer = getBeanContainer(); + assertThat(beanContainer).isNotNull(); + + ContainedBean bean = beanContainer.getBean( + "single", TestBean.class, + NativeLifecycleOptions.INSTANCE, IneffectiveBeanInstanceProducer.INSTANCE + ); + + assertThat(bean).isNotNull(); + assertThat(bean.getBeanClass()).isSameAs(TestBean.class); + } + /** * The lifecycle options mandated by the JPA spec and used as a default in Hibernate ORM.