diff --git a/src/pygambit/behavmixed.pxi b/src/pygambit/behavmixed.pxi index e37d465f7..1df47363c 100644 --- a/src/pygambit/behavmixed.pxi +++ b/src/pygambit/behavmixed.pxi @@ -19,6 +19,7 @@ # along with this program; if not, write to the Free Software # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. # +from pygambit.util import MultiIndexSeriesFormatter import cython from cython.operator cimport dereference as deref @@ -372,7 +373,7 @@ class MixedBehaviorProfile: raise ValueError("Cannot create a MixedBehaviorProfile outside a Game.") def __repr__(self) -> str: - return str([self[player] for player in self.game.players]) + return MultiIndexSeriesFormatter().format(self.as_dict()) def _repr_latex_(self) -> str: return ( @@ -382,6 +383,14 @@ class MixedBehaviorProfile: + r"\right]$" ) + def as_dict(self) -> dict[tuple, float]: + result = {} + for player in self.game.players: + for info in player.infosets: + for action in info.actions: + result[(player.label, info.label, action.label)] = self[action] + return result + @property def game(self) -> Game: """The game on which this mixed behavior profile is defined.""" diff --git a/src/pygambit/stratmixed.pxi b/src/pygambit/stratmixed.pxi index a5ee63707..b1faebffb 100644 --- a/src/pygambit/stratmixed.pxi +++ b/src/pygambit/stratmixed.pxi @@ -193,7 +193,7 @@ class MixedStrategyProfile: raise ValueError("Cannot create a MixedStrategyProfile outside a Game.") def __repr__(self) -> str: - return str([self[player] for player in self.game.players]) + return MultiIndexSeriesFormatter().format(self.as_dict()) def _repr_latex_(self) -> str: return ( @@ -203,6 +203,13 @@ class MixedStrategyProfile: r"\right]$" ) + def as_dict(self) -> dict[tuple, float]: + result = {} + for player in self.game.players: + for strategy in player.strategies: + result[(player.label, strategy.label)] = self[player][strategy] + return result + @property def game(self) -> Game: """The game on which this mixed strategy profile is defined.""" diff --git a/src/pygambit/util.py b/src/pygambit/util.py index 714f90fdd..64954cd89 100644 --- a/src/pygambit/util.py +++ b/src/pygambit/util.py @@ -22,3 +22,99 @@ def make_temporary(content: typing.Optional[str] = None) -> pathlib.Path: yield filepath finally: filepath.unlink(missing_ok=True) + + +class MultiIndexSeriesFormatter: + def __init__(self, max_total_width=120, max_column_width=30): + self.max_total_width = max_total_width + self.max_column_width = max_column_width + + def _truncate_text(self, text, max_width): + """ + Truncate text to specified maximum width using ellipsis. + + Args: + text (str): Text to truncate + max_width (int): Maximum allowed width + + Returns: + str: Truncated text + """ + # Convert to string and remove any newlines + text_str = str(text).replace("\n", " ") + + # Leave the text as is if its length is under max_width + if len(text_str) <= max_width: + return text_str + + # If max_width is less than 3, just return first characters + if max_width < 3: + return text_str[:max_width] + + # Truncate with ellipsis + return text_str[:max_width-3] + "..." + + def format(self, series: dict[tuple[str, ...], float]): + """ + Generate a string representation similar to pandas Series with multi-index. + + Args: + max_total_width (int): Maximum total width of the output + max_column_width (int): Maximum width for each column + + Returns: + str: Formatted string representation of the instance + """ + + # Prepare the output lines + output_lines = [] + + # Calculate column widths + # First, get the maximum width for each column + column_widths = [0] * len(next(iter(series.keys()))) + value_width = 0 + + # Determine column widths for index columns + for key in series: + for i, part in enumerate(key): + column_widths[i] = max(column_widths[i], len(str(part))) + + # Determine value width + value_width = max(len(str(val)) for val in series.values()) + + # Adjust column widths if total exceeds max_total_width + total_width = sum(column_widths) + value_width + len(next(iter(series.keys()))) + if total_width > self.max_total_width: + # Proportionally reduce column widths + reduction_factor = self.max_total_width / total_width + column_widths = [max(3, int(w * reduction_factor)) for w in column_widths] + + # Limit each column to max_column_width + column_widths = [min(w, self.max_column_width) for w in column_widths] + + # Format each line + keys = list(series.keys()) + for i, key in enumerate(keys): + # Truncate each part of the multi-index + formatted_parts = [ + self._truncate_text(str(part), width) + if i == 0 or part != previous_part else self._truncate_text(" " * width, width) + for part, width, previous_part in zip(key, column_widths, keys[i-1]) + ] + + # Pad or truncate each part to its designated width + formatted_index = " ".join( + part.ljust(width) if len(part) <= width else part[:width] + for part, width in zip(formatted_parts, column_widths) + ) + + # Format the value + value = self._truncate_text(str(series[key]), self.max_column_width) + + # Combine index and value + output_lines.append(f"{formatted_index} {value}") + + # Construct the final representation + repr_str = "\n".join(output_lines) + + return repr_str diff --git a/tests/test_io.py b/tests/test_io.py index 8e3536ab3..f78b3d9bb 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -89,3 +89,27 @@ def test_read_write_nfg(): deserialized_nfg_game = gbt.read_nfg(io.BytesIO(serialized_nfg_game.encode())) double_serialized_nfg_game = deserialized_nfg_game.to_nfg() assert serialized_nfg_game == double_serialized_nfg_game + + +def test_print_mixed_strategy_profile(): + game_path = os.path.join("tests", "test_games", "mixed_behavior_game.efg") + test_game = gbt.read_efg(game_path) + text_string = "\n".join(["Player 1 1 0.5", + " 2 0.5", + "Player 2 1 0.5", + " 2 0.5", + "Player 3 1 0.5", + " 2 0.5"]) + assert repr(test_game.mixed_strategy_profile()) == text_string + + +def test_print_mixed_behavior_profile(): + game_path = os.path.join("tests", "test_games", "mixed_behavior_game.efg") + test_game = gbt.read_efg(game_path) + text_string = "\n".join(["Player 1 Infoset 1:1 U1 0.5", + " D1 0.5", + "Player 2 Infoset 2:1 U2 0.5", + " D2 0.5", + "Player 3 Infoset 3:1 U3 0.5", + " D3 0.5"]) + assert repr(test_game.mixed_behavior_profile()) == text_string