Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] aec_to_parallel_wrapper optimization #1273

Open
escolanogui opened this issue Mar 11, 2025 · 0 comments
Open

[Question] aec_to_parallel_wrapper optimization #1273

escolanogui opened this issue Mar 11, 2025 · 0 comments
Labels
question Further information is requested

Comments

@escolanogui
Copy link

Question

Hello! I am a somewhat beginner using PettingZoo and had some questions about the aec_to_parallel_wrapper. In particular the step method.

def step(self, actions):
        rewards = defaultdict(int)
        terminations = {}
        truncations = {}
        infos = {}
        observations = {}
        for agent in self.aec_env.agents:
            if agent != self.aec_env.agent_selection:
                if self.aec_env.terminations[agent] or self.aec_env.truncations[agent]:
                    raise AssertionError(
                        f"expected agent {agent} got termination or truncation agent {self.aec_env.agent_selection}. Parallel environment wrapper expects all agent death (setting an agent's self.terminations or self.truncations entry to True) to happen only at the end of a cycle."
                    )
                else:
                    raise AssertionError(
                        f"expected agent {agent} got agent {self.aec_env.agent_selection}, Parallel environment wrapper expects agents to step in a cycle."
                    )
            obs, rew, termination, truncation, info = self.aec_env.last()
            self.aec_env.step(actions[agent])
            for agent in self.aec_env.agents:
                rewards[agent] += self.aec_env.rewards[agent]

        terminations = dict(**self.aec_env.terminations)
        truncations = dict(**self.aec_env.truncations)
        infos = dict(**self.aec_env.infos)
        observations = {
            agent: self.aec_env.observe(agent) for agent in self.aec_env.agents
        }
        while self.aec_env.agents and (
            self.aec_env.terminations[self.aec_env.agent_selection]
            or self.aec_env.truncations[self.aec_env.agent_selection]
        ):
            self.aec_env.step(None)

        self.agents = self.aec_env.agents
        return observations, rewards, terminations, truncations, infos

When testing this wrapper i have observed that self.aec_env.observe(agent) is called twice; once in obs, rew, termination, truncation, info = self.aec_env.last() and once in observations = {agent: self.aec_env.observe(agent) for agent in self.aec_env.agents}. So we are computing twice the observation for each agent, which if it is expensive or aren't caching can lead to a big decrease in performance.

Would it be sensible to save the observation from the call obs, rew, termination, truncation, info = self.aec_env.last() in an auxiliary variable and select the appropriate ones in {agent: self.aec_env.observe(agent) for agent in self.aec_env.agents}?

Sorry if i have missed/missunderstood something and thank you for your time.

@escolanogui escolanogui added the question Further information is requested label Mar 11, 2025
@escolanogui escolanogui changed the title [Question] Question title [Question] aec_to_parallel_wrapper optimization Mar 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

1 participant