diff --git a/qupulse/examples/expectation_checker.py b/qupulse/examples/expectation_checker.py new file mode 100644 index 000000000..caff2acfb --- /dev/null +++ b/qupulse/examples/expectation_checker.py @@ -0,0 +1,481 @@ +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.patheffects as pe +from typing import Dict, Tuple, Optional +import warnings +import time +from datetime import datetime +import xarray as xr + +from qupulse.pulses.pulse_template import PulseTemplate +from qupulse.program import waveforms +from qupulse.program.linspace import LinSpaceBuilder, to_increment_commands +from qupulse.program.loop import LoopBuilder +from qupulse.hardware.setup import HardwareSetup, PlaybackChannel, MarkerChannel, MeasurementMask +from qupulse.hardware.dacs import AlazarCard +from qupulse.pulses import PointPT, ConstantPT, RepetitionPT, ForLoopPT, TablePT, \ + FunctionPT, AtomicMultiChannelPT, SequencePT, MappingPT, ParallelConstantChannelPT +from qupulse.program import ProgramBuilder +from qupulse.plotting import plot, render, _render_loop + +from qupulse_hdawg.zihdawg import HDAWGRepresentation, HDAWGChannelGrouping + +import atsaverage +from atsaverage import alazar +from atsaverage.config import ScanlineConfiguration, CaptureClockConfiguration, EngineTriggerConfiguration,\ + TRIGInputConfiguration, InputConfiguration +import atsaverage.server +import atsaverage.client +import atsaverage.core +import atsaverage.config as config + + +#%% + +class HDAWGAlazar(): + ''' + gets locally installed alazar and hdawg in hardcoded config + meaning 2V-peak-to-peak range, channel name scheme "ZI0_A", "ZI0_B" etc. + ''' + def __init__(self, + awg_serial: str = "DEVXXXX", + awg_interface: str = "USB", #or 1GbE for LAN + awg_sample_rate: float = 1e9, + + ): + + self._hw_setup = HardwareSetup() + self._alazar = get_alazar() + self._init_alazar() + self._awg_serial = awg_serial + self._awg_sample_rate = awg_sample_rate + self._hdawg = get_hdawg(self._hw_setup,awg_serial,awg_interface) + self._init_hdawg() + + @property + def alazar(self) -> AlazarCard: + return self._alazar + + @property + def hdawg(self) -> HDAWGRepresentation: + return self._alazar + + def _init_hdawg(self): + self._hdawg.api_session.setDouble(f'/{self._awg_serial}/system/clocks/sampleclock/freq', self._awg_sample_rate) + + for i in range(8): + self._hdawg.api_session.setDouble(f'/{self._awg_serial}/sigouts/{i}/range', 2) + self._hdawg.api_session.setInt(f'/{self._awg_serial}/sigouts/{i}/on', 1) + + #Hacky piece of code to enable Hardware triggering on the second AWG, will probably be deprecated at some point + for awg in self._hw_setup.known_awgs: + # awg._program_manager._compiler_settings[0][1]['trigger_wait_code']='waitDigTrigger(1);' + awg._program_manager._compiler_settings[0][1]['trigger_wait_code']='' + + def _init_alazar(self): + + atsaverage.server.Server.default_instance.start(b'blabla') + atsaverage.server.Server.default_instance.connect_gui_window_from_new_process() + + warnings.filterwarnings("ignore", category=DeprecationWarning) + + + +class ExpectationChecker(): + ''' + trig-channels are all marker outputs from hdawg. + ensure more or less equal length cables from hdawg to alazar (including trig) + assumes channels in ascending order, e.g. first alazar channel to first hdawg + ''' + MARKER_CHANS = {f'ZI0_{x}_MARKER_FRONT' for x in 'ABCDEFGH'} + PLAY_CHANS = {f'ZI0_{x}' for x in 'ABCDEFGH'} + ALLOWED_CHANS = MARKER_CHANS | PLAY_CHANS + MEAS_CHANS = 'ABCD' + ALAZAR_RAW_RATE = 100_000_000 + MEAS_NAME = 'prog' + + def __init__(self, + devices: HDAWGAlazar|None, + pt: PulseTemplate|None, + program_builder: ProgramBuilder|None, + sample_rate_alazar: float = 1e-1, #ns + sample_rate_pt_plot: float = 1e0, #ns + save_path: str|None = None, + data_offsets: Dict[str,float] = {'t_offset':100,'v_offset':0.008,'v_scale':0.9975} + ): + + self._devices = devices + if pt is not None: + self._pt = self._approve_pt(pt) + else: + self._pt = None + self.program_builder = program_builder + + self._number_samples_per_average = int(sample_rate_alazar*1e9 // self.ALAZAR_RAW_RATE) + assert sample_rate_alazar*1e9 % self.ALAZAR_RAW_RATE == 0, f'only sample rates as multiples of {1/self.ALAZAR_RAW_RATE*1e9} ns' + assert self._number_samples_per_average >= 1, 'alazar sample rate too high' + + self._save_path = save_path + + self._data_offsets = data_offsets + + @classmethod + def from_file(cls, path: str, data_offsets: Optional[Dict[str,float]] = None) -> 'ExpectationChecker': + ds = xr.open_dataset(path) + file_data_offsets = { + 't_offset': ds.attrs['offsets_to_be_applied'][0], + 'v_offset': ds.attrs['offsets_to_be_applied'][1], + 'v_scale': ds.attrs['offsets_to_be_applied'][2] + } + + + ec = cls(None,None,None,save_path=None,data_offsets=data_offsets if data_offsets is not None else file_data_offsets) + ec._result = ds + return ec + + def _approve_pt(self, pt: PulseTemplate) -> bool: + + assert pt.defined_channels <= self.ALLOWED_CHANS, 'name chans according to hardcoded naming scheme (ZI0_X)' + + if any(ch in self.MARKER_CHANS for ch in pt.defined_channels): + print('overwriting marker channel(s) to be always on') + pt_mapped = ParallelConstantChannelPT(pt, {x:1. for x in self.MARKER_CHANS}, identifier=pt.identifier) + + + return pt_mapped + + def run(self): + self._register_program() + self._run_aqcuire() + self._save() + self._plot_result() + + def _save(self): + if self._save_path: + name = self._pt.identifier+'_' if self._pt.identifier else '' + name += str(datetime.now())[:19].replace(':','_').replace(' ','_') + self._result.to_netcdf(self._save_path+'//'+name+'.nc') + + def _plot_result(self): + + fig = plt.figure(figsize=(5,3),dpi=300) + gs0 = fig.add_gridspec(1, 1, height_ratios=[1.0,]) + ax = fig.add_subplot(gs0[0]) + + SMALL_SIZE = 8 + MEDIUM_SIZE = 8 + BIGGER_SIZE = 12 + + plt.rc('font', size=SMALL_SIZE) # controls default text sizes + plt.rc('axes', titlesize=SMALL_SIZE) # fontsize of the axes title + plt.rc('axes', labelsize=MEDIUM_SIZE) # fontsize of the x and y labels + plt.rc('xtick', labelsize=SMALL_SIZE) # fontsize of the tick labels + plt.rc('ytick', labelsize=SMALL_SIZE) # fontsize of the tick labels + plt.rc('legend', fontsize=SMALL_SIZE-1) # legend fontsize + plt.rc('figure', titlesize=BIGGER_SIZE) # fontsize of the figure title + plt.rc("xtick.minor", visible=True) + plt.rc("ytick.minor", visible=True) + + colors = plt.cm.Dark2.colors + + + #measured: + for ch in [ch for ch in self._result.data_vars if ch.endswith('meas')]: + ax.plot(self._result[ch].coords['time'].data+self._data_offsets.get('t_offset',0.), + self._data_offsets.get('v_scale',1.)*self._result[ch].data+self._data_offsets.get('v_offset',0.), + color=colors[self.MEAS_CHANS.find(ch[0])], + linestyle='', + marker="o",markerfacecolor=(1.,1.,1.,0.),#colors[i]+(0.5,), + markeredgecolor=colors[self.MEAS_CHANS.find(ch[0])]+(0.5,),markeredgewidth=.3,markersize=2.5, + # marker + label=ch + ) + + #expectation: + for ch in [ch for ch in self._result.data_vars if ch.endswith('exp')]: + ax.plot(self._result[ch].coords['time'].data,self._result[ch].data, + color=colors[self.MEAS_CHANS.find(ch[0])], + linestyle='-', + linewidth=1., + path_effects=[pe.Stroke(linewidth=1.5, foreground='black'), pe.Normal()], + zorder=2, + # marker="o",markerfacecolor=None, + # markeredgecolor=colors[i],markeredgewidth=.6,markersize=1., + label=ch + ) + + + ax.set_xlabel('Time (ns)') + ax.set_ylabel('Voltage (V)') + + if n:=self._result.attrs['name']: + ax.set_title(n) + + ax.legend(ncols=2) + + fig.tight_layout() + + return fig, ax + + + def _examine_diff(self): + pass + + def _register_program(self): + + self._devices._hw_setup.clear_programs() + + def run_callback(): + awgs = [] + for a in self._devices._hw_setup.known_awgs: + if a.__class__.__name__ != 'TaborChannelPair': #TODO: not negative check (but zihdawg channels potentially have different names) + awgs += [a] + for awg in awgs[::]: + awg.enable(True) + + return + + name = self.MEAS_NAME + for i,meas_chan in enumerate(self.MEAS_CHANS): + self._devices.alazar.register_mask_for_channel(name+meas_chan,i) + + operations = [] + measurement_masks = [] + for meas_chan in self.MEAS_CHANS: + mask_name = name + meas_chan + operations.append(atsaverage.operations.ChunkedAverage( + maskID=mask_name, identifier=mask_name, chunkSize=self._number_samples_per_average + )) + measurement_masks.append(MeasurementMask(self._devices.alazar, mask_name)) + self._devices._hw_setup.set_measurement(name, + measurement_masks, + allow_multiple_registration=True) + self._devices.alazar.register_operations(name,operations) + + prog = self._pt.create_program(program_builder=self.program_builder,) + # measurements = {mw_name: (begin_length_list[:, 0], begin_length_list[:, 1])} + # measurements = {name+meas_chan: ([0.,],[float(self._pt.duration),]) for meas_chan in self.MEAS_CHANS} + measurements = {name: ([0.,],[float(self._pt.duration),])} + + self._devices._hw_setup.register_program(self.MEAS_NAME, prog, run_callback=run_callback, + measurements=measurements + ) + + + def _run_aqcuire(self): + + _results_dict = alazar_measure(self._devices,program_id=self.MEAS_NAME) + + dur = float(self._pt.duration) + times = np.linspace(0.,dur,len(_results_dict['progA'][0])) + + _, voltages, _ = render_pt_with_times(self._pt,times,{ch for ch in self.PLAY_CHANS}) + + self._result = xr.Dataset( + { + key.removeprefix(self.MEAS_NAME)+'_meas': (["time",], arrarr[0]) + for key,arrarr in _results_dict.items() + }| + { + key.removeprefix('ZI0_')+'_exp': (["time",], val) for key,val in voltages.items() + } + , + coords={ + "time": times, + }, + attrs={'name': self._pt.identifier if self._pt.identifier else '', + 'offsets_to_be_applied': (self._data_offsets.get('t_offset',0.), + self._data_offsets.get('v_offset',0.), + self._data_offsets.get('v_scale',1.)) + } + ) + + return + + +#%% + + +def render_pt_with_times(pt,times,plot_channels + ) -> Tuple[np.ndarray, Dict[str, np.ndarray]]: + + waveform, measurements = _render_loop(pt.create_program(), render_measurements=False) + + if waveform is None: + return np.array([]), dict(), measurements + + channels = waveform.defined_channels + + voltages = {ch: waveforms._ALLOCATION_FUNCTION(times, **waveforms._ALLOCATION_FUNCTION_KWARGS) + for ch in channels if ch in plot_channels} + # print(voltages.keys()) + # print(plot_channels) + for ch, ch_voltage in voltages.items(): + waveform.get_sampled(channel=ch, sample_times=times, output_array=ch_voltage) + + return times, voltages, measurements + + +#%% measure program + +def alazar_measure(devices: HDAWGAlazar, program_id: str = None, + *args, **kwargs, + ) -> Dict[str,np.ndarray]: + + # alazar_hack_kwargs = {} + #Error 582 "ApiBufferOverflow": + alazar_hack_kwargs = {'extend_to_all_channels':True} + + try: + devices._hw_setup.run_program(program_id, + # **alazar_hack_kwargs + ) + print('\n--- Program run ---') + results = devices.alazar.measure_program() + print('--- Program measured ---') + + except RuntimeError as err: + print(err) + print('Error occured, second run') + # pass + time.sleep(2.0) + devices.alazar.update_settings=True + devices._hw_setup.run_program(program_id, + # **alazar_hack_kwargs + ) + results = devices.alazar.measure_program() + + #dict with keys single_measurement_name+[A/B/C/D] for channels, + #of length corresponding to + #some datapoints for every measurement window + + if 'average_n_consecutive' in kwargs.keys(): + average_n_consecutive = kwargs['average_n_consecutive'] + for key, res_arr in results.items(): + assert len(res_arr)%average_n_consecutive==0,'should have been divisible by average_n_consecutive' + reshaped_arr = res_arr.reshape(-1, average_n_consecutive) + results[key] = reshaped_arr.mean(axis=1) + + return results + + +#%% setup alazar + +def get_alazar(masks=None,operations=None,number_of_samples=None): + + r = 2.5 + rid = alazar.TriggerRangeID.etr_TTL + + # trig_level = int((r + 0.15) / (2*r) * 255) + # I don't know if this is correct? + + trig_level = 4 + + assert 0 <= trig_level < 256 + + + + + config = ScanlineConfiguration() + config.triggerInputConfiguration = TRIGInputConfiguration(triggerRange=rid) + config.triggerConfiguration = EngineTriggerConfiguration(triggerOperation=alazar.TriggerOperation.J, + triggerEngine1=alazar.TriggerEngine.J, + triggerSource1=alazar.TriggerSource.external, + triggerSlope1=alazar.TriggerSlope.positive, + triggerLevel1=trig_level, + triggerEngine2=alazar.TriggerEngine.K, + triggerSource2=alazar.TriggerSource.disable, + triggerSlope2=alazar.TriggerSlope.positive, + triggerLevel2=trig_level) + + config.captureClockConfiguration = CaptureClockConfiguration( + # source=alazar.CaptureClockType.external_clock, + # alazar.CaptureClockType.fast_external_clock, + source=alazar.CaptureClockType.external_clock_10MHz_ref, + # source=alazar.CaptureClockType.internal_clock, + # source=alazar.CaptureClockType.slow_external_clock, + # samplerate=alazar.SampleRateID.rate_100MSPS, + # ATS API: + # "If the clock source chosen is INTERNAL_CLOCK, this value is a member + # of ALAZAR_SAMPLE_RATES that defines the internal sample rate to choose. + # Valid values for each board vary. If the clock source chosen is + # EXTERNAL_CLOCK_10MHZ_REF, pass the value of the sample clock to + # generate from the reference in hertz. The values that can be generated + # depend on the board model. Otherwise, the clock source is external, + # pass SAMPLE_RATE_USER_DEF to this parameter." + # So: + samplerate=100_000_000, #for 100MSPS + + # samplerate=alazar.SampleRateID.rate_125MSPS, + + # decimation=1, + # samplerate=alazar.SampleRateID.user_def + ) + + config.inputConfiguration = 4*[InputConfiguration(input_range=alazar.InputRangeID.range_2_V)] + # config.totalRecordSize = 0 + # config.aimedBufferSize = 10*config.aimedBufferSize + + # is set automatically + assert config.totalRecordSize == 0 + + config.autoExtendTotalRecordSize = 0 + config.AUTO_EXTEND_TOTAL_RECORD_SIZE = 0 #what does th is even do? + + #test from m.b.'s example scripts + if masks is not None and operations is not None and number_of_samples is not None: + print('masks & operations set') + config.masks = masks + config.operations = operations + config.totalRecordSize = number_of_samples + config.autoExtendTotalRecordSize = 1 + + print(config) + + alazar_DAC = AlazarCard(atsaverage.core.getLocalCard(1, 1),config,) + # alazar_DAC.card.applyConfiguration(config) + # alazar_DAC.card.apply_board_configuration(config) + #alazar_DAC.card. + alazar_DAC.card.triggerTimeout = 1200000 + alazar_DAC.card.acquisitionTimeout = 12000000 + alazar_DAC.card.computationTimeout = 1200000 + + #alazar_DAC.config.rawDataMask =atsaverage.atsaverage.ChannelMask(0) + + # channels=['A','B','C','D'] + # for i,channel in enumerate(channels): + # alazar_DAC.register_mask_for_channel(channel, i) + + # return alazar_DAC, config + return(alazar_DAC) + + +#%% get hdawg + + + +def get_hdawg(hardware_setup: HardwareSetup, Serial='DEVXXXX', device_interface="USB", address='localhost',idx=None): + if idx is None: + idx=str(len(hardware_setup.known_awgs)) + hdawg = HDAWGRepresentation(Serial, device_interface=device_interface, data_server_addr=address) + + hdawg.reset() + + hdawg.channel_grouping = HDAWGChannelGrouping.CHAN_GROUP_1x8 + + CH_NAMES = 'ABCDEFGH' + + for awg in hdawg.channel_tuples: + n_channels = awg.num_channels + awg_channels = CH_NAMES[:n_channels] + CH_NAMES = CH_NAMES[n_channels:] + + for ch_i, ch_name in enumerate(awg_channels): + playback_name = 'ZI'+str(idx)+'_{ch_name}'.format(ch_name=ch_name) + hardware_setup.set_channel(playback_name, PlaybackChannel(awg, ch_i)) + hardware_setup.set_channel(playback_name + '_MARKER_FRONT', MarkerChannel(awg, 2 * ch_i)) + hardware_setup.set_channel(playback_name + '_MARKER_BACK', MarkerChannel(awg, 2 * ch_i + 1)) + + return hdawg + + diff --git a/qupulse/examples/expectation_checker_example_script.py b/qupulse/examples/expectation_checker_example_script.py new file mode 100644 index 000000000..9b7b3b58e --- /dev/null +++ b/qupulse/examples/expectation_checker_example_script.py @@ -0,0 +1,142 @@ +# -*- coding: utf-8 -*- +import numpy as np + +from qupulse.pulses import PointPT, ConstantPT, RepetitionPT, ForLoopPT, TablePT,\ + FunctionPT, AtomicMultiChannelPT, SequencePT, MappingPT, ParallelConstantChannelPT +from qupulse.program.linspace import LinSpaceBuilder + +from qupulse.plotting import plot +from qupulse.utils import to_next_multiple + +from expectation_checker import ExpectationChecker, HDAWGAlazar + +#%% Get Devices + +# Get the Alazar and the HDAWG in hardcoded configuration (2V p-p range) +# Connect any Marker from the HDAWG to Trig in of the Alazar +# Connect one dummy channel from HDAWG to the external clock of the Alazar, +# then set 0.5V-10MHz-oscillator on this channel (e.g. in HDAWG-Webinterface) + +ha = HDAWGAlazar("DEVXXXX","USB",) + +#%% Example pulse definitions + +# PulseTemplates must be defined on channels named as 'ZI0_X', X in A to H +# (ensure correct mapping to Alazar) +# Markers will be overwritten to play a marker on each channel to trigger the Alazar +# identifiers of the final PT will be the names of the plotted object + + +class ShortSingleRampTest(): + def __init__(self, base_time=1e3): + hold = ConstantPT(base_time, {'a': '-1. + idx * 0.01'}) + pt = hold.with_iteration('idx', 200) + self.pulse_template = MappingPT(pt, + channel_mapping={'a':'ZI0_A',}, + identifier=self.__class__.__name__ + ) + +class ShortSingleRampTestWithPlay(): + def __init__(self,base_time=1e3+8): + # init = PointPT([(1.0,1e4)],channel_names=('ZI0_MARKER_FRONT',)) + init = FunctionPT('1.0+1e-9*t',base_time,channel='ZI0_A_MARKER_FRONT')#.pad_to(to_next_multiple(1.0,16,4),) + + hold = ConstantPT(base_time, {'a': '-1. + idx * 0.01'})#.pad_to(to_next_multiple(1.0,16,4)) + pt = ParallelConstantChannelPT(init,dict(ZI0_A=0.))@(ParallelConstantChannelPT(hold,dict(ZI0_A_MARKER_FRONT=0.)).with_iteration('idx', 200)) + self.pulse_template = MappingPT(pt, + channel_mapping={'a':'ZI0_A',}, + identifier=self.__class__.__name__ + ) + +class SequencedRepetitionTest(): + def __init__(self,base_time=1e2,rep_factor=2): + wait = AtomicMultiChannelPT( + ConstantPT(f'64*{base_time}', {'a': '-1. + idx_a * 0.01 + y_gain', }), + ConstantPT(f'64*{base_time}', {'b': '-0.5 + idx_b * 0.02'}) + ) + + dependent_constant = AtomicMultiChannelPT( + ConstantPT(64*base_time, {'a': '-1.0 + y_gain'}), + ConstantPT(64*base_time, {'b': '-0.5 + idx_b*0.02',}), + ) + + dependent_constant2 = AtomicMultiChannelPT( + ConstantPT(64*base_time, {'a': '-0.5 + y_gain'}), + ConstantPT(64*base_time, {'b': '-0.3 + idx_b*0.02',}), + ) + + + pt = (dependent_constant @ dependent_constant2.with_repetition(rep_factor) @ (wait.with_iteration('idx_a', rep_factor))).with_iteration('idx_b', rep_factor)\ + + self.pulse_template = MappingPT(pt,parameter_mapping=dict(y_gain=0.3,), + channel_mapping={'a':'ZI0_A','b':'ZI0_C'}, + identifier=self.__class__.__name__ + ) + + +class SteppedRepetitionTest(): + def __init__(self,base_time=1e2,rep_factor=2): + + wait = ConstantPT(f'64*{base_time}*(1+idx_t)', {'a': '-0.5 + idx_a * 0.15', 'b': '-.5 + idx_a * 0.3'}) + normal_pt = ParallelConstantChannelPT(FunctionPT("sin(t/1000)","t_sin",channel='a'),{'b':-0.2}) + amp_pt = ParallelConstantChannelPT("amp*1/8"*FunctionPT("sin(t/2000)","t_sin",channel='a'),{'b':-0.5}) + # amp_pt2 = ParallelConstantChannelPT("amp2*1/8"*FunctionPT("sin(t/1000)","t_sin",channel='a'),{'b':-0.5}) + amp_inner = ParallelConstantChannelPT(FunctionPT(f"(1+amp)*1/(2*{rep_factor})*sin(4*pi*t/t_sin)","t_sin",channel='a'),{'b':-0.5}) + amp_inner2 = ParallelConstantChannelPT(FunctionPT(f"(1+amp2)*1/(2*{rep_factor})*sin((1*freq)*4*pi*t/t_sin)+off/(2*{rep_factor})","t_sin",channel='a'),{'b':-0.3}) + + pt = ((((normal_pt@amp_inner2).with_iteration('off', rep_factor)@normal_pt@wait)\ + .with_repetition(rep_factor))@amp_inner.with_iteration('amp', rep_factor))\ + .with_iteration('amp2', rep_factor).with_iteration('freq', rep_factor).with_iteration('idx_a',rep_factor) + + self.pulse_template = MappingPT(pt,parameter_mapping=dict(t_sin=64*base_time,idx_t=1, + #idx_a=1,#freq=1,#amp2=1 + ), + channel_mapping={'a':'ZI0_A','b':'ZI0_C'}, + identifier=self.__class__.__name__) + +class TimeSweepTest(): + def __init__(self,base_time=1e2,rep_factor=3): + wait = ConstantPT(f'64*{base_time}*(1+idx_t)', + {'a': '-1. + idx_a * 0.01', 'b': '-.5 + idx_b * 0.02'}) + + random_constant = ConstantPT(64*base_time, {'a': -.4, 'b': -.3}) + meas = ConstantPT(64*base_time, {'a': 0.05, 'b': 0.06}) + + singlet_scan = (SequencePT(random_constant,wait,meas,identifier='s')).with_iteration('idx_a', rep_factor)\ + .with_iteration('idx_b', rep_factor)\ + .with_iteration('idx_t', rep_factor) + + self.pulse_template = MappingPT(singlet_scan,channel_mapping={'a':'ZI0_A','b':'ZI0_C'}, + identifier=self.__class__.__name__) + + +#%% Instantiate Checker + +# select exemplary pulse +pulse = ShortSingleRampTest(1e3+8) +# pulse = ShortSingleRampTestWithPlay() +# pulse = SequencedRepetitionTest(1e3,4) +# pulse = SteppedRepetitionTest(1e2,3) +# pulse = TimeSweepTest(1e2,3) + +# Define a program builder to test program with: +program_builder = LinSpaceBuilder( + #set to True to ensure triggering at Program start if program starts with constant pulse + play_marker_when_constant=True, + #in case stepped repetitions are needed, insert variables here: + to_stepping_repeat={'example',}, + ) + +# Data will be saved as xr.Dataset in save_path +# data_offsets corrects for offsets in Alazar (not in saved data, only in plotting) +checker = ExpectationChecker(ha, pulse.pulse_template, + program_builder=program_builder, + save_path=SAVE_HERE, + data_offsets={'t_offset':-100.,'v_offset':0.008,'v_scale':0.9975} + ) + +assert float(pulse.pulse_template.duration) < 1e7, "Ensure you know what you're doing when recording long data" + +#%% Run the checker + +checker.run() diff --git a/qupulse/expressions/simple.py b/qupulse/expressions/simple.py new file mode 100644 index 000000000..60923a126 --- /dev/null +++ b/qupulse/expressions/simple.py @@ -0,0 +1,141 @@ +import numpy as np +from numbers import Real, Number +from typing import Optional, Union, Sequence, ContextManager, Mapping, Tuple, Generic, TypeVar, Iterable, Dict, List +from dataclasses import dataclass + +from functools import total_ordering +from qupulse.utils.sympy import _lambdify_modules +from qupulse.expressions import sympy as sym_expr, Expression +from qupulse.utils.types import MeasurementWindow, TimeType, FrozenMapping + + +NumVal = TypeVar('NumVal', bound=Real) + + +@total_ordering +@dataclass +class SimpleExpression(Generic[NumVal]): + """This is a potential hardware evaluable expression of the form + + C + C1*R1 + C2*R2 + ... + where R1, R2, ... are potential runtime parameters. + + The main use case is the expression of for loop dependent variables where the Rs are loop indices. There the + expressions can be calculated via simple increments. + """ + + base: NumVal + offsets: Mapping[str, NumVal] + + def __post_init__(self): + assert isinstance(self.offsets, Mapping) + + def value(self, scope: Mapping[str, NumVal]) -> NumVal: + value = self.base + for name, factor in self.offsets.items(): + value += scope[name] * factor + return value + + def __abs__(self): + return abs(self.base)+sum([abs(o) for o in self.offsets.values()]) + + def __eq__(self, other): + #there is no good way to compare it without having a value, + #but cannot require more parameters in magic method? + #so have this weird full equality for now which doesn logically make sense + #in most cases to catch unintended consequences + + if isinstance(other, (float, int, TimeType)): + return self.base==other and all([o==other for o in self.offsets]) + + if type(other) == type(self): + if len(self.offsets)!=len(other.offsets): return False + return self.base==other.base and all([o1==o2 for o1,o2 in zip(self.offsets,other.offsets)]) + + return NotImplemented + + def __gt__(self, other): + return all([b for b in self._return_greater_comparison_bools(other)]) + + def __lt__(self, other): + return all([not b for b in self._return_greater_comparison_bools(other)]) + + def _return_greater_comparison_bools(self, other) -> List[bool]: + #there is no good way to compare it without having a value, + #but cannot require more parameters in magic method? + #so have this weird full equality for now which doesn logically make sense + #in most cases to catch unintended consequences + if isinstance(other, (float, int, TimeType)): + return [self.base>other] + [o>other for o in self.offsets.values()] + + if type(other) == type(self): + if len(self.offsets)!=len(other.offsets): return [False] + return [self.base>other.base] + [o1>o2 for o1,o2 in zip(self.offsets.values(),other.offsets.values())] + + return NotImplemented + + def __add__(self, other): + if isinstance(other, (float, int, TimeType)): + return SimpleExpression(self.base + other, self.offsets) + + if type(other) == type(self): + offsets = self.offsets.copy() + for name, value in other.offsets.items(): + offsets[name] = value + offsets.get(name, 0) + return SimpleExpression(self.base + other.base, offsets) + + return NotImplemented + + def __radd__(self, other): + return self.__add__(other) + + def __sub__(self, other): + return self.__add__(-other) + + def __rsub__(self, other): + return (-self).__add__(other) + + def __neg__(self): + return SimpleExpression(-self.base, {name: -value for name, value in self.offsets.items()}) + + def __mul__(self, other: NumVal): + if isinstance(other, (float, int, TimeType)): + return SimpleExpression(self.base * other, {name: other * value for name, value in self.offsets.items()}) + + return NotImplemented + + def __rmul__(self, other): + return self.__mul__(other) + + def __truediv__(self, other): + inv = 1 / other + return self.__mul__(inv) + + def __hash__(self): + return hash((self.base,frozenset(sorted(self.offsets.items())))) + + @property + def free_symbols(self): + return () + + def _sympy_(self): + return self + + def replace(self, r, s): + return self + + def evaluate_in_scope_(self, *args, **kwargs): + # TODO: remove. It is currently required to avoid nesting this class in an expression for the MappedScope + # We can maybe replace is with a HardwareScope or something along those lines + return self + + +#alibi class to allow instance check? +@dataclass +class SimpleExpressionStepped(SimpleExpression): + step_nesting_level: int + rng: range + reverse: int|bool + + +_lambdify_modules.append({'SimpleExpression': SimpleExpression, 'SimpleExpressionStepped': SimpleExpressionStepped}) diff --git a/qupulse/hardware/awgs/base.py b/qupulse/hardware/awgs/base.py index 5b1bb7c74..f0a41a4d8 100644 --- a/qupulse/hardware/awgs/base.py +++ b/qupulse/hardware/awgs/base.py @@ -17,9 +17,9 @@ from qupulse.hardware.util import get_sample_times, not_none_indices from qupulse.utils.types import ChannelID from qupulse.program.linspace import LinSpaceNode, LinSpaceArbitraryWaveform, to_increment_commands, Command, \ - Increment, Set as LSPSet, LoopLabel, LoopJmp, Wait, Play + Increment, Set as LSPSet, LoopLabel, LoopJmp, Wait, Play, DEFAULT_INCREMENT_RESOLUTION, DepDomain from qupulse.program.loop import Loop -from qupulse.program.waveforms import Waveform +from qupulse.program.waveforms import Waveform, WaveformCollection from qupulse.comparable import Comparable from qupulse.utils.types import TimeType @@ -30,6 +30,8 @@ Program = Loop +SAMPLE_TIME_TOLERANCE = 1e-10 + class AWGAmplitudeOffsetHandling: IGNORE_OFFSET = 'ignore_offset' # Offset is ignored. @@ -191,6 +193,7 @@ def __init__(self, program: AllowedProgramTypes, voltage_transformations: Tuple[Optional[Callable], ...], sample_rate: TimeType, waveforms: Sequence[Waveform] = None, + # voltage_resolution: Optional[float] = None, program_type: _ProgramType = _ProgramType.Loop): """ @@ -204,6 +207,8 @@ def __init__(self, program: AllowedProgramTypes, sample_rate: waveforms: These waveforms are sampled and stored in _waveforms. If None the waveforms are extracted from loop + # voltage_resolution: voltage resolution for LinSpaceProgram, i.e. 2**(-16) for 16 bit AWG + program_type: type of program from _ProgramType, determined by the ProgramBuilder used. """ assert len(channels) == len(amplitudes) == len(offsets) == len(voltage_transformations) @@ -218,9 +223,12 @@ def __init__(self, program: AllowedProgramTypes, self._program_type = program_type self._program = program - if program_type == _ProgramType.Linspace: - self._transformed_commands = self._transform_linspace_commands(to_increment_commands(program)) + # self._voltage_resolution = voltage_resolution + if program_type == _ProgramType.Linspace: + #!!! the voltage resolution may not be adequately represented if voltage transformations are not None? + self._transformed_commands = self._transform_linspace_commands(to_increment_commands(program,)) + if waveforms is None: if program_type is _ProgramType.Loop: waveforms = OrderedDict((node.waveform, None) @@ -228,8 +236,18 @@ def __init__(self, program: AllowedProgramTypes, elif program_type is _ProgramType.Linspace: #not so clean #TODO: also marker handling not optimal - waveforms = OrderedDict((command.waveform, None) - for command in self._transformed_commands if isinstance(command,Play)).keys() + waveforms_d = OrderedDict() + for command in self._transformed_commands: + if not isinstance(command,Play): + continue + if isinstance(command.waveform,Waveform): + waveforms_d[command.waveform] = None + elif isinstance(command.waveform,WaveformCollection): + for w in command.waveform.flatten(): + waveforms_d[w] = None + else: + raise NotImplementedError() + waveforms = waveforms_d.keys() else: raise NotImplementedError() @@ -267,21 +285,35 @@ def _channel_transformations(self) -> Mapping[ChannelID, ChannelTransformation]: def _transform_linspace_commands(self, command_list: List[Command]) -> List[Command]: # all commands = Union[Increment, Set, LoopLabel, LoopJmp, Wait, Play] - trafos_by_channel_idx = list(self._channel_transformations().values()) - + # TODO: voltage resolution + + # trafos_by_channel_idx = list(self._channel_transformations().values()) + # increment_domains_to_transform = {DepDomain.VOLTAGE, DepDomain.WF_SCALE, DepDomain.WF_OFFSET} + for command in command_list: if isinstance(command, (LoopLabel, LoopJmp, Play, Wait)): # play is handled by transforming the sampled waveform continue elif isinstance(command, Increment): - ch_trafo = trafos_by_channel_idx[command.channel] + if command.key.domain is not DepDomain.VOLTAGE or \ + command.channel not in self._channels: + #for sweeps of wf-scale and wf-offset, the channel amplitudes/offsets are already considered in the wf sampling. + continue + + ch_trafo = self._channel_transformations()[command.channel] if ch_trafo.voltage_transformation: - raise RuntimeError("Cannot apply a voltage transformation to a linspace increment command") + if ch_trafo.voltage_transformation(1.0) != 1.0: + raise RuntimeError("Cannot apply a voltage transformation to a linspace increment command") command.value /= ch_trafo.amplitude elif isinstance(command, LSPSet): - ch_trafo = trafos_by_channel_idx[command.channel] + if command.key.domain is not DepDomain.VOLTAGE or \ + command.channel not in self._channels: + #for sweeps of wf-scale and wf-offset, the channel amplitudes/offsets are already considered in the wf sampling. + continue + ch_trafo = self._channel_transformations()[command.channel] if ch_trafo.voltage_transformation: - command.value = float(ch_trafo.voltage_transformation(command.value)) + # for the case of swept parameters, this is defaulted to identity + command.value = ch_trafo.voltage_transformation(command.value) command.value -= ch_trafo.offset command.value /= ch_trafo.amplitude else: @@ -293,7 +325,7 @@ def _sample_waveforms(self, waveforms: Sequence[Waveform]) -> List[Tuple[Tuple[n Tuple[numpy.ndarray, ...]]]: sampled_waveforms = [] - time_array, segment_lengths = get_sample_times(waveforms, self._sample_rate) + time_array, segment_lengths = get_sample_times(waveforms, self._sample_rate, SAMPLE_TIME_TOLERANCE) sample_memory = numpy.zeros_like(time_array, dtype=float) n_samples = numpy.sum(segment_lengths) @@ -308,7 +340,9 @@ def _sample_waveforms(self, waveforms: Sequence[Waveform]) -> List[Tuple[Tuple[n segment_length = int(segment_length) segment_end = segment_begin + segment_length - wf_time = time_array[:segment_length] + # wf_time = time_array[:segment_length] + #hacky (time_array left intact in get_sample_times such that now multiply by undersampling) + wf_time = time_array[:segment_length] * 2**waveform._pow_2_divisor wf_sample_memory = sample_memory[:segment_length] sampled_channels = [] diff --git a/qupulse/hardware/dacs/alazar.py b/qupulse/hardware/dacs/alazar.py index c694c4be6..49ae885e3 100644 --- a/qupulse/hardware/dacs/alazar.py +++ b/qupulse/hardware/dacs/alazar.py @@ -390,13 +390,16 @@ def register_mask_for_channel(self, mask_id: str, hw_channel: int, mask_type='au raise NotImplementedError('Currently only can do cross buffer mask') self._mask_prototypes[mask_id] = (hw_channel, mask_type) - def measure_program(self, channels: Iterable[str]) -> Dict[str, np.ndarray]: + def measure_program(self, channels: Iterable[str] = None) -> Dict[str, np.ndarray]: """ Get all measurements at once and write them in a dictionary. """ scanline_data = self.__card.extractNextScanline() - + + if channels is None: + channels = scanline_data.operationResults.keys() + scanline_definition = scanline_data.definition operation_definitions = {operation.identifier: operation for operation in scanline_definition.operations} diff --git a/qupulse/hardware/setup.py b/qupulse/hardware/setup.py index e976a6743..719872ff2 100644 --- a/qupulse/hardware/setup.py +++ b/qupulse/hardware/setup.py @@ -94,6 +94,7 @@ def register_program(self, name: str, program: Loop, run_callback=lambda: None, update: bool = False, + channels = None, measurements: Mapping[str, Tuple[np.ndarray, np.ndarray]] = None) -> None: """Register a program under a given name at the hardware setup. The program will be uploaded to the participating AWGs and DACs. The run callback is used for triggering the program after arming. @@ -109,7 +110,9 @@ def register_program(self, name: str, if not callable(run_callback): raise TypeError('The provided run_callback is not callable') - channels = next(program.get_depth_first_iterator()).waveform.defined_channels + if channels is None: + # channels = next(program.get_depth_first_iterator()).waveform.defined_channels + channels = program.get_defined_channels() if channels - set(self._channel_map.keys()): raise KeyError('The following channels are unknown to the HardwareSetup: {}'.format( channels - set(self._channel_map.keys()))) diff --git a/qupulse/hardware/util.py b/qupulse/hardware/util.py index 2d656d529..8a4a7a6b1 100644 --- a/qupulse/hardware/util.py +++ b/qupulse/hardware/util.py @@ -128,7 +128,9 @@ def get_waveform_length(waveform: Waveform, Returns: Number of samples for the waveform """ - segment_length = waveform.duration * sample_rate_in_GHz + # segment_length = waveform.duration * sample_rate_in_GHz + #hacky + segment_length = waveform.duration * sample_rate_in_GHz / 2**waveform._pow_2_divisor # __round__ is implemented for Fraction and gmpy2.mpq rounded_segment_length = round(segment_length) @@ -143,7 +145,7 @@ def get_waveform_length(waveform: Waveform, deviation=deviation, rounded_segment_length=rounded_segment_length)) if rounded_segment_length <= 0: - raise ValueError("Error while sampling waveform. Waveform has a length <= zero at the given sample " + raise ValueError(f"Error while sampling waveform. Waveform has a length {rounded_segment_length} <= zero at the given sample " "rate of %rGHz" % sample_rate_in_GHz) segment_length = np.uint64(rounded_segment_length) @@ -246,7 +248,7 @@ def check_invalid_values(ch_data): if ch2 is None: ch2 = np.zeros(size) else: - check_invalid_values(ch1) + check_invalid_values(ch2) marker_data = np.zeros(size, dtype=np.uint16) for idx, marker in enumerate(markers): if marker is not None: diff --git a/qupulse/plotting.py b/qupulse/plotting.py index 0a62451fa..45b85f221 100644 --- a/qupulse/plotting.py +++ b/qupulse/plotting.py @@ -28,7 +28,7 @@ from qupulse.pulses.pulse_template import PulseTemplate from qupulse.program.waveforms import Waveform from qupulse.program.loop import Loop, to_waveform - +from qupulse.program.linspace import LinSpaceTopLevel __all__ = ["render", "plot", "PlottingNotPossibleException"] @@ -37,7 +37,9 @@ def render(program: Union[Loop], sample_rate: Real = 10.0, render_measurements: bool = False, time_slice: Tuple[Real, Real] = None, - plot_channels: Optional[Set[ChannelID]] = None) -> Tuple[np.ndarray, Dict[ChannelID, np.ndarray], + plot_channels: Optional[Set[ChannelID]] = None, + individualize_times: bool = False, + ) -> Tuple[np.ndarray, Dict[ChannelID, np.ndarray], List[MeasurementWindow]]: """'Renders' a pulse program. @@ -50,7 +52,7 @@ def render(program: Union[Loop], render_measurements: If True, the third return value is a list of measurement windows. time_slice: The time slice to be rendered. If None, the entire pulse will be shown. plot_channels: Only channels in this set are rendered. If None, all will. - + individualize_times: return individual time-voltage array pairs for constant value cleanup Returns: A tuple (times, values, measurements). times is a numpy.ndarray of dimensions sample_count where containing the time values. voltages is a dictionary of one numpy.ndarray of dimensions sample_count per @@ -100,10 +102,36 @@ def render(program: Union[Loop], for ch in channels} for ch, ch_voltage in voltages.items(): waveform.get_sampled(channel=ch, sample_times=times, output_array=ch_voltage) - + + + if individualize_times: + # new_dict = {ch: (np.copy(times),volts) for ch,volts in voltages.items()} + new_dict = {} + for ch in channels: + new_dict[ch] = deduplicate_with_aux(voltages[ch],np.copy(times),) + return times, new_dict, measurements + return times, voltages, measurements +def deduplicate_with_aux(arr, aux, threshold=1e-4): + # Calculate the absolute differences between consecutive elements + diffs = np.abs(np.diff(arr, prepend=arr[0])) + + # Use cumsum to track the cumulative differences + cumulative_diffs = np.cumsum(diffs) + + # Find indices where cumulative differences exceed the threshold + mask = np.concatenate(([0],np.where(np.diff(np.floor_divide(cumulative_diffs,threshold),prepend=cumulative_diffs[0])>0)[0],[-1])) + + # Apply the mask to both the main and auxiliary arrays + dedup_arr = arr[mask] + dedup_aux = aux[mask] + + return dedup_aux, dedup_arr + + + def _render_loop(loop: Loop, render_measurements: bool,) -> Tuple[Waveform, List[MeasurementWindow]]: """Transform program into single waveform and measurement windows. @@ -132,6 +160,7 @@ def plot(pulse: Union[PulseTemplate, Loop], stepped: bool=True, maximum_points: int=10**6, time_slice: Tuple[Real, Real]=None, + individualize_times: bool = False, **kwargs) -> Any: # pragma: no cover """Plots a pulse using matplotlib. @@ -182,7 +211,8 @@ def plot(pulse: Union[PulseTemplate, Loop], times, voltages, measurements = render(program, sample_rate, render_measurements=bool(plot_measurements), - time_slice=time_slice) + time_slice=time_slice, + individualize_times=individualize_times) else: times, voltages, measurements = np.array([]), dict(), [] @@ -214,9 +244,15 @@ def plot(pulse: Union[PulseTemplate, Loop], for ch_name, voltage in voltages.items(): label = 'channel {}'.format(ch_name) if stepped: - line, = axes.step(times, voltage, **{**dict(where='post', label=label), **kwargs}) + if individualize_times: + line, = axes.step(voltage[0], voltage[1], **{**dict(where='post', label=label), **kwargs}) + else: + line, = axes.step(times, voltage, **{**dict(where='post', label=label), **kwargs}) else: - line, = axes.plot(times, voltage, **{**dict(label=label), **kwargs}) + if individualize_times: + axes.plot(voltage[0], voltage[1], **{**dict(label=label), **kwargs}) + else: + line, = axes.plot(times, voltage, **{**dict(label=label), **kwargs}) legend_handles.append(line) if plot_measurements: @@ -235,8 +271,8 @@ def plot(pulse: Union[PulseTemplate, Loop], axes.legend(handles=legend_handles) - max_voltage = max((max(channel, default=0) for channel in voltages.values()), default=0) - min_voltage = min((min(channel, default=0) for channel in voltages.values()), default=0) + max_voltage = max((max(channel if not individualize_times else channel[1], default=0) for channel in voltages.values()), default=0) + min_voltage = min((min(channel if not individualize_times else channel[1], default=0) for channel in voltages.values()), default=0) # add some margins in the presentation axes.set_xlim(-0.5+time_slice[0], time_slice[1] + 0.5) diff --git a/qupulse/program/__init__.py b/qupulse/program/__init__.py index 611a96fcd..6f0064cbc 100644 --- a/qupulse/program/__init__.py +++ b/qupulse/program/__init__.py @@ -1,102 +1,13 @@ -import contextlib -from abc import ABC, abstractmethod -from dataclasses import dataclass from typing import Optional, Union, Sequence, ContextManager, Mapping, Tuple, Generic, TypeVar, Iterable, Dict -from numbers import Real, Number - -import numpy as np +from typing import Protocol, runtime_checkable, Set from qupulse._program.waveforms import Waveform -from qupulse.utils.types import MeasurementWindow, TimeType, FrozenMapping +from qupulse.utils.types import MeasurementWindow, TimeType from qupulse._program.volatile import VolatileRepetitionCount from qupulse.parameter_scope import Scope -from qupulse.expressions import sympy as sym_expr, Expression -from qupulse.utils.sympy import _lambdify_modules - -from typing import Protocol, runtime_checkable - - -NumVal = TypeVar('NumVal', bound=Real) - - -@dataclass -class SimpleExpression(Generic[NumVal]): - """This is a potential hardware evaluable expression of the form - - C + C1*R1 + C2*R2 + ... - where R1, R2, ... are potential runtime parameters. - - The main use case is the expression of for loop dependent variables where the Rs are loop indices. There the - expressions can be calculated via simple increments. - """ - - base: NumVal - offsets: Mapping[str, NumVal] - - def __post_init__(self): - assert isinstance(self.offsets, Mapping) - - def value(self, scope: Mapping[str, NumVal]) -> NumVal: - value = self.base - for name, factor in self.offsets: - value += scope[name] * factor - return value - - def __add__(self, other): - if isinstance(other, (float, int, TimeType)): - return SimpleExpression(self.base + other, self.offsets) - - if type(other) == type(self): - offsets = self.offsets.copy() - for name, value in other.offsets.items(): - offsets[name] = value + offsets.get(name, 0) - return SimpleExpression(self.base + other.base, offsets) - - return NotImplemented - - def __radd__(self, other): - return self.__add__(other) - - def __sub__(self, other): - return self.__add__(-other) - - def __rsub__(self, other): - (-self).__add__(other) - - def __neg__(self): - return SimpleExpression(-self.base, {name: -value for name, value in self.offsets.items()}) - - def __mul__(self, other: NumVal): - if isinstance(other, (float, int, TimeType)): - return SimpleExpression(self.base * other, {name: other * value for name, value in self.offsets.items()}) - - return NotImplemented - - def __rmul__(self, other): - return self.__mul__(other) - - def __truediv__(self, other): - inv = 1 / other - return self.__mul__(inv) - - @property - def free_symbols(self): - return () - - def _sympy_(self): - return self - - def replace(self, r, s): - return self - - def evaluate_in_scope_(self, *args, **kwargs): - # TODO: remove. It is currently required to avoid nesting this class in an expression for the MappedScope - # We can maybe replace is with a HardwareScope or something along those lines - return self - - -_lambdify_modules.append({'SimpleExpression': SimpleExpression}) - +from qupulse.expressions import sympy as sym_expr +from qupulse.expressions.simple import SimpleExpression +from qupulse import ChannelID RepetitionCount = Union[int, VolatileRepetitionCount, SimpleExpression[int]] HardwareTime = Union[TimeType, SimpleExpression[TimeType]] @@ -155,10 +66,18 @@ def new_subprogram(self, global_transformation: 'Transformation' = None) -> Cont it is not empty.""" def with_iteration(self, index_name: str, rng: range, + pt_obj: 'ForLoopPT', #hack this in for now. + # can be placed more suitably, like in pulsemetadata later on, but need some working thing now. measurements: Optional[Sequence[MeasurementWindow]] = None) -> Iterable['ProgramBuilder']: pass - - def to_program(self) -> Optional[Program]: + + def evaluate_nested_stepping(self, scope: Scope, parameter_names: set[str]) -> bool: + return False + + def time_reversed(self) -> ContextManager['ProgramBuilder']: + pass + + def to_program(self, defined_channels: Set[ChannelID]) -> Optional[Program]: """Further addition of new elements might fail after finalizing the program.""" diff --git a/qupulse/program/linspace.py b/qupulse/program/linspace.py index 0d454c090..bc3d27d0f 100644 --- a/qupulse/program/linspace.py +++ b/qupulse/program/linspace.py @@ -2,20 +2,159 @@ import contextlib import dataclasses import numpy as np -from dataclasses import dataclass -from typing import Mapping, Optional, Sequence, ContextManager, Iterable, Tuple, Union, Dict, List, Iterator +from dataclasses import dataclass, field +from typing import Mapping, Optional, Sequence, ContextManager, Iterable, Tuple, Union, Dict, List, Iterator, Generic,\ + Set as TypingSet, Callable +from enum import Enum +from itertools import dropwhile, count +from numbers import Real, Number +from collections import defaultdict + from qupulse import ChannelID, MeasurementWindow from qupulse.parameter_scope import Scope, MappedScope, FrozenDict -from qupulse.program import (ProgramBuilder, HardwareTime, HardwareVoltage, Waveform, RepetitionCount, TimeType, - SimpleExpression) -from qupulse.program.waveforms import MultiChannelWaveform +# from qupulse.pulses.pulse_template import PulseTemplate +# from qupulse.pulses import ForLoopPT +from qupulse.program import ProgramBuilder, HardwareTime, HardwareVoltage, Waveform, RepetitionCount, TimeType +from qupulse.expressions.simple import SimpleExpression, NumVal, SimpleExpressionStepped +from qupulse.program.waveforms import MultiChannelWaveform, TransformingWaveform, WaveformCollection, SequenceWaveform + +from qupulse.program.transformation import ChainedTransformation, ScalingTransformation, OffsetTransformation,\ + IdentityTransformation, ParallelChannelTransformation, Transformation # this resolution is used to unify increments # the increments themselves remain floats +# !!! translated: this is NOT a hardware resolution, +# just a programmatic 'small epsilon' to avoid rounding errors. DEFAULT_INCREMENT_RESOLUTION: float = 1e-9 +DEFAULT_TIME_RESOLUTION: float = 1e-3 + +class DepDomain(Enum): + VOLTAGE = 0 + TIME_LIN = -1 + TIME_LOG = -2 + FREQUENCY = -3 + WF_SCALE = -4 + WF_OFFSET = -5 + STEP_INDEX = -6 + NODEP = None + + +class InstanceCounterMeta(type): + def __init__(cls, name, bases, dct): + super().__init__(name, bases, dct) + cls._instance_tracker = {} + + def __call__(cls, *args, **kwargs): + normalized_args = cls._normalize_args(*args, **kwargs) + # Create a key based on the arguments + key = tuple(sorted(normalized_args.items())) + cls._instance_tracker.setdefault(key,count(start=0)) + instance = super().__call__(*args, **kwargs) + instance._channel_num = next(cls._instance_tracker[key]) + return instance + + def _normalize_args(cls, *args, **kwargs): + # Get the parameter names from the __init__ method + param_names = cls.__init__.__code__.co_varnames[1:cls.__init__.__code__.co_argcount] + # Create a dictionary with default values + normalized_args = dict(zip(param_names, args)) + # Update with any kwargs + normalized_args.update(kwargs) + return normalized_args - +@dataclass +class StepRegister(metaclass=InstanceCounterMeta): + #set this as name of sweepval var + register_name: str + register_nesting: int + #should be increased by metaclass every time the class is instantiated with the same arguments + _channel_num: int = dataclasses.field(default_factory=lambda: None) + + @property + def reg_var_name(self): + return self.register_name+'_'+str(self.register_num)+'_'+str(self._channel_num) + + def __hash__(self): + return hash((self.register_name,self.register_nesting,self._channel_num)) + + +GeneralizedChannel = Union[DepDomain,ChannelID,StepRegister] + +# is there any way to cast the numpy cumprod to int? +int_type = Union[np.int64,np.int32,int] + +class ResolutionDependentValue(Generic[NumVal]): + + def __init__(self, + bases: Tuple[NumVal], + multiplicities: Tuple[int], + offset: NumVal): + + self.bases = tuple(bases) + self.multiplicities = tuple(multiplicities) + self.offset = offset + self.__is_time_or_int = all(isinstance(b,(TimeType,int_type)) for b in bases) and isinstance(offset,(TimeType,int_type)) + + #this is not to circumvent float errors in python, but rounding errors from awg-increment commands. + #python float are thereby accurate enough if no awg with a 500 bit resolution is invented. + def __call__(self, resolution: Optional[float]) -> Union[NumVal,TimeType]: + #with resolution = None handle TimeType/int case? + if resolution is None: + assert self.__is_time_or_int + return sum(b*m for b,m in zip(self.bases,self.multiplicities)) + self.offset + #resolution as float value of granularity of base val. + #to avoid conflicts between positive and negative vals from casting half to even, + #use abs val + return sum(np.sign(b) * round(abs(b) / resolution) * m * resolution for b,m in zip(self.bases,self.multiplicities))\ + + np.sign(self.offset) * round(abs(self.offset) / resolution) * resolution + #cast the offset only once? + + def __bool__(self): + return any(bool(b) for b in self.bases) or bool(self.offset) + + def __add__(self, other): + # this should happen in the context of an offset being added to it, not the bases being modified. + if isinstance(other, (float, int, TimeType)): + return ResolutionDependentValue(self.bases,self.multiplicities,self.offset+other) + return NotImplemented + + def __radd__(self, other): + return self.__add__(other) + + def __sub__(self, other): + return self.__add__(-other) + + def __mul__(self, other): + # this should happen when the amplitude is being scaled + if isinstance(other, (float, int, TimeType)): + return ResolutionDependentValue(tuple(b*other for b in self.bases),self.multiplicities,self.offset*other) + return NotImplemented + + def __rmul__(self,other): + return self.__mul__(other) + + def __truediv__(self,other): + return self.__mul__(1/other) + + def __float__(self): + return float(self(resolution=None)) + + def __str__(self): + return f"RDP of {sum(b*m for b,m in zip(self.bases,self.multiplicities)) + self.offset}" + + def __repr__(self): + return "RDP("+",".join([f"{k}="+v.__str__() for k,v in vars(self).items()])+")" + + def __eq__(self,o): + if not isinstance(o,ResolutionDependentValue): + return False + return self.__dict__ == o.__dict__ + + def __hash__(self): + return hash((self.bases,self.offset,self.multiplicities,self.__is_time_or_int)) + + @dataclass(frozen=True) class DepKey: """The key that identifies how a certain set command depends on iteration indices. The factors are rounded with a @@ -24,44 +163,241 @@ class DepKey: These objects allow backends which support it to track multiple amplitudes at once. """ factors: Tuple[int, ...] + domain: DepDomain + _free_upon_loop_exit: Optional[int] = field(hash=False,compare=False) + # _free_upon_loop_exit: Optional[int] + # strategy: DepStrategy + @classmethod - def from_voltages(cls, voltages: Sequence[float], resolution: float): - # remove trailing zeros - while voltages and voltages[-1] == 0: - voltages = voltages[:-1] - return cls(tuple(int(round(voltage / resolution)) for voltage in voltages)) + def from_domain(cls, factors, resolution, domain, free_upon_loop_exit): + # # remove trailing zeros + #why was this done in the first place? this seems to introduce more bugs than it solves + # while factors and factors[-1] == 0: + # factors = factors[:-1] + return cls(tuple(int(round(factor / resolution)) for factor in factors), + domain, _free_upon_loop_exit=free_upon_loop_exit) + + @classmethod + def from_voltages(cls, voltages: Sequence[float], resolution: float, free_upon_loop_exit: Optional[int]=None): + return cls.from_domain(voltages, resolution, DepDomain.VOLTAGE, free_upon_loop_exit) + + @classmethod + def from_lin_times(cls, times: Sequence[float], resolution: float): + return cls.from_domain(times, resolution, DepDomain.TIME_LIN, None) + + +@dataclass +class DummyMeasurementMemory: + measurements: List[MeasurementWindow] = field(default_factory=lambda: []) + + def add_measurements(self, measurements: List[MeasurementWindow]): + self.measurements.extend(measurements) @dataclass class LinSpaceNode: """AST node for a program that supports linear spacing of set points as well as nested sequencing and repetitions""" - - def dependencies(self) -> Mapping[int, set]: + + _cached_body_duration: TimeType|None = field(default=None, kw_only=True) + _measurement_memory: DummyMeasurementMemory|None = field(default_factory=lambda:DummyMeasurementMemory(), kw_only=True) + + def dependencies(self) -> Mapping[GeneralizedChannel, set]: + # doing this as a set _should_ get rid of non-active deps that are one level above? + #!!! + raise NotImplementedError + + @property + def body_duration(self) -> TimeType: + raise NotImplementedError + + def _get_measurement_windows(self) -> Mapping[str, np.ndarray]: + """Private implementation of get_measurement_windows with a slightly different data format for easier tiling. + Returns: + A dictionary (measurement_name -> array) with begin == array[:, 0] and length == array[:, 1] + """ raise NotImplementedError + def get_measurement_windows(self, drop: bool = False) -> Dict[str, Tuple[np.ndarray, np.ndarray]]: + """Iterates over all children and collect the begin and length arrays of each measurement window. + Args: + drop: NO EFFECT CURRENTLY + Returns: + A dictionary (measurement_name -> (begin, length)) with begin and length being :class:`numpy.ndarray` + """ + return {mw_name: (begin_length_list[:, 0], begin_length_list[:, 1]) + for mw_name, begin_length_list in self._get_measurement_windows().items()} + + def _reverse_body(self): + #default: do nothing + return + @dataclass -class LinSpaceHold(LinSpaceNode): +class LinSpaceTopLevel(LinSpaceNode): + + body: Tuple[LinSpaceNode, ...] + _play_marker_when_constant: bool|set[ChannelID] + _defined_channels: TypingSet[ChannelID] + + @property + def play_marker_when_constant(self) -> bool|set[ChannelID]: + return self._play_marker_when_constant + + @property + def body_duration(self) -> TimeType: + if self._cached_body_duration is None: + self._cached_body_duration = sum(b.body_duration for b in self.body) + return self._cached_body_duration + + def _get_measurement_windows(self) -> Mapping[str, np.ndarray]: + """Private implementation of get_measurement_windows with a slightly different data format for easier tiling. + Returns: + A dictionary (measurement_name -> array) with begin == array[:, 0] and length == array[:, 1] + """ + return _get_measurement_windows_loop(self._measurement_memory.measurements,1,self.body) + + def get_defined_channels(self) -> TypingSet[ChannelID]: + return self._defined_channels + + def _reverse_body(self): + self.body = self.body[::-1] + for node in self.body: + node._reverse_body() + +@dataclass +class LinSpaceNodeChannelSpecific(LinSpaceNode): + + channels: Tuple[GeneralizedChannel, ...] + + @property + def play_channels(self) -> Tuple[ChannelID, ...]: + return tuple(ch for ch in self.channels if isinstance(ch,ChannelID)) + + def _get_measurement_windows(self,) -> Mapping[str, np.ndarray]: + """Private implementation of get_measurement_windows with a slightly different data format for easier tiling. + Returns: + A dictionary (measurement_name -> array) with begin == array[:, 0] and length == array[:, 1] + """ + return _get_measurement_windows_leaf(self._measurement_memory.measurements) + + +@dataclass +class LinSpaceHold(LinSpaceNodeChannelSpecific): """Hold voltages for a given time. The voltages and the time may depend on the iteration index.""" - bases: Tuple[float, ...] - factors: Tuple[Optional[Tuple[float, ...]], ...] + bases: Dict[GeneralizedChannel, float] + factors: Dict[GeneralizedChannel, Optional[Tuple[float, ...]]] duration_base: TimeType duration_factors: Optional[Tuple[TimeType, ...]] - - def dependencies(self) -> Mapping[int, set]: - return {idx: {factors} - for idx, factors in enumerate(self.factors) + + def dependencies(self) -> Mapping[DepDomain, Mapping[ChannelID, set]]: + return {dom: {ch: {factors}} + for dom, ch_to_factors in self._dep_by_domain().items() + for ch, factors in ch_to_factors.items() if factors} - + + def _dep_by_domain(self) -> Mapping[DepDomain, Mapping[GeneralizedChannel, set]]: + return {DepDomain.VOLTAGE: self.factors, + DepDomain.TIME_LIN: {DepDomain.TIME_LIN:self.duration_factors}, + } + + @property + def body_duration(self) -> TimeType: + if self.duration_factors: + raise NotImplementedError + if self._cached_body_duration is None: + self._cached_body_duration = self.duration_base + return self._cached_body_duration + @dataclass -class LinSpaceArbitraryWaveform(LinSpaceNode): +class LinSpaceArbitraryWaveform(LinSpaceNodeChannelSpecific): """This is just a wrapper to pipe arbitrary waveforms through the system.""" waveform: Waveform - channels: Tuple[ChannelID, ...] + + def dependencies(self): + return {} + + @property + def body_duration(self) -> TimeType: + if self._cached_body_duration is None: + self._cached_body_duration = self.waveform.duration + return self._cached_body_duration + + +@dataclass +class LinSpaceArbitraryWaveformIndexed(LinSpaceNodeChannelSpecific): + """This is just a wrapper to pipe arbitrary waveforms through the system.""" + waveform: Union[Waveform,WaveformCollection] + + scale_bases: Dict[ChannelID, float] + scale_factors: Dict[ChannelID, Optional[Tuple[float, ...]]] + + offset_bases: Dict[ChannelID, float] + offset_factors: Dict[ChannelID, Optional[Tuple[float, ...]]] + + index_factors: Optional[Dict[StepRegister,Tuple[int, ...]]] = dataclasses.field(default_factory=lambda: None) + + def __post_init__(self): + #somewhat assert the integrity in this case. + if isinstance(self.waveform,WaveformCollection): + assert self.index_factors is not None + + def dependencies(self) -> Mapping[DepDomain, Mapping[GeneralizedChannel, set]]: + return {dom: {ch: {factors}} + for dom, ch_to_factors in self._dep_by_domain().items() + for ch, factors in ch_to_factors.items() + if factors} + + def _dep_by_domain(self) -> Mapping[DepDomain, Mapping[GeneralizedChannel, set]]: + return {DepDomain.WF_SCALE: self.scale_factors, + DepDomain.WF_OFFSET: self.offset_factors, + DepDomain.STEP_INDEX: self.index_factors} + + @property + def step_channels(self) -> Optional[Tuple[StepRegister]]: + return tuple(self.index_factors.keys()) if self.index_factors else () + + @property + def body_duration(self) -> TimeType: + if self._cached_body_duration is None: + self._cached_body_duration = self.waveform.duration + return self._cached_body_duration + + +#!!! this is merely to catch measurements added to a sequencePT... +#(same as LinSpaceRepeat but with count=1; do not inherit to not confuse isinstance checks unintentionally) +@dataclass +class LinSpaceSequence(LinSpaceNode): + body: Tuple[LinSpaceNode, ...] + + def dependencies(self): + dependencies = {} + for node in self.body: + for dom, ch_to_deps in node.dependencies().items(): + for ch, deps in ch_to_deps.items(): + dependencies.setdefault(dom,{}).setdefault(ch, set()).update(deps) + return dependencies + + @property + def body_duration(self) -> TimeType: + if self._cached_body_duration is None: + self._cached_body_duration = sum(b.body_duration for b in self.body) + return self._cached_body_duration + + def _get_measurement_windows(self,) -> Mapping[str, np.ndarray]: + """Private implementation of get_measurement_windows with a slightly different data format for easier tiling. + Returns: + A dictionary (measurement_name -> array) with begin == array[:, 0] and length == array[:, 1] + """ + return _get_measurement_windows_loop(self._measurement_memory.measurements,1,self.body) + + def _reverse_body(self): + self.body = self.body[::-1] + for node in self.body: + node._reverse_body() @dataclass @@ -73,11 +409,30 @@ class LinSpaceRepeat(LinSpaceNode): def dependencies(self): dependencies = {} for node in self.body: - for idx, deps in node.dependencies().items(): - dependencies.setdefault(idx, set()).update(deps) + for dom, ch_to_deps in node.dependencies().items(): + for ch, deps in ch_to_deps.items(): + dependencies.setdefault(dom,{}).setdefault(ch, set()).update(deps) return dependencies - - + + @property + def body_duration(self) -> TimeType: + if self._cached_body_duration is None: + self._cached_body_duration = self.count*sum(b.body_duration for b in self.body) + return self._cached_body_duration + + def _get_measurement_windows(self,) -> Mapping[str, np.ndarray]: + """Private implementation of get_measurement_windows with a slightly different data format for easier tiling. + Returns: + A dictionary (measurement_name -> array) with begin == array[:, 0] and length == array[:, 1] + """ + return _get_measurement_windows_loop(self._measurement_memory.measurements,self.count,self.body) + + def _reverse_body(self): + self.body = self.body[::-1] + for node in self.body: + node._reverse_body() + + @dataclass class LinSpaceIter(LinSpaceNode): """Iteration in linear space are restricted to range 0 to length. @@ -85,16 +440,106 @@ class LinSpaceIter(LinSpaceNode): Offsets and spacing are stored in the hold node.""" body: Tuple[LinSpaceNode, ...] length: int - + + to_be_stepped: bool + def dependencies(self): dependencies = {} for node in self.body: - for idx, deps in node.dependencies().items(): - # remove the last elemt in index because this iteration sets it -> no external dependency - shortened = {dep[:-1] for dep in deps} - if shortened != {()}: - dependencies.setdefault(idx, set()).update(shortened) + for dom, ch_to_deps in node.dependencies().items(): + for ch, deps in ch_to_deps.items(): + # remove the last element in index because this iteration sets it -> no external dependency + shortened = {dep[:-1] for dep in deps} + if shortened != {()}: + dependencies.setdefault(dom,{}).setdefault(ch, set()).update(shortened) return dependencies + + @property + def body_duration(self) -> TimeType: + if self._cached_body_duration is None: + self._cached_body_duration = self.length*sum(b.body_duration for b in self.body) + return self._cached_body_duration + + def _get_measurement_windows(self,) -> Mapping[str, np.ndarray]: + """Private implementation of get_measurement_windows with a slightly different data format for easier tiling. + Returns: + A dictionary (measurement_name -> array) with begin == array[:, 0] and length == array[:, 1] + """ + return _get_measurement_windows_loop(self._measurement_memory.measurements,self.length,self.body) + + def _reverse_body(self): + self.body = self.body[::-1] + for node in self.body: + node._reverse_body() + + +def _get_measurement_windows_leaf(measurements: List[MeasurementWindow]) -> Mapping[str, np.ndarray]: + """Private implementation of get_measurement_windows with a slightly different data format for easier tiling. + Returns: + A dictionary (measurement_name -> array) with begin == array[:, 0] and length == array[:, 1] + """ + temp_meas_windows = defaultdict(list) + if measurements: + for (mw_name, begin, length) in measurements: + temp_meas_windows[mw_name].append((begin, length)) + + for mw_name, begin_length_list in temp_meas_windows.items(): + temp_meas_windows[mw_name] = np.asarray(begin_length_list, dtype=float) + + return temp_meas_windows + + +def _get_measurement_windows_loop(measurements: List[MeasurementWindow], count: int, + body: List[LinSpaceNode]) -> Mapping[str, np.ndarray]: + """Private implementation of get_measurement_windows with a slightly different data format for easier tiling. + Returns: + A dictionary (measurement_name -> array) with begin == array[:, 0] and length == array[:, 1] + """ + temp_meas_windows = defaultdict(list) + temp_meas_windows_children = defaultdict(list) + if measurements: + for (mw_name, begin, length) in measurements: + temp_meas_windows[mw_name].append((begin, length)) + + for mw_name, begin_length_list in temp_meas_windows.items(): + temp_meas_windows[mw_name] = [np.asarray(begin_length_list, dtype=float)] + + offset = TimeType(0) + for child in body: + for mw_name, begins_length_array in child._get_measurement_windows().items(): + begins_length_array[:, 0] += float(offset) + temp_meas_windows_children[mw_name].append(begins_length_array) + offset += child.body_duration + + body_duration = float(offset) + + # formatting like this for easier debugging + result = {} + + # repeat and add repetition based offset + #!!! but only for children + for mw_name, begin_length_list in temp_meas_windows_children.items(): + result[mw_name] = _repeat_loop_measurements(begin_length_list, count, body_duration) + + for mw_name, begin_length_list in temp_meas_windows.items(): + result[mw_name] = _repeat_loop_measurements(begin_length_list, 1, body_duration) + + return result + + +def _repeat_loop_measurements(begin_length_list: List[np.ndarray], + repetition_count: int, + body_duration: float + ) -> np.ndarray: + temp_begin_length_array = np.concatenate(begin_length_list) + + begin_length_array = np.tile(temp_begin_length_array, (repetition_count, 1)) + + shaped_begin_length_array = np.reshape(begin_length_array, (repetition_count, -1, 2)) + + shaped_begin_length_array[:, :, 0] += (np.arange(repetition_count) * body_duration)[:, np.newaxis] + + return begin_length_array class LinSpaceBuilder(ProgramBuilder): @@ -106,43 +551,89 @@ class LinSpaceBuilder(ProgramBuilder): Arbitrary waveforms are not implemented yet """ - def __init__(self, channels: Tuple[ChannelID, ...]): + def __init__(self, + # identifier, loop_index or ForLoopPT which is to be stepped: + to_stepping_repeat: TypingSet[Union[str,'ForLoopPT']] = None, + # this can indicate ChannelIDs (of _MarkerChannel_ on which to activate marker output in constant waveforms.) + # this can only be on or off for the entire program as those waveform samples are reused for efficiency. + # TODO: that could be adapted but the necessity so far was marginal. + play_marker_when_constant: bool|set[ChannelID] = False, + #loop indices that are supposed to be rolled out: + to_rollout: set[str]|None = None, + #tuple of (max_wf_len,max_total_len) of successive wfs that should be + #concatenated to save waveform memory + concatenate_wfs: tuple[TimeType|float,TimeType|float]|None = None, + ): super().__init__() - self._name_to_idx = {name: idx for idx, name in enumerate(channels)} - self._idx_to_name = channels + # self._name_to_idx = {name: idx for idx, name in enumerate(channels)} + # self._voltage_idx_to_name = channels self._stack = [[]] self._ranges = [] - + self._to_stepping_repeat = to_stepping_repeat or set() + self._play_marker_when_constant = play_marker_when_constant + self._pt_channels = None + self._meas_queue = [] + self._reversed_counter = 0 + self._to_rollout = to_rollout or set() + self._current_rollout_vars = {} + self._concatenate_wfs = concatenate_wfs + + assert not any(d in self._to_rollout for d in self._to_stepping_repeat) + def _root(self): return self._stack[0] def _get_rng(self, idx_name: str) -> range: return self._get_ranges()[idx_name] - def inner_scope(self, scope: Scope) -> Scope: + def inner_scope(self, scope: Scope, pt_obj: Optional['ForLoopPT']=None) -> Scope: """This function is necessary to inject program builder specific parameter implementations into the build process.""" if self._ranges: - name, _ = self._ranges[-1] - return scope.overwrite({name: SimpleExpression(base=0, offsets={name: 1})}) + name, rng, reverse = self._ranges[-1] + + #hack in: if a loop was rolled out, skip the simpleexpression assignment + if name in self._current_rollout_vars: + pass + + elif pt_obj and (pt_obj in self._to_stepping_repeat or pt_obj.identifier in self._to_stepping_repeat \ + or pt_obj.loop_index in self._to_stepping_repeat): + # the nesting level should be simply the amount of this type in the scope. + nest = len(tuple(v for v in scope.values() if isinstance(v,SimpleExpressionStepped))) + return_scope = scope.overwrite({name:SimpleExpressionStepped( + base=0,offsets={name: 1},step_nesting_level=nest+1,rng=rng,reverse=reverse)}) + else: + if isinstance(scope.get(name,None),SimpleExpressionStepped): + return_scope = scope + else: + return_scope = scope.overwrite({name: SimpleExpression(base=0, offsets={name: 1})}) else: - return scope + return_scope = scope + + return return_scope.overwrite(self._current_rollout_vars) def _get_ranges(self): - return dict(self._ranges) - + return dict(r[:2] for r in self._ranges) + + def _get_range_reversals(self): + return {r[0]:r[2] for r in self._ranges} + def hold_voltage(self, duration: HardwareTime, voltages: Mapping[ChannelID, HardwareVoltage]): - voltages = sorted((self._name_to_idx[ch_name], value) for ch_name, value in voltages.items()) - voltages = [value for _, value in voltages] + # voltages = sorted((self._name_to_idx[ch_name], value) for ch_name, value in voltages.items()) + # voltages = [value for _, value in voltages] ranges = self._get_ranges() - factors = [] - bases = [] - for value in voltages: - if isinstance(value, float): - bases.append(value) - factors.append(None) + reversals = self._get_range_reversals() + factors = {} + bases = {} + duration_base = duration + duration_factors = None + + for ch_name,value in voltages.items(): + if isinstance(value, (float, int)): + bases[ch_name] = float(value) + factors[ch_name] = None continue offsets = value.offsets base = value.base @@ -152,33 +643,206 @@ def hold_voltage(self, duration: HardwareTime, voltages: Mapping[ChannelID, Hard step = 0. offset = offsets.get(rng_name, None) if offset: - start += rng.start * offset - step += rng.step * offset + if reversals[rng_name]: + start += (rng.stop-1) * offset + step += -rng.step * offset + else: + start += rng.start * offset + step += rng.step * offset base += start incs.append(step) - factors.append(tuple(incs)) - bases.append(base) + factors[ch_name] = tuple(incs) + bases[ch_name] = base if isinstance(duration, SimpleExpression): - duration_factors = duration.offsets + # duration_factors = duration.offsets + # duration_base = duration.base + duration_offsets = duration.offsets duration_base = duration.base - else: - duration_base = duration - duration_factors = None - - set_cmd = LinSpaceHold(bases=tuple(bases), - factors=tuple(factors), + duration_factors = [] + for rng_name, rng in ranges.items(): + start = TimeType(0) + step = TimeType(0) + offset = duration_offsets.get(rng_name, None) + if offset: + if reversals[rng_name]: + start += (rng.stop-1) * offset + step += -rng.step * offset + else: + start += rng.start * offset + step += rng.step * offset + duration_base += start + duration_factors.append(step) + + + set_cmd = LinSpaceHold(channels=tuple(voltages.keys()), + bases=bases, + factors=factors, duration_base=duration_base, - duration_factors=duration_factors) + duration_factors=tuple(duration_factors) if duration_factors else None, + ) self._stack[-1].append(set_cmd) - - def play_arbitrary_waveform(self, waveform: Waveform): - return self._stack[-1].append(LinSpaceArbitraryWaveform(waveform, self._idx_to_name)) + if self._meas_queue: + meas = self._meas_queue.pop() + if self._reversed_counter%2: + if duration_factors: + raise NotImplementedError + duration = duration_base + meas = [ + (name, duration - (begin + length), length) + for name, begin, length in meas + ] + self._stack[-1][-1]._measurement_memory.add_measurements(meas) + + + def play_arbitrary_waveform(self, waveform: Union[Waveform,WaveformCollection], + stepped_var_list: Optional[List[Tuple[str,SimpleExpressionStepped]]] = None): + + if self._meas_queue: + meas = self._meas_queue.pop() + if self._reversed_counter%2: + duration = waveform.duration + meas = [ + (name, duration - (begin + length), length) + for name, begin, length in meas + ] + else: + meas = None + + # recognize voltage trafo sweep syntax from a transforming waveform. + # other sweepable things may need different approaches. + if not isinstance(waveform,(TransformingWaveform,WaveformCollection)): + assert stepped_var_list is None + ret = self._stack[-1].append(LinSpaceArbitraryWaveform(waveform=waveform.reversed() if self._reversed_counter%2 else waveform, + channels=waveform.defined_channels,)) + if meas: + self._stack[-1][-1]._measurement_memory.add_measurements(meas) + return ret + + + #should be sufficient to test the first wf, as all should have the same trafo + waveform_propertyextractor = waveform + while isinstance(waveform_propertyextractor,WaveformCollection): + waveform_propertyextractor = waveform_propertyextractor.waveform_collection[0] + + if isinstance(waveform_propertyextractor,TransformingWaveform): + #test for transformations that contain SimpleExpression + wf_transformation = waveform_propertyextractor.transformation + + # chainedTransformation should now have flat hierachy. + collected_trafos, dependent_trafo_vals_flag = collect_scaling_and_offset_per_channel( + waveform_propertyextractor.defined_channels,wf_transformation) + else: + dependent_trafo_vals_flag = False + + #fast track + if not dependent_trafo_vals_flag and not isinstance(waveform,WaveformCollection): + ret = self._stack[-1].append(LinSpaceArbitraryWaveform(waveform=waveform.reversed() if self._reversed_counter%2 else waveform, + channels=waveform.defined_channels,)) + if meas: + self._stack[-1][-1]._measurement_memory.add_measurements(meas) + return ret + + ranges = self._get_ranges() + reversals = self._get_range_reversals() + ranges_list = list(ranges) + index_factors = {} + + if stepped_var_list: + # the index ordering shall be with the last index changing fastest. + # (assuming the WaveformColleciton will be flattened) + # this means increments on last shall be 1, next lower 1*len(fastest), + # next 1*len(fastest)*len(second_fastest),... -> product(higher_reg_range_lens) + # total_reg_len = len(stepped_var_list) + reg_lens = tuple(len(v.rng) for s,v in stepped_var_list) + total_rng_len = np.cumprod(reg_lens)[-1] + reg_incr_values = list(np.cumprod(reg_lens[::-1]))[::-1][1:] + [1,] + + assert isinstance(waveform,WaveformCollection) + + for reg_num,(var_name,value) in enumerate(stepped_var_list): + # this should be given anyway: + assert isinstance(value, SimpleExpressionStepped) + + """ + # by definition, every var_name should be relevant for the waveform/ + # has been included in the nested WaveformCollection. + # so, each time this code is called, a new waveform node containing this is called, + # and one can/must increase the offset by the + + # assert value.base += total_rng_len + """ + + assert value.base == 0 + + offsets = value.offsets + #there can never be more than one key in this + # (nowhere is an evaluation of arithmetics between steppings intended) + assert len(offsets)==1 + assert all(v==1 for v in offsets.values()) + assert set(offsets.keys())=={var_name,} + + # this makes the search through ranges pointless; have tuple of zeros + # except for one inc at the position of the stepvar in the ranges dict + + incs = [0 for v in ranges_list] + incs[ranges_list.index(var_name)] = reg_incr_values[reg_num] + + #needs to be new "channel" each time? should be handled by metaclass + reg_channel = StepRegister(var_name,reg_num) + index_factors[reg_channel] = tuple(incs) + # bases[reg_channel] = value.base + + scale_factors, offset_factors = {}, {} + scale_bases, offset_bases = {}, {} + + if dependent_trafo_vals_flag: + for ch_name,scale_offset_dict in collected_trafos.items(): + for bases,factors,key in zip((scale_bases, offset_bases),(scale_factors, offset_factors),('s','o')): + value = scale_offset_dict[key] + if isinstance(value, float): + bases[ch_name] = value + factors[ch_name] = None + continue + offsets = value.offsets + base = value.base + incs = [] + for rng_name, rng in ranges.items(): + start = 0. + step = 0. + offset = offsets.get(rng_name, None) + if offset: + if reversals[rng_name]: + start += (rng.stop-1) * offset + step += -rng.step * offset + else: + start += rng.start * offset + step += rng.step * offset + base += start + incs.append(step) + factors[ch_name] = tuple(incs) + bases[ch_name] = base + + # assert ba + + ret = self._stack[-1].append(LinSpaceArbitraryWaveformIndexed( + waveform=waveform.reversed() if self._reversed_counter%2 and not stepped_var_list else waveform, + channels=waveform_propertyextractor.defined_channels.union(set(index_factors.keys())), + scale_bases=scale_bases, + scale_factors=scale_factors, + offset_bases=offset_bases, + offset_factors=offset_factors, + index_factors=index_factors, + )) + if meas: + self._stack[-1][-1]._measurement_memory.add_measurements(meas) + return ret def measure(self, measurements: Optional[Sequence[MeasurementWindow]]): - """Ignores measurements""" - pass + # """Ignores measurements""" + + self._meas_queue.append(measurements) def with_repetition(self, repetition_count: RepetitionCount, measurements: Optional[Sequence[MeasurementWindow]] = None) -> Iterable['ProgramBuilder']: @@ -188,32 +852,265 @@ def with_repetition(self, repetition_count: RepetitionCount, yield self blocks = self._stack.pop() if blocks: - self._stack[-1].append(LinSpaceRepeat(body=tuple(blocks), count=repetition_count)) - + repeat = LinSpaceRepeat(body=tuple(blocks), count=repetition_count,) + if measurements: repeat._measurement_memory.add_measurements(measurements) + self._stack[-1].append(repeat) + + @contextlib.contextmanager + def time_reversed(self, measurements: Optional[Sequence[MeasurementWindow]] = None) -> ContextManager['ProgramBuilder']: + if measurements: raise NotImplementedError('TODO') + self._reversed_counter += 1 + self._stack.append([]) + yield self + blocks = self._stack.pop() + if blocks: + for block in blocks: + block._reverse_body() + self._stack[-1].extend(blocks[::-1]) + self._reversed_counter += 1 + @contextlib.contextmanager def with_sequence(self, measurements: Optional[Sequence[MeasurementWindow]] = None) -> ContextManager['ProgramBuilder']: + self._stack.append([]) yield self + blocks = self._stack.pop() + if blocks: + + #try to concatenate waveforms that are plain succession of things? + #as to_single_waveform is not implemented yet here / might serve different purpose + if self._concatenate_wfs is not None: + blocks = _get_concatenated_blocks(blocks,*self._concatenate_wfs) + + sequence = LinSpaceSequence(body=tuple(blocks),) + if measurements: sequence._measurement_memory.add_measurements(measurements) + self._stack[-1].append(sequence) def new_subprogram(self, global_transformation: 'Transformation' = None) -> ContextManager['ProgramBuilder']: + + inner_builder = LinSpaceBuilder(self._to_stepping_repeat,self._play_marker_when_constant,self._to_rollout) + yield inner_builder + inner_program = inner_builder.to_program() + + # if inner_program is not None: + + # # measurements = [(name, begin, length) + # # for name, (begins, lengths) in inner_program.get_measurement_windows().items() + # # for begin, length in zip(begins, lengths)] + # # self._top.add_measurements(measurements) + # waveform = to_waveform(inner_program,self._idx_to_name) + # if global_transformation is not None: + # waveform = TransformingWaveform.from_transformation(waveform, global_transformation) + # self.play_arbitrary_waveform(waveform) + raise NotImplementedError('Not implemented yet (postponed)') def with_iteration(self, index_name: str, rng: range, + pt_obj: 'ForLoopPT', measurements: Optional[Sequence[MeasurementWindow]] = None) -> Iterable['ProgramBuilder']: if len(rng) == 0: return + + #semi-hacky: if the body duration expression has variables, + #this must lead to rolling out the pulse with inner mapping + #this effectively disables the efficient t-looping for ConstantPTs, but + #that would probably be rarely used anyway + #(cause measurements not implemented, which is more of a headache than this). + if (durvars:=pt_obj.body.duration.variables) or index_name in self._to_rollout: + if any(d not in self._current_rollout_vars for d in durvars) or index_name in self._to_rollout: + # assert index_name in durvars, 'expected iteration index ot be in duration expression of body' #doesn't have to be the case if nested loops + with self.with_sequence(measurements): + for value in rng: + self._current_rollout_vars[index_name] = value + yield self + self._current_rollout_vars.pop(index_name) + return + self._stack.append([]) - self._ranges.append((index_name, rng)) + self._ranges.append((index_name, rng, self._reversed_counter%2)) yield self cmds = self._stack.pop() self._ranges.pop() if cmds: - self._stack[-1].append(LinSpaceIter(body=tuple(cmds), length=len(rng))) - - def to_program(self) -> Optional[Sequence[LinSpaceNode]]: + stepped = False + if pt_obj in self._to_stepping_repeat or pt_obj.identifier in self._to_stepping_repeat \ + or pt_obj.loop_index in self._to_stepping_repeat: + stepped = True + iteration = LinSpaceIter(body=tuple(cmds), length=len(rng), to_be_stepped=stepped) + if measurements: iteration._measurement_memory.add_measurements(measurements) + self._stack[-1].append(iteration) + + + def evaluate_nested_stepping(self, scope: Scope, parameter_names: set[str]) -> bool: + + stepped_vals = {k:v for k,v in scope.items() if isinstance(v,SimpleExpressionStepped)} + #when overlap, then the PT is part of the stepped progression + if stepped_vals.keys() & parameter_names: + return True + return False + + def dispatch_to_stepped_wf_or_hold(self, + build_func: Callable[[Mapping[str, Real],Dict[ChannelID, Optional[ChannelID]]],Optional[Waveform]], + build_parameters: Scope, + parameter_names: set[str], + channel_mapping: Dict[ChannelID, Optional[ChannelID]], + #measurements tbd + global_transformation: Optional["Transformation"], + _pow_2_divisor: int + ) -> None: + + stepped_vals = {k:v for k,v in build_parameters.items() + if isinstance(v,SimpleExpressionStepped) and k in parameter_names} + sorted_steps = list(sorted(stepped_vals.items(), key=lambda item: item[1].step_nesting_level)) + + def build_nested_wf_colls(remaining_ranges: List[Tuple], fixed_elements: List[Tuple] = []): + + if len(remaining_ranges) == 0: + inner_scope = build_parameters.overwrite(dict(fixed_elements)) + #by now, no SimpleExpressionStepped should remain here that is relevant for the current loop. + assert not any(isinstance(v,SimpleExpressionStepped) for k,v in inner_scope.items() if k in parameter_names) + waveform = build_func(inner_scope,channel_mapping=channel_mapping) + if global_transformation: + waveform = TransformingWaveform.from_transformation(waveform, global_transformation) + + #hacky + waveform._pow_2_divisor = _pow_2_divisor + + #this case should not happen, should have been caught beforehand: + # or maybe not, if e.g. amp is zero for some reason + # assert waveform.constant_value_dict() is None + return waveform.reversed() if self._reversed_counter%2 else waveform + else: + if remaining_ranges[0][1].reverse: + direction = -1 + else: + direction = 1 + return WaveformCollection( + tuple(build_nested_wf_colls(remaining_ranges[1:], + fixed_elements+[(remaining_ranges[0][0],remaining_ranges[0][1].value({remaining_ranges[0][0]:it})),]) + for it in remaining_ranges[0][1].rng[::direction])) + + + # not completely convinced this works as intended. + # doesn't this - also in pulse_template program creation - lead to complications with ParallelConstantChannelTrafo? + # dirty, quick workaround - if this doesnt work, assume it is also not constant: + try: + potential_waveform = build_func(build_parameters,channel_mapping=channel_mapping) + if global_transformation: + potential_waveform = TransformingWaveform.from_transformation(potential_waveform, global_transformation) + constant_values = potential_waveform.constant_value_dict() + except: + constant_values = None + + if constant_values is None: + wf_coll = build_nested_wf_colls(sorted_steps) + self.play_arbitrary_waveform(wf_coll,sorted_steps) + else: + # in the other case, all dependencies _should_ be on amp and length, which is covered by hold appropriately + # and doesn't need to be stepped? + self.hold_voltage(potential_waveform.duration, constant_values) + + def to_program(self, defined_channels: TypingSet[ChannelID]) -> Optional[Sequence[LinSpaceNode]]: + assert not self._meas_queue if self._root(): - return self._root() - + return LinSpaceTopLevel(body=tuple(self._root()), + _play_marker_when_constant=self._play_marker_when_constant, + _defined_channels=defined_channels) + + +def collect_scaling_and_offset_per_channel(channels: Sequence[ChannelID], + transformation: Transformation) \ + -> Tuple[Dict[ChannelID,Dict[str,Union[NumVal,SimpleExpression]]], bool]: + + ch_trafo_dict = {ch: {'s':1.,'o':0.} for ch in channels} + + # allowed_trafos = {IdentityTransformation,} + if not isinstance(transformation,ChainedTransformation): + transformations = (transformation,) + else: + transformations = transformation.transformations + + is_dependent_flag = [] + + for trafo in transformations: + #first elements of list are applied first in trafos. + assert trafo.is_constant_invariant() + if isinstance(trafo,ParallelChannelTransformation): + for ch,val in trafo._channels.items(): + is_dependent_flag.append(trafo.contains_sweepval) + # assert not ch in ch_trafo_dict.keys() + # the waveform is sampled with these values taken into account, no change needed. + # ch_trafo_dict[ch]['o'] = val + # ch_trafo_dict.setdefault(ch,{'s':1.,'o':val}) + elif isinstance(trafo,ScalingTransformation): + is_dependent_flag.append(trafo.contains_sweepval) + for ch,val in trafo._factors.items(): + try: + ch_trafo_dict[ch]['s'] = reduce_non_swept(ch_trafo_dict[ch]['s']*val) + ch_trafo_dict[ch]['o'] = reduce_non_swept(ch_trafo_dict[ch]['o']*val) + except TypeError as e: + print('Attempting scale sweep of other sweep val') + raise e + elif isinstance(trafo,OffsetTransformation): + is_dependent_flag.append(trafo.contains_sweepval) + for ch,val in trafo._offsets.items(): + ch_trafo_dict[ch]['o'] += val + elif isinstance(trafo,IdentityTransformation): + continue + elif isinstance(trafo,ChainedTransformation): + raise RuntimeError() + else: + raise NotImplementedError() + + return ch_trafo_dict, any(is_dependent_flag) + + +def reduce_non_swept(val: Union[SimpleExpression,NumVal]) -> Union[SimpleExpression,NumVal]: + if isinstance(val,SimpleExpression) and all(v==0 for v in val.offsets.values()): + return val.base + return val + + +def _get_concatenated_blocks(blocks: Sequence[LinSpaceNode], + max_wf_len: TimeType|float, + max_total_len: TimeType|float) -> Sequence[LinSpaceNode]: + concat_blocks = [] + current_blocks, current_len = [], 0. + + def combine_nodes_into_one(nodes: Sequence[LinSpaceArbitraryWaveform]) -> LinSpaceArbitraryWaveform: + assert all(isinstance(n,LinSpaceArbitraryWaveform) for n in nodes), 'can only combine type LinSpaceArbitraryWaveform' + seq_wf = SequenceWaveform.from_sequence([n.waveform for n in nodes]) + combined_labwf = LinSpaceArbitraryWaveform(waveform=seq_wf,channels=seq_wf.defined_channels) + return combined_labwf + + + for node in blocks: + if not isinstance(node,LinSpaceArbitraryWaveform): + if current_blocks: concat_blocks.append(combine_nodes_into_one(current_blocks)) + concat_blocks.append(node) + current_blocks, current_len = [], 0. + else: + append_to_current = True + if len(current_blocks): + #we'll ignore the difficulty of the hacky feature of the reduced sample rate playback + #by starting new subblocks when that quantity changes + if current_blocks[-1].waveform._pow_2_divisor != node.waveform._pow_2_divisor or\ + node.waveform.duration > max_wf_len or\ + current_len+node.waveform.duration > max_total_len: + append_to_current = False + + if append_to_current: + current_blocks.append(node) + current_len += node.waveform.duration + else: + concat_blocks.append(combine_nodes_into_one(current_blocks)) + current_blocks = [node,] + current_len = node.waveform.duration + + #append last + if current_blocks: concat_blocks.append(combine_nodes_into_one(current_blocks)) + return concat_blocks + @dataclass class LoopLabel: @@ -223,22 +1120,35 @@ class LoopLabel: @dataclass class Increment: - channel: int - value: float - dependency_key: DepKey - + channel: Optional[GeneralizedChannel] + value: Union[ResolutionDependentValue,Tuple[ResolutionDependentValue]] + key: DepKey + + def __hash__(self): + return hash((type(self),self.channel,self.value,self.key)) + + def __str__(self): + return "Increment("+",".join([f"{k}="+v.__str__() for k,v in vars(self).items()])+")" @dataclass class Set: - channel: int - value: float - key: DepKey = dataclasses.field(default_factory=lambda: DepKey(())) - - + channel: Optional[GeneralizedChannel] + value: Union[ResolutionDependentValue,Tuple[ResolutionDependentValue]] + key: DepKey = dataclasses.field(default_factory=lambda: DepKey((),DepDomain.NODEP,None)) + + def __hash__(self): + return hash((type(self),self.channel,self.value,self.key)) + + def __str__(self): + return "Set("+",".join([f"{k}="+v.__str__() for k,v in vars(self).items()])+")" + @dataclass class Wait: - duration: TimeType + duration: Optional[TimeType] + key_by_domain: Dict[DepDomain,DepKey] = dataclasses.field(default_factory=lambda: {}) + def __hash__(self): + return hash((type(self),self.duration,frozenset(self.key_by_domain.items()))) @dataclass class LoopJmp: @@ -247,8 +1157,19 @@ class LoopJmp: @dataclass class Play: - waveform: Waveform - channels: Tuple[ChannelID] + waveform: Union[Waveform,WaveformCollection] + play_channels: Tuple[ChannelID] + step_channels: Tuple[StepRegister] = () + #actually did the name + keys_by_domain_by_ch: Dict[ChannelID,Dict[DepDomain,DepKey]] = None + + def __post_init__(self): + if self.keys_by_domain_by_ch is None: + self.keys_by_domain_by_ch = {ch: {} for ch in self.play_channels+self.step_channels} + + def __hash__(self): + return hash((type(self),self.waveform,self.play_channels,self.step_channels, + frozenset((k,frozenset(d.items())) for k,d in self.keys_by_domain_by_ch.items()))) Command = Union[Increment, Set, LoopLabel, LoopJmp, Wait, Play] @@ -259,11 +1180,13 @@ class DepState: base: float iterations: Tuple[int, ...] - def required_increment_from(self, previous: 'DepState', factors: Sequence[float]) -> float: - assert len(self.iterations) == len(previous.iterations) + def required_increment_from(self, previous: 'DepState', + factors: Sequence[float]) -> ResolutionDependentValue: + assert len(self.iterations) == len(previous.iterations) #or (all(self.iterations)==0 and all(previous.iterations)==0) assert len(self.iterations) == len(factors) - increment = self.base - previous.base + # increment = self.base - previous.base + res_bases, res_mults, offset = [], [], self.base - previous.base for old, new, factor in zip(previous.iterations, self.iterations, factors): # By convention there are only two possible values for each integer here: 0 or the last index # The three possible increments are none, regular and jump to next line @@ -276,13 +1199,18 @@ def required_increment_from(self, previous: 'DepState', factors: Sequence[float] assert old == 0 # regular iteration, although the new value will probably be > 1, the resulting increment will be # applied multiple times so only one factor is needed. - increment += factor - + # increment += factor + res_bases.append(factor) + res_mults.append(1) + else: assert new == 0 # we need to jump back. The old value gives us the number of increments to reverse - increment -= factor * old - return increment + # increment -= factor * old + res_bases.append(-factor) + res_mults.append(old) + + return ResolutionDependentValue(res_bases,res_mults,offset) @dataclass @@ -292,119 +1220,297 @@ class _TranslationState: label_num: int = dataclasses.field(default=0) commands: List[Command] = dataclasses.field(default_factory=list) iterations: List[int] = dataclasses.field(default_factory=list) - active_dep: Dict[int, DepKey] = dataclasses.field(default_factory=dict) - dep_states: Dict[int, Dict[DepKey, DepState]] = dataclasses.field(default_factory=dict) - plain_voltage: Dict[int, float] = dataclasses.field(default_factory=dict) + active_dep: Dict[GeneralizedChannel, Dict[DepDomain, DepKey]] = dataclasses.field(default_factory=dict) + dep_states: Dict[GeneralizedChannel, Dict[DepKey, DepState]] = dataclasses.field(default_factory=dict) + plain_value: Dict[GeneralizedChannel, Dict[DepDomain,float]] = dataclasses.field(default_factory=dict) resolution: float = dataclasses.field(default_factory=lambda: DEFAULT_INCREMENT_RESOLUTION) - + resolution_time: float = dataclasses.field(default_factory=lambda: DEFAULT_TIME_RESOLUTION) + nesting_lvl_to_loop_label: List[int] = dataclasses.field(default_factory=lambda: []) + def new_loop(self, count: int): label = LoopLabel(self.label_num, count) jmp = LoopJmp(self.label_num) self.label_num += 1 return label, jmp - def get_dependency_state(self, dependencies: Mapping[int, set]): - return { - self.dep_states.get(ch, {}).get(DepKey.from_voltages(dep, self.resolution), None) - for ch, deps in dependencies.items() - for dep in deps - } - - def set_voltage(self, channel: int, value: float): - key = DepKey(()) - if self.active_dep.get(channel, None) != key or self.plain_voltage.get(channel, None) != value: - self.commands.append(Set(channel, value, key)) - self.active_dep[channel] = key - self.plain_voltage[channel] = value - - def _add_repetition_node(self, node: LinSpaceRepeat): + def get_dependency_state(self, dependencies: Mapping[DepDomain, Mapping[GeneralizedChannel, set]]): + dom_to_ch_to_depstates = {} + + for dom, ch_to_deps in dependencies.items(): + dom_to_ch_to_depstates.setdefault(dom,{}) + for ch, deps in ch_to_deps.items(): + dom_to_ch_to_depstates[dom].setdefault(ch,set()) + for dep in deps: + dom_to_ch_to_depstates[dom][ch].add(self.dep_states.get(ch, {}).get( + DepKey.from_domain(dep, self.resolution, dom, None),None)) + + return dom_to_ch_to_depstates + # return { + # dom: self.dep_states.get(ch, {}).get(DepKey.from_domain(dep, self.resolution, dom), + # None) + # for dom, ch_to_deps in dependencies.items() + # for ch, deps in ch_to_deps.items() + # for dep in deps + # } + + def compare_ignoring_post_trailing_zeros(self, + pre_state: Mapping[DepDomain, Mapping[GeneralizedChannel, set]], + post_state: Mapping[DepDomain, Mapping[GeneralizedChannel, set]]) -> bool: + + def reduced_or_none(dep_state: DepState) -> Union[DepState,None]: + new_iterations = tuple(dropwhile(lambda x: x == 0, reversed(dep_state.iterations)))[::-1] + return DepState(dep_state.base, new_iterations) if len(new_iterations)>0 else None + + has_changed = False + dom_keys = set(pre_state.keys()).union(post_state.keys()) + for dom_key in dom_keys: + pre_state_dom, post_state_dom = pre_state.get(dom_key,{}), post_state.get(dom_key,{}) + ch_keys = set(pre_state_dom.keys()).union(post_state_dom.keys()) + for ch_key in ch_keys: + pre_state_dom_ch, post_state_dom_ch = pre_state_dom.get(ch_key,set()), post_state_dom.get(ch_key,set()) + # reduce the depStates to the ones which do not just contain zeros + reduced_pre_set = set(reduced_or_none(dep_state) for dep_state in pre_state_dom_ch + if dep_state is not None) - {None} + reduced_post_set = set(reduced_or_none(dep_state) for dep_state in post_state_dom_ch + if dep_state is not None) - {None} + + if not reduced_post_set <= reduced_pre_set: + has_changed == True + + return has_changed + + def set_voltage(self, channel: ChannelID, value: float): + self.set_non_indexed_value(channel, value, domain=DepDomain.VOLTAGE, always_emit_set=True) + + def set_wf_scale(self, channel: ChannelID, value: float): + self.set_non_indexed_value(channel, value, domain=DepDomain.WF_SCALE) + + def set_wf_offset(self, channel: ChannelID, value: float): + self.set_non_indexed_value(channel, value, domain=DepDomain.WF_OFFSET) + + def set_non_indexed_value(self, channel: GeneralizedChannel, value: float, + domain: DepDomain, always_emit_set: bool=False): + key = DepKey((),domain,None) + # I do not completely get why it would have to be set again if not in active dep. + # if not key != self.active_dep.get(channel, None) or + if self.plain_value.get(channel, {}).get(domain, None) != value or always_emit_set: + self.commands.append(Set(channel, ResolutionDependentValue((),(),offset=value), key)) + # there has to be no active dep when the value is not indexed? + # self.active_dep.setdefault(channel,{})[DepDomain.NODEP] = key + self.plain_value.setdefault(channel,{}) + self.plain_value[channel][domain] = value + + # def _add_repetition_node(self, node: LinSpaceRepeat): + # pre_dep_state = self.get_dependency_state(node.dependencies()) + # label, jmp = self.new_loop(node.count) + # initial_position = len(self.commands) + # self.commands.append(label) + # self.add_node(node.body) + # post_dep_state = self.get_dependency_state(node.dependencies()) + # # the last index in the iterations may not be initialized in pre_dep_state if the outer loop only sets an index + # # after this loop is in the sequence of the current level, + # # meaning that an trailing 0 at the end of iterations of each depState in the post_dep_state + # # should be ignored when comparing. + # # zeros also should only mean a "Set" command, which is not harmful if executed multiple times. + # # if pre_dep_state != post_dep_state: + # if self.compare_ignoring_post_trailing_zeros(pre_dep_state,post_dep_state): + # # hackedy + # self.commands.pop(initial_position) + # self.commands.append(label) + # label.count -= 1 + # self.add_node(node.body) + # self.commands.append(jmp) + + + def _add_repetition_node(self, node: LinSpaceRepeat, + safe_rep_possible: bool = False, + ): pre_dep_state = self.get_dependency_state(node.dependencies()) label, jmp = self.new_loop(node.count) initial_position = len(self.commands) self.commands.append(label) - self.add_node(node.body) + self.add_node(node.body,safe_rep_possible) post_dep_state = self.get_dependency_state(node.dependencies()) - if pre_dep_state != post_dep_state: + # the last index in the iterations may not be initialized in pre_dep_state if the outer loop only sets an index + # after this loop is in the sequence of the current level, + # meaning that a trailing 0 at the end of iterations of each depState in the post_dep_state + # should be ignored when comparing. + # zeros also should only mean a "Set" command, which is not harmful if executed multiple times. + # if pre_dep_state != post_dep_state: + #EDIT: even this is not enough it seems; if a dependency from an outer + # loop is present that the repetition does not know about, this is still necessary. + # why not always in the first place? + # if self.compare_ignoring_post_trailing_zeros(pre_dep_state,post_dep_state): + if not safe_rep_possible: # hackedy self.commands.pop(initial_position) self.commands.append(label) label.count -= 1 self.add_node(node.body) self.commands.append(jmp) - + def _add_iteration_node(self, node: LinSpaceIter): + self.iterations.append(0) + if node.length > 1: + label, jmp = self.new_loop(node.length - 1) + self.nesting_lvl_to_loop_label.append(label.idx) self.add_node(node.body) if node.length > 1: - self.iterations[-1] = node.length - label, jmp = self.new_loop(node.length - 1) + self.iterations[-1] = node.length - 1 self.commands.append(label) self.add_node(node.body) self.commands.append(jmp) + self._free_registers(label.idx) self.iterations.pop() - - def _set_indexed_voltage(self, channel: int, base: float, factors: Sequence[float]): - dep_key = DepKey.from_voltages(voltages=factors, resolution=self.resolution) + self.nesting_lvl_to_loop_label.pop() + + def _free_registers(self,label:int): + for ch,dep_state_dict in self.dep_states.items(): + for depkey,dep_state in list(dep_state_dict.items()): + # print(f'LOOKING AT {depkey._free_upon_loop_exit=},{label=}') + # the _free_upon_loop_exit are not compared/hashed, meaning the + # first entry made, stemming from the lowest nesting level, + # will have the lowest number, meaning it should only be removed + # once this level is reached again. + if depkey._free_upon_loop_exit==label: + # print(f'RESETTING {depkey=}') + # dep_state_dict[depkey] = 0 + dep_state_dict.pop(depkey) + + + def _set_indexed_voltage(self, channel: ChannelID, base: float, factors: Sequence[float]): + free_upon_nesting_exit = next((i for i, x in enumerate(factors) if x != 0), len(factors)) + key = DepKey.from_voltages(voltages=factors, resolution=self.resolution, free_upon_loop_exit=self.nesting_lvl_to_loop_label[free_upon_nesting_exit]) + self.set_indexed_value(key, channel, base, factors, domain=DepDomain.VOLTAGE, always_emit_incr=True) + + def _set_indexed_lin_time(self, base: TimeType, factors: Sequence[TimeType]): + key = DepKey.from_lin_times(times=factors, resolution=self.resolution) + self.set_indexed_value(key, DepDomain.TIME_LIN, base, factors, domain=DepDomain.TIME_LIN) + + def set_indexed_value(self, dep_key: DepKey, channel: GeneralizedChannel, + base: Union[float,TimeType], factors: Sequence[Union[float,TimeType]], + domain: DepDomain, always_emit_incr: bool = False): new_dep_state = DepState( base, iterations=tuple(self.iterations) ) current_dep_state = self.dep_states.setdefault(channel, {}).get(dep_key, None) - if current_dep_state is None: - assert all(it == 0 for it in self.iterations) - self.commands.append(Set(channel, base, dep_key)) - self.active_dep[channel] = dep_key + + if current_dep_state is None or current_dep_state==0: + # if not all(it == 0 for it in self.iterations): + #this is not valid anymore, more intricate check would be necessary + #with check againstall iterations 0 for depkey compared without loop property + # if current_dep_state is None: + # assert all(it == 0 for it in self.iterations), self.iterations + self.commands.append(Set(channel, ResolutionDependentValue((),(),offset=base), dep_key)) + self.active_dep.setdefault(channel,{})[dep_key.domain] = dep_key else: + # print(self.label_num) + # print(self.iterations) inc = new_dep_state.required_increment_from(previous=current_dep_state, factors=factors) # we insert all inc here (also inc == 0) because it signals to activate this amplitude register - if inc or self.active_dep.get(channel, None) != dep_key: + # -> since this is not necessary for other domains, make it stricter and bypass if necessary for voltage. + if ((inc or self.active_dep.get(channel, {}).get(dep_key.domain) != dep_key) + and new_dep_state != current_dep_state)\ + or always_emit_incr: + # if always_emit_incr and new_dep_state == current_dep_state, inc should be zero. + #this is not always the case, e.g. if multiple sequenced pts with different + #dependencies on param exist and some with same are chained? very complicated case + #and probably not handled correctly + # if always_emit_incr and new_dep_state == current_dep_state: + # assert inc==0. self.commands.append(Increment(channel, inc, dep_key)) - self.active_dep[channel] = dep_key + self.active_dep.setdefault(channel,{})[dep_key.domain] = dep_key self.dep_states[channel][dep_key] = new_dep_state - + def _add_hold_node(self, node: LinSpaceHold): - if node.duration_factors: - raise NotImplementedError("TODO") - for ch, (base, factors) in enumerate(zip(node.bases, node.factors)): - if factors is None: - self.set_voltage(ch, base) + for ch in node.play_channels: + if node.factors[ch] is None: + self.set_voltage(ch, node.bases[ch]) continue - else: - self._set_indexed_voltage(ch, base, factors) - - self.commands.append(Wait(node.duration_base)) - - def add_node(self, node: Union[LinSpaceNode, Sequence[LinSpaceNode]]): + self._set_indexed_voltage(ch, node.bases[ch], node.factors[ch]) + + if node.duration_factors: + self._set_indexed_lin_time(node.duration_base,node.duration_factors) + # raise NotImplementedError("TODO") + self.commands.append(Wait(None, {DepDomain.TIME_LIN: self.active_dep[DepDomain.TIME_LIN][DepDomain.TIME_LIN]})) + else: + self.commands.append(Wait(node.duration_base)) + + def _add_indexed_play_node(self, node: LinSpaceArbitraryWaveformIndexed): + + #assume this as criterion: + if len(node.scale_bases) and len(node.offset_bases): + for ch in node.play_channels: + for base,factors,domain in zip((node.scale_bases[ch], node.offset_bases[ch]), + (node.scale_factors[ch], node.offset_factors[ch]), + (DepDomain.WF_SCALE,DepDomain.WF_OFFSET)): + if factors is None: + continue + # assume here that the waveform will have the correct settings the TransformingWaveform, + # where no SimpleExpression is replaced now. + # will yield the correct trafo already without having to make adjustments + # self.set_non_indexed_value(ch, base, domain) + else: + key = DepKey.from_domain(factors, resolution=self.resolution, domain=domain, free_upon_loop_exit=self.label_num) + self.set_indexed_value(key, ch, base, factors, key.domain) + + for st_ch, st_factors in node.index_factors.items(): + #this should not happen: + assert st_factors is not None + key = DepKey.from_domain(st_factors, resolution=self.resolution, domain=DepDomain.STEP_INDEX, free_upon_loop_exit=None) + self.set_indexed_value(key, st_ch, 0, st_factors, key.domain) + + + self.commands.append(Play(node.waveform, node.channels, step_channels=node.step_channels, + keys_by_domain_by_ch={c: self.active_dep.get(c,{}) for c in node.channels})) + + + def add_node(self, node: Union[LinSpaceNode, Sequence[LinSpaceNode]], + safe_rep_possible: bool = False, + ): """Translate a (sequence of) linspace node(s) to commands and add it to the internal command list.""" + if isinstance(node, Sequence): for lin_node in node: - self.add_node(lin_node) - + self.add_node(lin_node,safe_rep_possible if len(node)==1 else False) + + elif isinstance(node, LinSpaceSequence): + for node in node.body: + self.add_node(node) + elif isinstance(node, LinSpaceRepeat): - self._add_repetition_node(node) + self._add_repetition_node(node,safe_rep_possible) elif isinstance(node, LinSpaceIter): self._add_iteration_node(node) elif isinstance(node, LinSpaceHold): self._add_hold_node(node) - + + elif isinstance(node, LinSpaceArbitraryWaveformIndexed): + self._add_indexed_play_node(node) + elif isinstance(node, LinSpaceArbitraryWaveform): - self.commands.append(Play(node.waveform, node.channels)) + self.commands.append(Play(node.waveform, node.play_channels)) else: raise TypeError("The node type is not handled", type(node), node) -def to_increment_commands(linspace_nodes: Sequence[LinSpaceNode]) -> List[Command]: +def to_increment_commands(linspace_nodes: LinSpaceTopLevel, + # resolution: float = DEFAULT_INCREMENT_RESOLUTION + ) -> List[Command]: """translate the given linspace node tree to a minimal sequence of set and increment commands as well as loops.""" + # if resolution: raise NotImplementedError('wrongly assumed resolution. need to fix') state = _TranslationState() - state.add_node(linspace_nodes) + state.add_node(linspace_nodes.body,safe_rep_possible=True) return state.commands diff --git a/qupulse/program/loop.py b/qupulse/program/loop.py index 0f0356531..5ef8124cb 100644 --- a/qupulse/program/loop.py +++ b/qupulse/program/loop.py @@ -20,6 +20,7 @@ from qupulse.utils.numeric import smallest_factor_ge from qupulse.utils.tree import Node from qupulse.utils.types import TimeType, MeasurementWindow +from qupulse import ChannelID __all__ = ['Loop', 'make_compatible', 'MakeCompatibleWarning', 'to_waveform'] @@ -516,6 +517,9 @@ def reverse_inplace(self): (name, duration - (begin + length), length) for name, begin, length in self._measurements ] + + def get_defined_channels(self) -> Set[ChannelID]: + return next(self.get_depth_first_iterator()).waveform.defined_channels def to_waveform(program: Loop) -> Waveform: @@ -771,7 +775,7 @@ def __init__(self): self._stack: List[StackFrame] = [StackFrame(self._root, None)] - def inner_scope(self, scope: Scope) -> Scope: + def inner_scope(self, scope: Scope, pt_obj: 'ForLoopPT') -> Scope: local_vars = self._stack[-1].iterating if local_vars is None: return scope @@ -806,13 +810,23 @@ def with_repetition(self, repetition_count: RepetitionCount, self._try_append(repetition_loop, measurements) def with_iteration(self, index_name: str, rng: range, + pt_obj: 'ForLoopPT', measurements: Optional[Sequence[MeasurementWindow]] = None) -> Iterable['ProgramBuilder']: with self.with_sequence(measurements): top_frame = self._stack[-1] for value in rng: top_frame.iterating = (index_name, value) yield self - + + @contextmanager + def time_reversed(self) -> ContextManager['LoopBuilder']: + inner_builder = LoopBuilder() + yield inner_builder + inner_program = inner_builder.to_program() + if inner_program: + inner_program.reverse_inplace() + self._try_append(inner_program, None) + @contextmanager def with_sequence(self, measurements: Optional[Sequence[MeasurementWindow]] = None) -> ContextManager['ProgramBuilder']: top_frame = StackFrame(LoopGuard(self._top, measurements), None) @@ -842,7 +856,8 @@ def new_subprogram(self, global_transformation: Transformation = None) -> Contex waveform = TransformingWaveform.from_transformation(waveform, global_transformation) self.play_arbitrary_waveform(waveform) - def to_program(self) -> Optional[Loop]: + def to_program(self, defined_channels: Set[ChannelID]={}) -> Optional[Loop]: + #defined channels ignored as can be inferred from depth_iterator anyway if len(self._stack) != 1: warnings.warn("Creating program with active build stack.") if self._root.waveform or len(self._root.children) != 0: diff --git a/qupulse/program/transformation.py b/qupulse/program/transformation.py index 1d3c86879..21e437725 100644 --- a/qupulse/program/transformation.py +++ b/qupulse/program/transformation.py @@ -8,9 +8,9 @@ from qupulse.comparable import Comparable from qupulse.utils.types import SingletonABCMeta, frozendict from qupulse.expressions import ExpressionScalar +from qupulse.expressions.simple import SimpleExpression - -_TrafoValue = Union[Real, ExpressionScalar] +_TrafoValue = Union[Real, ExpressionScalar, SimpleExpression] __all__ = ['Transformation', 'IdentityTransformation', 'LinearTransformation', 'ScalingTransformation', @@ -88,7 +88,16 @@ def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) - class ChainedTransformation(Transformation): def __init__(self, *transformations: Transformation): - self._transformations = transformations + #avoid nesting also here in init to ensure always flat hierachy? + parsed = [] + for t in transformations: + if t is IdentityTransformation() or t is None: + pass + elif isinstance(t,ChainedTransformation): + parsed.extend(t.transformations) + else: + parsed.append(t) + self._transformations = tuple(parsed) @property def transformations(self) -> Tuple[Transformation, ...]: @@ -231,7 +240,7 @@ def __init__(self, offsets: Mapping[ChannelID, _TrafoValue]): def __call__(self, time: Union[np.ndarray, float], data: Mapping[ChannelID, Union[np.ndarray, float]]) -> Mapping[ChannelID, Union[np.ndarray, float]]: - offsets = _instantiate_expression_dict(time, self._offsets) + offsets = _instantiate_expression_dict(time, self._offsets, default_sweepval = 0.) return {channel: channel_values + offsets[channel] if channel in offsets else channel_values for channel, channel_values in data.items()} @@ -254,6 +263,10 @@ def is_constant_invariant(self): def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: return _get_constant_output_channels(self._offsets, input_channels) + + @property + def contains_sweepval(self) -> bool: + return any(isinstance(o,SimpleExpression) for o in self._offsets.values()) class ScalingTransformation(Transformation): @@ -263,7 +276,7 @@ def __init__(self, factors: Mapping[ChannelID, _TrafoValue]): def __call__(self, time: Union[np.ndarray, float], data: Mapping[ChannelID, Union[np.ndarray, float]]) -> Mapping[ChannelID, Union[np.ndarray, float]]: - factors = _instantiate_expression_dict(time, self._factors) + factors = _instantiate_expression_dict(time, self._factors, default_sweepval = 1.) return {channel: channel_values * factors[channel] if channel in factors else channel_values for channel, channel_values in data.items()} @@ -287,6 +300,10 @@ def is_constant_invariant(self): def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) -> AbstractSet[ChannelID]: return _get_constant_output_channels(self._factors, input_channels) + @property + def contains_sweepval(self) -> bool: + return any(isinstance(o,SimpleExpression) for o in self._factors.values()) + try: if TYPE_CHECKING: @@ -359,6 +376,10 @@ def get_constant_output_channels(self, input_channels: AbstractSet[ChannelID]) - output_channels.add(ch) return output_channels + + @property + def contains_sweepval(self) -> bool: + return any(isinstance(o,SimpleExpression) for o in self._channels.values()) def chain_transformations(*transformations: Transformation) -> Transformation: @@ -378,12 +399,17 @@ def chain_transformations(*transformations: Transformation) -> Transformation: return ChainedTransformation(*parsed_transformations) -def _instantiate_expression_dict(time, expressions: Mapping[str, _TrafoValue]) -> Mapping[str, Union[Real, np.ndarray]]: +def _instantiate_expression_dict(time, expressions: Mapping[str, _TrafoValue], + default_sweepval: float) -> Mapping[str, Union[Real, np.ndarray]]: scope = {'t': time} modified_expressions = {} for name, value in expressions.items(): if hasattr(value, 'evaluate_in_scope'): modified_expressions[name] = value.evaluate_in_scope(scope) + if isinstance(value, SimpleExpression): + # it is assumed that swept parameters will be handled by the ProgramBuilder accordingly + # such that here only an "identity" trafo is to be applied and the trafos are set in the program internally. + modified_expressions[name] = default_sweepval if modified_expressions: return {**expressions, **modified_expressions} else: diff --git a/qupulse/program/waveforms.py b/qupulse/program/waveforms.py index 94c395532..b544ef0b6 100644 --- a/qupulse/program/waveforms.py +++ b/qupulse/program/waveforms.py @@ -18,15 +18,13 @@ from qupulse import ChannelID from qupulse.program.transformation import Transformation -from qupulse.utils import checked_int_cast, isclose -from qupulse.utils.types import TimeType, time_from_float from qupulse.utils.performance import is_monotonic from qupulse.comparable import Comparable from qupulse.expressions import ExpressionScalar +from qupulse.expressions.simple import SimpleExpression from qupulse.pulses.interpolation import InterpolationStrategy from qupulse.utils import checked_int_cast, isclose -from qupulse.utils.types import TimeType, time_from_float, FrozenDict -from qupulse.program.transformation import Transformation +from qupulse.utils.types import TimeType, time_from_float from qupulse.utils import pairwise class ConstantFunctionPulseTemplateWarning(UserWarning): @@ -51,6 +49,13 @@ def _to_time_type(duration: Real) -> TimeType: else: return time_from_float(float(duration), absolute_error=PULSE_TO_WAVEFORM_ERROR) +def _to_hardware_time(duration: Union[Real, SimpleExpression]) -> Union[TimeType, SimpleExpression[TimeType]]: + if isinstance(duration, SimpleExpression): + return SimpleExpression[TimeType](_to_time_type(duration.base), + {name:_to_time_type(value) for name,value in duration.offsets.items()}) + else: + return _to_time_type(duration) + class Waveform(Comparable, metaclass=ABCMeta): """Represents an instantiated PulseTemplate which can be sampled to retrieve arrays of voltage @@ -58,10 +63,11 @@ class Waveform(Comparable, metaclass=ABCMeta): __sampled_cache = WeakValueDictionary() - __slots__ = ('_duration',) + __slots__ = ('_duration','_pow_2_divisor') - def __init__(self, duration: TimeType): + def __init__(self, duration: TimeType, _pow_2_divisor: int = 0): self._duration = duration + self._pow_2_divisor = _pow_2_divisor @property def duration(self) -> TimeType: @@ -215,7 +221,17 @@ def reversed(self) -> 'Waveform': """Returns a reversed version of this waveform.""" # We don't check for constness here because const waveforms are supposed to override this method return ReversedWaveform(self) - + + @abstractmethod + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Hashable: + """key for hashing *without* channel reference. Don't call directly, + only via _hash_only_subset. + """ + + def _hash_only_subset(self, channel_subset: Set[ChannelID]) -> int: + """Return a hash value of this Comparable object.""" + return hash(self.get_subset_for_channels(channel_subset)._compare_subset_key(channel_subset)) + class TableWaveformEntry(NamedTuple('TableWaveformEntry', [('t', Real), ('v', float), @@ -248,7 +264,7 @@ def __init__(self, category=DeprecationWarning) waveform_table = self._validate_input(waveform_table) - super().__init__(duration=_to_time_type(waveform_table[-1].t)) + super().__init__(duration=_to_hardware_time(waveform_table[-1].t)) self._table = waveform_table self._channel_id = channel @@ -353,7 +369,11 @@ def from_table(cls, channel: ChannelID, table: Sequence[EntryInInit]) -> Union[' @property def compare_key(self) -> Any: return self._channel_id, self._table - + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Any: + assert self.defined_channels == channel_subset + return self._table + def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, @@ -397,10 +417,7 @@ class ConstantWaveform(Waveform): def __init__(self, duration: Real, amplitude: Any, channel: ChannelID): """ Create a qupulse waveform corresponding to a ConstantPulseTemplate """ - super().__init__(duration=_to_time_type(duration)) - if hasattr(amplitude, 'shape'): - amplitude = amplitude[()] - hash(amplitude) + super().__init__(duration=_to_hardware_time(duration)) self._amplitude = amplitude self._channel = channel @@ -409,7 +426,7 @@ def from_mapping(cls, duration: Real, constant_values: Mapping[ChannelID, float] 'MultiChannelWaveform']: """Construct a ConstantWaveform or a MultiChannelWaveform of ConstantWaveforms with given duration and values""" assert constant_values - duration = _to_time_type(duration) + duration = _to_hardware_time(duration) if len(constant_values) == 1: (channel, amplitude), = constant_values.items() return cls(duration, amplitude=amplitude, channel=channel) @@ -437,7 +454,11 @@ def defined_channels(self) -> AbstractSet[ChannelID]: @property def compare_key(self) -> Tuple[Any, ...]: return self._duration, self._amplitude, self._channel - + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Tuple[Any, ...]: + assert self.defined_channels == channel_subset + return self._duration, self._amplitude + def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, @@ -483,7 +504,7 @@ def __init__(self, expression: ExpressionScalar, elif not expression.variables: warnings.warn("Constant FunctionWaveform is not recommended as the constant propagation will be suboptimal", category=ConstantFunctionPulseTemplateWarning) - super().__init__(duration=_to_time_type(duration)) + super().__init__(duration=_to_hardware_time(duration)) self._expression = expression self._channel_id = channel @@ -509,7 +530,11 @@ def defined_channels(self) -> AbstractSet[ChannelID]: @property def compare_key(self) -> Any: return self._channel_id, self._expression, self._duration - + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Any: + assert self.defined_channels == channel_subset + return self._expression, self._duration + @property def duration(self) -> TimeType: return self._duration @@ -639,7 +664,10 @@ def unsafe_sample(self, @property def compare_key(self) -> Tuple[Waveform]: return self._sequenced_waveforms - + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Tuple[Any]: + return tuple(wf._compare_subset_key(channel_subset) for wf in self._sequenced_waveforms) + @property def duration(self) -> TimeType: return self._duration @@ -788,7 +816,15 @@ def defined_channels(self) -> AbstractSet[ChannelID]: def compare_key(self) -> Any: # sort with channels return self._sub_waveforms - + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Any: + if len(channel_subset) == 0: return + if channel_subset != self.defined_channels: #also catches channel_subset >= self.defined_channels + # print(self.defined_channels) + # print(channel_subset) + return self.get_subset_for_channels(channel_subset)._compare_subset_key(channel_subset) + return tuple(self[channel]._compare_subset_key({channel}) for channel in channel_subset) + def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, @@ -856,7 +892,10 @@ def unsafe_sample(self, @property def compare_key(self) -> Tuple[Any, int]: return self._body.compare_key, self._repetition_count - + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Tuple[Any, int]: + return self._body._compare_subset_key(channel_subset), self._repetition_count + def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> Waveform: return RepetitionWaveform.from_repetition_count( body=self._body.unsafe_get_subset_for_channels(channels), @@ -931,7 +970,11 @@ def defined_channels(self) -> AbstractSet[ChannelID]: @property def compare_key(self) -> Tuple[Waveform, Transformation]: return self.inner_waveform, self.transformation - + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Tuple[Any, Transformation]: + remaining_channels = self.transformation.get_input_channels(channel_subset) + return self.inner_waveform._compare_subset_key(remaining_channels), self.transformation + def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> 'SubsetWaveform': return SubsetWaveform(self, channel_subset=channels) @@ -980,7 +1023,12 @@ def defined_channels(self) -> FrozenSet[ChannelID]: @property def compare_key(self) -> Tuple[frozenset, Waveform]: return self.defined_channels, self.inner_waveform - + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Any: + #creating another subset from inner_waveform may run into recursive loops? + #so pipe through until MultiChannelWF is reached basically? + return self._inner_waveform._compare_subset_key(channel_subset) + def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform: return self.inner_waveform.get_subset_for_channels(channels) @@ -1131,7 +1179,10 @@ def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform: @property def compare_key(self) -> Tuple[str, Waveform, Waveform]: return self._arithmetic_operator, self._lhs, self._rhs - + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Tuple[str, Any, Any]: + return self._arithmetic_operator, self._lhs._compare_subset_key(channel_subset), self._rhs._compare_subset_key(channel_subset) + class FunctorWaveform(Waveform): # TODO: Use Protocol to enforce that it accepts second argument has the keyword out @@ -1191,7 +1242,10 @@ def unsafe_get_subset_for_channels(self, channels: Set[ChannelID]) -> Waveform: @property def compare_key(self) -> Tuple[Waveform, FrozenSet]: return self._inner_waveform, frozenset(self._functor.items()) - + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Tuple[Any, FrozenSet]: + return self._inner_waveform._compare_subset_key(channel_subset), frozenset(self._functor.items()) + class ReversedWaveform(Waveform): """Reverses the inner waveform in time.""" @@ -1217,7 +1271,7 @@ def unsafe_sample(self, channel: ChannelID, sample_times: np.ndarray, else: inner_output_array = output_array[::-1] inner_output_array = self._inner.unsafe_sample(channel, inner_sample_times, output_array=inner_output_array) - if inner_output_array.base not in (output_array, output_array.base): + if id(inner_output_array.base) not in (id(output_array), id(output_array.base)): # TODO: is there a guarantee by numpy we never end up here? output_array[:] = inner_output_array[::-1] return output_array @@ -1232,6 +1286,54 @@ def unsafe_get_subset_for_channels(self, channels: AbstractSet[ChannelID]) -> 'W @property def compare_key(self) -> Hashable: return self._inner.compare_key + + def _compare_subset_key(self, channel_subset: Set[ChannelID]) -> Any: + return (self._inner._compare_subset_key(channel_subset),'-') def reversed(self) -> 'Waveform': return self._inner + + + +class WaveformCollection(): + + def __init__(self, waveform_collection: Tuple[Union[Waveform,"WaveformCollection"]]): + + self._waveform_collection = tuple(waveform_collection) + + @property + def duration(self) -> TimeType: + lens = [wf.duration for wf in self.flatten()] + assert np.all(np.isclose([float(l) for l in lens], float(lens[0]))) + return lens[0] + + @property + def waveform_collection(self): + return self._waveform_collection + + @property + def nesting_level(self): + #assume it is balanced for now. + if isinstance(self.waveform_collection[0],type(self)): + return self.waveform_collection[0].nesting_level+1 + return 0 + + def flatten(self) -> Tuple[Waveform]: + def flatten_tuple(nested_tuple): + for item in nested_tuple: + if isinstance(item, type(self)): + yield from flatten_tuple(item.waveform_collection) + else: + yield item + return tuple(flatten_tuple(self.waveform_collection)) + + def reversed(self) -> 'WaveformCollection': + """Returns a reversed version of this waveformcollection.""" + rev = tuple(w.reversed() for w in self._waveform_collection[::-1]) + return WaveformCollection(rev) + + @property + def _pow_2_divisor(self) -> int: + divs = set(wf._pow_2_divisor for wf in self.flatten()) + assert len(divs)==1 + return divs.pop() \ No newline at end of file diff --git a/qupulse/pulses/__init__.py b/qupulse/pulses/__init__.py index 4a8e1016f..d49833d7a 100644 --- a/qupulse/pulses/__init__.py +++ b/qupulse/pulses/__init__.py @@ -17,7 +17,8 @@ from qupulse.pulses.point_pulse_template import PointPulseTemplate as PointPT from qupulse.pulses.arithmetic_pulse_template import ArithmeticPulseTemplate as ArithmeticPT,\ ArithmeticAtomicPulseTemplate as ArithmeticAtomicPT -from qupulse.pulses.time_reversal_pulse_template import TimeReversalPulseTemplate as TimeReversalPT +from qupulse.pulses.time_reversal_pulse_template import TimeReversalPulseTemplate as TimeReversalPT,\ + AtomicTimeReversalPulseTemplate as AtomicTimeReversalPT import warnings with warnings.catch_warnings(): @@ -31,4 +32,4 @@ __all__ = ["FunctionPT", "ForLoopPT", "AtomicMultiChannelPT", "MappingPT", "RepetitionPT", "SequencePT", "TablePT", "PointPT", "ConstantPT", "AbstractPT", "ParallelConstantChannelPT", "ArithmeticPT", "ArithmeticAtomicPT", - "TimeReversalPT", "ParallelChannelPT"] + "TimeReversalPT", "ParallelChannelPT", "AtomicTimeReversalPT"] diff --git a/qupulse/pulses/arithmetic_pulse_template.py b/qupulse/pulses/arithmetic_pulse_template.py index eedfce20b..349a66667 100644 --- a/qupulse/pulses/arithmetic_pulse_template.py +++ b/qupulse/pulses/arithmetic_pulse_template.py @@ -6,7 +6,7 @@ import sympy -from qupulse.expressions import ExpressionScalar, ExpressionLike +from qupulse.expressions import ExpressionScalar, ExpressionLike, Expression from qupulse.serialization import Serializer, PulseRegistryType from qupulse.parameter_scope import Scope @@ -116,7 +116,7 @@ def _apply_operation(self, lhs: Mapping[str, Any], rhs: Mapping[str, Any]) -> Di operator_both=operator_both, rhs_only=rhs_only) - @property + @cached_property def integral(self) -> Dict[ChannelID, ExpressionScalar]: # this is a guard for possible future changes assert self._arithmetic_operator in ('+', '-'), \ @@ -126,11 +126,11 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: return self._apply_operation(self.lhs._as_expression(), self.rhs._as_expression()) - @property + @cached_property def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: return self._apply_operation(self.lhs.initial_values, self.rhs.initial_values) - @property + @cached_property def final_values(self) -> Dict[ChannelID, ExpressionScalar]: return self._apply_operation(self.lhs.final_values, self.rhs.final_values) @@ -448,7 +448,7 @@ def _scalar_as_dict(self) -> Dict[ChannelID, ExpressionScalar]: else: return dict(self._scalar) - @property + @cached_property def integral(self) -> Dict[ChannelID, ExpressionScalar]: if _is_time_dependent(self._scalar): # use the superclass implementation that relies on _as_expression @@ -493,14 +493,14 @@ def _apply_operation_to_channel_dict(self, return _apply_operation_to_channel_dict(lhs, rhs, operator_both=operator_both, rhs_only=rhs_only) - @property + @cached_property def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: return self._apply_operation_to_channel_dict( self._pulse_template.initial_values, self._scalar_as_dict() ) - @property + @cached_property def final_values(self) -> Dict[ChannelID, ExpressionScalar]: return self._apply_operation_to_channel_dict( self._pulse_template.final_values, @@ -538,7 +538,12 @@ def get_measurement_windows(self, def _is_atomic(self): return self._pulse_template._is_atomic() - + + def pad_all_atomic_subtemplates_to(self, + to_new_duration: Callable[[Expression], ExpressionLike]) -> 'PulseTemplate': + + self._pulse_template = self._pulse_template.pad_all_atomic_subtemplates_to(to_new_duration) + def try_operation(lhs: Union[PulseTemplate, ExpressionLike, Mapping[ChannelID, ExpressionLike]], op: str, diff --git a/qupulse/pulses/constant_pulse_template.py b/qupulse/pulses/constant_pulse_template.py index 192ee82b1..40eb2efb9 100644 --- a/qupulse/pulses/constant_pulse_template.py +++ b/qupulse/pulses/constant_pulse_template.py @@ -134,6 +134,6 @@ def build_waveform(self, def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: return {ch: ExpressionScalar(val) for ch, val in self._amplitude_dict.items()} - @property + @cached_property def final_values(self) -> Dict[ChannelID, ExpressionScalar]: return {ch: ExpressionScalar(val) for ch, val in self._amplitude_dict.items()} diff --git a/qupulse/pulses/function_pulse_template.py b/qupulse/pulses/function_pulse_template.py index 24d98fbe2..66e516c94 100644 --- a/qupulse/pulses/function_pulse_template.py +++ b/qupulse/pulses/function_pulse_template.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List, Set, Optional, Union import numbers +from functools import cached_property import sympy @@ -141,7 +142,7 @@ def deserialize(cls, del kwargs['measurement_declarations'] return super().deserialize(None, **kwargs) - @property + @cached_property def integral(self) -> Dict[ChannelID, ExpressionScalar]: return {self.__channel: ExpressionScalar( sympy.integrate(self.__expression.sympified_expression, ('t', 0, self.duration.sympified_expression)) @@ -151,12 +152,12 @@ def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: expr = ExpressionScalar.make(self.__expression.underlying_expression.subs({'t': self._AS_EXPRESSION_TIME})) return {self.__channel: expr} - @property + @cached_property def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: expr = ExpressionScalar.make(self.__expression.underlying_expression.subs('t', 0)) return {self.__channel: expr} - @property + @cached_property def final_values(self) -> Dict[ChannelID, ExpressionScalar]: expr = ExpressionScalar.make(self.__expression.underlying_expression.subs('t', self.__duration_expression.underlying_expression)) return {self.__channel: expr} diff --git a/qupulse/pulses/loop_pulse_template.py b/qupulse/pulses/loop_pulse_template.py index 0f458c687..053a1f0e4 100644 --- a/qupulse/pulses/loop_pulse_template.py +++ b/qupulse/pulses/loop_pulse_template.py @@ -4,9 +4,10 @@ import functools import itertools from abc import ABC -from typing import Dict, Set, Optional, Any, Union, Tuple, Iterator, Sequence, cast, Mapping +from typing import Dict, Set, Optional, Any, Union, Tuple, Iterator, Sequence, cast, Mapping, Callable import warnings from numbers import Number +from functools import cached_property import sympy @@ -16,7 +17,7 @@ from qupulse.program import ProgramBuilder -from qupulse.expressions import ExpressionScalar, ExpressionVariableMissingException, Expression +from qupulse.expressions import ExpressionScalar, ExpressionVariableMissingException, Expression, ExpressionLike from qupulse.utils import checked_int_cast, cached_property from qupulse.pulses.parameters import InvalidParameterNameException, ParameterConstrainer, ParameterNotProvidedException from qupulse.pulses.pulse_template import PulseTemplate, ChannelID, AtomicPulseTemplate @@ -159,16 +160,18 @@ def _internal_create_program(self, *, measurements = self.get_measurement_windows(scope, measurement_mapping) for iteration_program_builder in program_builder.with_iteration(loop_index_name, loop_range, - measurements=measurements): - self.body._create_program(scope=iteration_program_builder.inner_scope(scope), + measurements=measurements, + pt_obj=self): + self.body._create_program(scope=iteration_program_builder.inner_scope(scope,pt_obj=self), measurement_mapping=measurement_mapping, channel_mapping=channel_mapping, global_transformation=global_transformation, to_single_waveform=to_single_waveform, program_builder=iteration_program_builder) - def build_waveform(self, parameter_scope: Scope) -> ForLoopWaveform: - return ForLoopWaveform([self.body.build_waveform(local_scope) + def build_waveform(self, parameter_scope: Scope, + channel_mapping: Dict[ChannelID, Optional[ChannelID]]) -> ForLoopWaveform: + return ForLoopWaveform([self.body.build_waveform(local_scope,channel_mapping) for local_scope in self._body_scope_generator(parameter_scope, forward=True)]) def get_serialization_data(self, serializer: Optional[Serializer]=None) -> Dict[str, Any]: @@ -196,7 +199,7 @@ def deserialize(cls, serializer: Optional[Serializer]=None, **kwargs) -> 'ForLoo kwargs['body'] = cast(PulseTemplate, serializer.deserialize(kwargs['body'])) return super().deserialize(None, **kwargs) - @property + @cached_property def integral(self) -> Dict[ChannelID, ExpressionScalar]: step_size = self._loop_range.step.sympified_expression @@ -222,7 +225,7 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: return body_integrals - @property + @cached_property def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: values = self.body.initial_values initial_idx = self._loop_range.start @@ -230,7 +233,7 @@ def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: values[ch] = ExpressionScalar(value.underlying_expression.subs(self._loop_index, initial_idx)) return values - @property + @cached_property def final_values(self) -> Dict[ChannelID, ExpressionScalar]: values = self.body.final_values start, step, stop = self._loop_range.start.sympified_expression, self._loop_range.step.sympified_expression, self._loop_range.stop.sympified_expression @@ -239,7 +242,13 @@ def final_values(self) -> Dict[ChannelID, ExpressionScalar]: for ch, value in values.items(): values[ch] = ExpressionScalar(value.underlying_expression.subs(self._loop_index, final_idx)) return values - + + def pad_all_atomic_subtemplates_to(self, + to_new_duration: Callable[[Expression], ExpressionLike]) -> 'ForLoopPulseTemplate': + self.__body = self.body.pad_all_atomic_subtemplates_to(to_new_duration) + self.__dict__.pop('duration', None) + return self + class LoopIndexNotUsedException(Exception): def __init__(self, loop_index: str, body_parameter_names: Set[str]): diff --git a/qupulse/pulses/mapping_pulse_template.py b/qupulse/pulses/mapping_pulse_template.py index 07e7d1024..004e43076 100644 --- a/qupulse/pulses/mapping_pulse_template.py +++ b/qupulse/pulses/mapping_pulse_template.py @@ -1,10 +1,11 @@ -from typing import Optional, Set, Dict, Union, List, Any, Tuple, Mapping +from typing import Optional, Set, Dict, Union, List, Any, Tuple, Mapping, Callable import itertools import numbers import collections +from functools import cached_property from qupulse.utils.types import ChannelID, FrozenDict, FrozenMapping -from qupulse.expressions import Expression, ExpressionScalar +from qupulse.expressions import Expression, ExpressionScalar, ExpressionLike from qupulse.parameter_scope import Scope, MappedScope from qupulse.pulses.pulse_template import PulseTemplate, MappingTuple from qupulse.pulses.parameters import ParameterNotProvidedException, ParameterConstrainer @@ -338,21 +339,26 @@ def _apply_mapping_to_inner_channel_dict(self, to_map: Dict[ChannelID, Expressio if self.__channel_mapping.get(ch, ch) is not None } - @property + @cached_property def integral(self) -> Dict[ChannelID, ExpressionScalar]: return self._apply_mapping_to_inner_channel_dict(self.__template.integral) def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: return self._apply_mapping_to_inner_channel_dict(self.__template._as_expression()) - @property + @cached_property def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: return self._apply_mapping_to_inner_channel_dict(self.__template.initial_values) - @property + @cached_property def final_values(self) -> Dict[ChannelID, ExpressionScalar]: return self._apply_mapping_to_inner_channel_dict(self.__template.final_values) - + + def pad_all_atomic_subtemplates_to(self, + to_new_duration: Callable[[Expression], ExpressionLike]) -> 'PulseTemplate': + + self.__template = self.template.pad_all_atomic_subtemplates_to(to_new_duration) + return self class MissingMappingException(Exception): """Indicates that no mapping was specified for some parameter declaration of a diff --git a/qupulse/pulses/multi_channel_pulse_template.py b/qupulse/pulses/multi_channel_pulse_template.py index 6b76bb49f..b78f49c07 100644 --- a/qupulse/pulses/multi_channel_pulse_template.py +++ b/qupulse/pulses/multi_channel_pulse_template.py @@ -7,9 +7,10 @@ - ParallelChannelPulseTemplate: A pulse template to add channels to an existing pulse template. """ -from typing import Dict, List, Optional, Any, AbstractSet, Union, Set, Sequence, Mapping +from typing import Dict, List, Optional, Any, AbstractSet, Union, Set, Sequence, Mapping, Callable import numbers import warnings +from functools import cached_property from qupulse.serialization import Serializer, PulseRegistryType from qupulse.parameter_scope import Scope @@ -215,7 +216,14 @@ def final_values(self) -> Dict[ChannelID, ExpressionScalar]: for subtemplate in self._subtemplates: values.update(subtemplate.final_values) return values - + + def pad_all_atomic_subtemplates_to(self, + to_new_duration: Callable[[Expression], ExpressionLike]) -> 'PulseTemplate': + + for i,subtemplate in enumerate(self._subtemplates): + self._subtemplates[i] = subtemplate.pad_all_atomic_subtemplates_to(to_new_duration) + return self + class ParallelChannelPulseTemplate(PulseTemplate): def __init__(self, @@ -311,7 +319,7 @@ def parameter_names(self): def duration(self) -> ExpressionScalar: return self.template.duration - @property + @cached_property def integral(self) -> Dict[ChannelID, ExpressionScalar]: integral = self._template.integral @@ -320,13 +328,13 @@ def integral(self) -> Dict[ChannelID, ExpressionScalar]: integral[channel] = value * duration return integral - @property + @cached_property def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: values = self._template.initial_values values.update(self._overwritten_channels) return values - @property + @cached_property def final_values(self) -> Dict[ChannelID, ExpressionScalar]: values = self._template.final_values values.update(self._overwritten_channels) @@ -352,7 +360,13 @@ def with_parallel_channels(self, values: Mapping[ChannelID, ExpressionLike]) -> def _is_atomic(self) -> bool: return self._template._is_atomic() - + + def pad_all_atomic_subtemplates_to(self, + to_new_duration: Callable[[Expression], ExpressionLike]) -> 'PulseTemplate': + + self._template = self.template.pad_all_atomic_subtemplates_to(to_new_duration) + return self + ParallelConstantChannelPulseTemplate = ParallelChannelPulseTemplate diff --git a/qupulse/pulses/point_pulse_template.py b/qupulse/pulses/point_pulse_template.py index 0075a98dc..bc9c3f063 100644 --- a/qupulse/pulses/point_pulse_template.py +++ b/qupulse/pulses/point_pulse_template.py @@ -2,6 +2,7 @@ from numbers import Real import itertools import numbers +from functools import cached_property import sympy import numpy as np @@ -131,7 +132,7 @@ def point_parameters(self) -> Set[str]: def parameter_names(self) -> Set[str]: return self.point_parameters | self.measurement_parameters | self.constrained_parameters - @property + @cached_property def integral(self) -> Dict[ChannelID, ExpressionScalar]: expressions = {} shape = (len(self.defined_channels),) @@ -168,7 +169,7 @@ def value_trafo(v): expressions[channel] = pw return expressions - @property + @cached_property def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: shape = (len(self._channels),) return { @@ -176,7 +177,7 @@ def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: for ch_idx, ch in enumerate(self._channels) } - @property + @cached_property def final_values(self) -> Dict[ChannelID, ExpressionScalar]: shape = (len(self._channels),) return { diff --git a/qupulse/pulses/pulse_template.py b/qupulse/pulses/pulse_template.py index 97dd5cda4..3b1573683 100644 --- a/qupulse/pulses/pulse_template.py +++ b/qupulse/pulses/pulse_template.py @@ -12,6 +12,7 @@ import itertools import collections from numbers import Real, Number +from functools import cached_property import numpy import sympy @@ -28,6 +29,9 @@ from qupulse.parameter_scope import Scope, DictScope from qupulse.program import ProgramBuilder, default_program_builder, Program +from qupulse.program.linspace import LinSpaceBuilder + +from qupulse.expressions.simple import SimpleExpressionStepped __all__ = ["PulseTemplate", "AtomicPulseTemplate", "DoubleParameterNameException", "MappingTuple", "UnknownVolatileParameter"] @@ -59,7 +63,9 @@ def __init__(self, *, identifier: Optional[str]) -> None: super().__init__(identifier=identifier) self.__cached_hash_value = None - + self._pow_2_divisor: int = 0 + # self.__cached_serialization_data = None + @property @abstractmethod def parameter_names(self) -> Set[str]: @@ -207,7 +213,7 @@ def create_program(self, *, to_single_waveform=to_single_waveform, program_builder=program_builder) - return program_builder.to_program() + return program_builder.to_program(set(complete_channel_mapping.values())) @abstractmethod def _internal_create_program(self, *, @@ -373,6 +379,7 @@ def with_appended(self, *appended: 'PulseTemplate'): return self def pad_to(self, to_new_duration: Union[ExpressionLike, Callable[[Expression], ExpressionLike]], + as_single_wf: bool = False, pt_kwargs: Mapping[str, Any] = None) -> 'PulseTemplate': """Pad this pulse template to the given duration. The target duration can be numeric, symbolic or a callable that returns a new duration from the current @@ -392,27 +399,59 @@ def pad_to(self, to_new_duration: Union[ExpressionLike, Callable[[Expression], E >>> padded_4 = my_pt.pad_to(to_next_multiple(1, 16)) Args: to_new_duration: Duration or callable that maps the current duration to the new duration + as_single_wf: + - if PT is supposed ot be single element in memory, e.g. + to conform to waveform-granularities, select True. + (for nested PTs, pad_to on sub-PTs or pad_all_atomic recommened) + - if used just as a mean to elongate last value, select False. pt_kwargs: Keyword arguments for the newly created sequence pulse template. Returns: - A pulse template that has the duration given by ``to_new_duration``. It can be ``self`` if the duration is - already as required. It is never ``self`` if ``pt_kwargs`` is non-empty. + A pulse template that has the duration given by ``to_new_duration``. + XXX# self if ConstantPT, + else SingleWFTimeExtensionPulseTemplate if as_single_wf, + else SequencePT """ from qupulse.pulses import ConstantPT, SequencePT + from qupulse.pulses.time_extension_pulse_template import SingleWFTimeExtensionPulseTemplate current_duration = self.duration if callable(to_new_duration): new_duration = to_new_duration(current_duration) else: new_duration = ExpressionScalar(to_new_duration) pad_duration = new_duration - current_duration - if not pt_kwargs and pad_duration == 0: - return self - pad_pt = ConstantPT(pad_duration, self.final_values) - if pt_kwargs: - return SequencePT(self, pad_pt, **pt_kwargs) + + #maybe leads to inconsistencies if self may be returned + # #shortcut + # if isinstance(self,ConstantPT): + # if pt_kwargs: + # raise NotImplementedError() + # self._duration = new_duration + # return self + + pt_kwargs = pt_kwargs or {} + + if as_single_wf: + return SingleWFTimeExtensionPulseTemplate(self, new_duration, **pt_kwargs) + else: - return self @ pad_pt - + if not pt_kwargs and pad_duration == 0: + return self + pad_pt = ConstantPT(pad_duration, self.final_values) + if pt_kwargs: + return SequencePT(self, pad_pt, **pt_kwargs) + else: + return self @ pad_pt + + # @abstractmethod + def pad_all_atomic_subtemplates_to(self, + to_new_duration: Callable[[Expression], ExpressionLike]) -> 'PulseTemplate': + """pad ll atomic subtemplates to a new duration determiend from callable + to_new_duration, e.g. from qupulse.utils.to_next_multiple for waveform + granularity. + """ + raise NotImplementedError() + def __format__(self, format_spec: str): if format_spec == '': format_spec = self._DEFAULT_FORMAT_SPEC @@ -428,14 +467,29 @@ def __format__(self, format_spec: str): def __str__(self): return format(self) - - def __repr__(self): + + # @cached_property + # def _cached_repr(self) -> str: + # type_name = type(self).__name__ + # kwargs = ','.join('%s=%r' % (key, value) + # for key, value in self.__cached_serialization_data.items() + # if key.isidentifier() and value is not None) + # return '{type_name}({kwargs})'.format(type_name=type_name, kwargs=kwargs) + + # def __repr__(self): + # ser = self.get_serialization_data() + # if self.__cached_serialization_data != ser: + # self.__cached_serialization_data = ser + # self.__dict__.pop("_cached_repr", None) + # return self._cached_repr + + def __repr__(self) -> str: type_name = type(self).__name__ kwargs = ','.join('%s=%r' % (key, value) for key, value in self.get_serialization_data().items() if key.isidentifier() and value is not None) return '{type_name}({kwargs})'.format(type_name=type_name, kwargs=kwargs) - + def __add__(self, other: ExpressionLike): from qupulse.pulses.arithmetic_pulse_template import try_operation return try_operation(self, '+', other) @@ -513,7 +567,24 @@ def _internal_create_program(self, *, ### current behavior (same as previously): only adds EXEC Loop and measurements if a waveform exists. ### measurements are directly added to parent_loop (to reflect behavior of Sequencer + MultiChannelProgram) assert not scope.get_volatile_parameters().keys() & self.parameter_names, "AtomicPT cannot be volatile" - + + + # "hackedy": + if program_builder.evaluate_nested_stepping(scope,self.parameter_names): + measurements = self.get_measurement_windows(parameters=scope, + measurement_mapping=measurement_mapping) + program_builder.measure(measurements) + + program_builder.dispatch_to_stepped_wf_or_hold(build_func=self.build_waveform, + build_parameters=scope, + parameter_names=self.parameter_names, + channel_mapping=channel_mapping, + #measurements + global_transformation=global_transformation, + _pow_2_divisor=self._pow_2_divisor + ) + return + waveform = self.build_waveform(parameters=scope, channel_mapping=channel_mapping) if waveform: @@ -523,7 +594,9 @@ def _internal_create_program(self, *, if global_transformation: waveform = TransformingWaveform.from_transformation(waveform, global_transformation) - + + waveform._pow_2_divisor = self._pow_2_divisor + constant_values = waveform.constant_value_dict() if constant_values is None: program_builder.play_arbitrary_waveform(waveform) @@ -555,27 +628,35 @@ def _as_expression(self) -> Dict[ChannelID, ExpressionScalar]: raise NotImplementedError(f"_as_expression is not implemented for {type(self)} " f"which means it cannot be truncated and integrated over.") - @property + @cached_property def integral(self) -> Dict[ChannelID, ExpressionScalar]: # this default implementation uses _as_expression return {ch: ExpressionScalar(sympy.integrate(expr.sympified_expression, (self._AS_EXPRESSION_TIME, 0, self.duration.sympified_expression))) for ch, expr in self._as_expression().items()} - @property + @cached_property def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: values = self._as_expression() for ch, value in values.items(): values[ch] = value.evaluate_symbolic({self._AS_EXPRESSION_TIME: 0}) return values - @property + @cached_property def final_values(self) -> Dict[ChannelID, ExpressionScalar]: values = self._as_expression() for ch, value in values.items(): values[ch] = value.evaluate_symbolic({self._AS_EXPRESSION_TIME: self.duration}) return values - + + def pad_to(self, to_new_duration: Union[ExpressionLike, Callable[[Expression], ExpressionLike]], + pt_kwargs: Mapping[str, Any] = {}) -> 'PulseTemplate': + return super().pad_to(to_new_duration,as_single_wf=True,pt_kwargs=pt_kwargs) + + def pad_all_atomic_subtemplates_to(self, + to_new_duration: Callable[[Expression], ExpressionLike]) -> 'PulseTemplate': + return self.pad_to(to_new_duration) + class DoubleParameterNameException(Exception): diff --git a/qupulse/pulses/repetition_pulse_template.py b/qupulse/pulses/repetition_pulse_template.py index ead19c6d9..bae7998fd 100644 --- a/qupulse/pulses/repetition_pulse_template.py +++ b/qupulse/pulses/repetition_pulse_template.py @@ -1,9 +1,10 @@ """This module defines RepetitionPulseTemplate, a higher-order hierarchical pulse template that represents the n-times repetition of another PulseTemplate.""" -from typing import Dict, List, AbstractSet, Optional, Union, Any, Mapping, cast +from typing import Dict, List, AbstractSet, Optional, Union, Any, Mapping, cast, Callable from numbers import Real from warnings import warn +from functools import cached_property import numpy as np @@ -13,7 +14,7 @@ from qupulse.parameter_scope import Scope from qupulse.utils.types import ChannelID -from qupulse.expressions import ExpressionScalar +from qupulse.expressions import ExpressionScalar, Expression, ExpressionLike from qupulse.utils import checked_int_cast from qupulse.pulses.pulse_template import PulseTemplate from qupulse.pulses.loop_pulse_template import LoopPulseTemplate @@ -135,7 +136,9 @@ def _internal_create_program(self, *, for repetition_program_builder in program_builder.with_repetition(repetition_definition, measurements=measurements): - self.body._create_program(scope=repetition_program_builder.inner_scope(scope), + self.body._create_program( + #scope=repetition_program_builder.inner_scope(scope, pt_obj=self), + scope=scope, #there should not be any replacements with SimpleExpression here, so this is unnecessary measurement_mapping=measurement_mapping, channel_mapping=channel_mapping, global_transformation=global_transformation, @@ -166,7 +169,7 @@ def deserialize(cls, serializer: Optional[Serializer]=None, **kwargs) -> 'Repeti return super().deserialize(**kwargs) - @property + @cached_property def integral(self) -> Dict[ChannelID, ExpressionScalar]: body_integral = self.body.integral return {channel: self.repetition_count * value for channel, value in body_integral.items()} @@ -178,7 +181,13 @@ def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: @property def final_values(self) -> Dict[ChannelID, ExpressionScalar]: return self.body.final_values - + + def pad_all_atomic_subtemplates_to(self, + to_new_duration: Callable[[Expression], ExpressionLike]) -> 'PulseTemplate': + + self.__body = self.body.pad_all_atomic_subtemplates_to(to_new_duration) + return self + class ParameterNotIntegerException(Exception): """Indicates that the value of the parameter given as repetition count was not an integer.""" diff --git a/qupulse/pulses/sequence_pulse_template.py b/qupulse/pulses/sequence_pulse_template.py index 5107bb104..6aac50f35 100644 --- a/qupulse/pulses/sequence_pulse_template.py +++ b/qupulse/pulses/sequence_pulse_template.py @@ -17,7 +17,7 @@ from qupulse.pulses.mapping_pulse_template import MappingPulseTemplate, MappingTuple from qupulse.program.waveforms import SequenceWaveform from qupulse.pulses.measurement import MeasurementDeclaration, MeasurementDefiner -from qupulse.expressions import Expression, ExpressionScalar +from qupulse.expressions import Expression, ExpressionScalar, ExpressionLike __all__ = ["SequencePulseTemplate"] @@ -177,7 +177,7 @@ def deserialize(cls, def defined_channels(self) -> Set[ChannelID]: return self.__subtemplates[0].defined_channels - @property + @cached_property def integral(self) -> Dict[ChannelID, ExpressionScalar]: expressions = {channel: 0 for channel in self.defined_channels} @@ -194,3 +194,9 @@ def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: def final_values(self) -> Dict[ChannelID, ExpressionScalar]: return self.__subtemplates[-1].final_values + def pad_all_atomic_subtemplates_to(self, + to_new_duration: Callable[[Expression], ExpressionLike]) -> 'PulseTemplate': + + for i,sub in enumerate(self.__subtemplates): + self.__subtemplates[i] = sub.pad_all_atomic_subtemplates_to(to_new_duration) + return self \ No newline at end of file diff --git a/qupulse/pulses/table_pulse_template.py b/qupulse/pulses/table_pulse_template.py index f8b631add..053e12afc 100644 --- a/qupulse/pulses/table_pulse_template.py +++ b/qupulse/pulses/table_pulse_template.py @@ -11,6 +11,7 @@ import numbers import itertools import warnings +from functools import cached_property import numpy as np import sympy @@ -412,7 +413,7 @@ def is_valid_interpolation_strategy(inter): return TablePulseTemplate(parsed, **kwargs) - @property + @cached_property def integral(self) -> Dict[ChannelID, ExpressionScalar]: expressions = dict() for channel, channel_entries in self._entries.items(): diff --git a/qupulse/pulses/time_extension_pulse_template.py b/qupulse/pulses/time_extension_pulse_template.py new file mode 100644 index 000000000..672fcc322 --- /dev/null +++ b/qupulse/pulses/time_extension_pulse_template.py @@ -0,0 +1,103 @@ +from numbers import Real +from typing import Dict, Optional, Set, Union, List, Iterable, Any +from functools import cached_property + +from qupulse import ChannelID +from qupulse.parameter_scope import Scope +from qupulse.pulses.pulse_template import PulseTemplate, AtomicPulseTemplate +from qupulse.pulses.constant_pulse_template import ConstantPulseTemplate as ConstantPT +from qupulse.expressions import ExpressionLike, ExpressionScalar +from qupulse._program.waveforms import ConstantWaveform +from qupulse.program import ProgramBuilder +from qupulse.pulses.parameters import ConstraintLike +from qupulse.pulses.measurement import MeasurementDeclaration +from qupulse.serialization import Serializer, PulseRegistryType +from qupulse.program.waveforms import SequenceWaveform + + +def _evaluate_expression_dict(expression_dict: Dict[str, ExpressionScalar], scope: Scope) -> Dict[str, float]: + return {ch: value.evaluate_in_scope(scope) + for ch, value in expression_dict.items()} + + +class SingleWFTimeExtensionPulseTemplate(AtomicPulseTemplate): + """Extend the given pulse template with a constant suffix. + """ + + def __init__(self, + main_pt: PulseTemplate, + new_duration: Union[str, ExpressionScalar], + identifier: Optional[str] = None, + *, + measurements: Optional[List[MeasurementDeclaration]]=None, + registry: PulseRegistryType=None) -> None: + + AtomicPulseTemplate.__init__(self, identifier=identifier, measurements=measurements) + + self.__main_pt = main_pt + self.__pad_pt = ConstantPT(new_duration-main_pt.duration, self.final_values) + self._duration = ExpressionScalar.make(new_duration) + + self._register(registry=registry) + + @property + def parameter_names(self) -> Set[str]: + return self.__main_pt.parameter_names + + @property + def duration(self) -> ExpressionScalar: + """An expression for the duration of this PulseTemplate.""" + return self._duration + + @property + def defined_channels(self) -> Set[ChannelID]: + return self.__main_pt.defined_channels + + @cached_property + def integral(self) -> Dict[ChannelID, ExpressionScalar]: + + unextended = self.__main_pt.integral + + return {ch: unextended_ch + (self.duration-self.__main_pt.duration)*self.__main_pt.final_values[ch] \ + for ch,unextended_ch in unextended.items()} + + @property + def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: + return self.__main_pt.initial_values + + @property + def final_values(self) -> Dict[ChannelID, ExpressionScalar]: + return self.__main_pt.final_values + + def get_serialization_data(self, serializer: Optional[Serializer]=None) -> Dict[str, Any]: + if serializer is not None: + raise NotImplementedError("SingleWFTimeExtensionPulseTemplate does not implement legacy serialization.") + data = super().get_serialization_data(serializer) + data['main_pt'] = self.__main_pt + data['new_duration'] = self.duration + data['measurements']: self.measurement_declarations + + return data + + @classmethod + def deserialize(cls, + serializer: Optional[Serializer]=None, # compatibility to old serialization routines, deprecated + **kwargs) -> 'SingleWFTimeExtensionPulseTemplate': + main_pt = kwargs['main_pt'] + new_duration = kwargs['new_duration'] + del kwargs['main_pt'] + del kwargs['new_duration'] + + if serializer: # compatibility to old serialization routines, deprecated + raise NotImplementedError() + + return cls(main_pt,new_duration,**kwargs) + + def build_waveform(self, + parameters: Dict[str, Real], + channel_mapping: Dict[ChannelID, ChannelID]) -> SequenceWaveform: + return SequenceWaveform.from_sequence( + [wf for sub_template in [self.__main_pt,self.__pad_pt] + if (wf:=sub_template.build_waveform(parameters, channel_mapping=channel_mapping)) is not None]) + + \ No newline at end of file diff --git a/qupulse/pulses/time_reversal_pulse_template.py b/qupulse/pulses/time_reversal_pulse_template.py index 35a1884be..359d2d866 100644 --- a/qupulse/pulses/time_reversal_pulse_template.py +++ b/qupulse/pulses/time_reversal_pulse_template.py @@ -1,14 +1,86 @@ -from typing import Optional, Set, Dict, Union +from typing import Optional, Set, Dict, Union, Callable, Any from qupulse import ChannelID from qupulse.program.loop import Loop from qupulse.program.waveforms import Waveform from qupulse.serialization import PulseRegistryType -from qupulse.expressions import ExpressionScalar +from qupulse.expressions import ExpressionScalar, Expression, ExpressionLike +from qupulse.parameter_scope import Scope +from qupulse.program import ProgramBuilder +from qupulse.pulses.pulse_template import PulseTemplate, AtomicPulseTemplate +from qupulse.serialization import Serializer, PulseRegistryType -from qupulse.pulses.pulse_template import PulseTemplate +class AtomicTimeReversalPulseTemplate(AtomicPulseTemplate): + """Extend the given pulse template with a constant suffix. + """ + + def __init__(self, inner: PulseTemplate, + identifier: Optional[str] = None, + registry: PulseRegistryType = None): + + assert isinstance(inner, AtomicPulseTemplate) + AtomicPulseTemplate.__init__(self, identifier=identifier,measurements=None) + + self._inner = inner + self._register(registry=registry) + + @property + def parameter_names(self) -> Set[str]: + return self._inner.parameter_names + + @property + def duration(self) -> ExpressionScalar: + """An expression for the duration of this PulseTemplate.""" + return self._inner.duration + + @property + def defined_channels(self) -> Set[ChannelID]: + return self._inner.defined_channels + + @property + def integral(self) -> Dict[ChannelID, ExpressionScalar]: + return self._inner.integral + + @property + def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: + return self._inner.final_values + + @property + def final_values(self) -> Dict[ChannelID, ExpressionScalar]: + return self._inner.initial_values + + def get_serialization_data(self, serializer: Optional[Serializer]=None) -> Dict[str, Any]: + if serializer is not None: + raise NotImplementedError("AtomicTimeReversalPulseTemplate does not implement legacy serialization.") + data = super().get_serialization_data(serializer) + data['inner'] = self._inner + + return data + + @classmethod + def deserialize(cls, + serializer: Optional[Serializer]=None, # compatibility to old serialization routines, deprecated + **kwargs) -> 'AtomicTimeReversalPulseTemplate': + main_pt = kwargs['main_pt'] + new_duration = kwargs['new_duration'] + del kwargs['main_pt'] + del kwargs['new_duration'] + + if serializer: # compatibility to old serialization routines, deprecated + raise NotImplementedError() + + return cls(main_pt,new_duration,**kwargs) + + def build_waveform(self, + *args, **kwargs) -> Optional[Waveform]: + wf = self._inner.build_waveform(*args, **kwargs) + if wf is not None: + return wf.reversed() + + + class TimeReversalPulseTemplate(PulseTemplate): """This pulse template reverses the inner pulse template in time.""" @@ -45,14 +117,19 @@ def defined_channels(self) -> Set['ChannelID']: @property def integral(self) -> Dict[ChannelID, ExpressionScalar]: return self._inner.integral + + @property + def initial_values(self) -> Dict[ChannelID, ExpressionScalar]: + return self._inner.final_values - def _internal_create_program(self, *, parent_loop: Loop, **kwargs) -> None: - inner_loop = Loop() - self._inner._internal_create_program(parent_loop=inner_loop, **kwargs) - inner_loop.reverse_inplace() - - parent_loop.append_child(inner_loop) - + @property + def final_values(self) -> Dict[ChannelID, ExpressionScalar]: + return self._inner.initial_values + + def _internal_create_program(self, *, program_builder: ProgramBuilder, **kwargs) -> None: + with program_builder.time_reversed() as reversed_builder: + self._inner._internal_create_program(program_builder=reversed_builder, **kwargs) + def build_waveform(self, *args, **kwargs) -> Optional[Waveform]: wf = self._inner.build_waveform(*args, **kwargs) @@ -68,3 +145,9 @@ def get_serialization_data(self, serializer=None): def _is_atomic(self) -> bool: return self._inner._is_atomic() + + def pad_all_atomic_subtemplates_to(self, + to_new_duration: Callable[[Expression], ExpressionLike]) -> 'PulseTemplate': + self._inner = self._inner.pad_all_atomic_subtemplates_to(to_new_duration) + return self + \ No newline at end of file diff --git a/qupulse/utils/__init__.py b/qupulse/utils/__init__.py index 326072f4b..e33b1ac4d 100644 --- a/qupulse/utils/__init__.py +++ b/qupulse/utils/__init__.py @@ -7,14 +7,16 @@ from collections import OrderedDict from frozendict import frozendict from qupulse.expressions import ExpressionScalar, ExpressionLike +from qupulse.expressions.simple import SimpleExpression import numpy +import sympy as sp try: - from math import isclose + from math import isclose as math_isclose except ImportError: # py version < 3.5 - isclose = None + math_isclose = None try: from functools import cached_property @@ -51,8 +53,17 @@ def _fallback_is_close(a, b, *, rel_tol=1e-09, abs_tol=0.0): return abs(a-b) <= max(rel_tol * max(abs(a), abs(b)), abs_tol) # pragma: no cover -if not isclose: - isclose = _fallback_is_close +if not math_isclose: + math_isclose = _fallback_is_close + + +def checked_is_close(a, b, *, rel_tol=1e-09, abs_tol=0.0): + if isinstance(a,SimpleExpression) or isinstance(b,SimpleExpression): + return _fallback_is_close(a, b, rel_tol=rel_tol, abs_tol=abs_tol) + return math_isclose(a,b,rel_tol=rel_tol,abs_tol=abs_tol) + + +isclose = checked_is_close def _fallback_pairwise(iterable: Iterable[_T]) -> Iterator[Tuple[_T, _T]]: @@ -126,7 +137,7 @@ def forced_hash(obj) -> int: def to_next_multiple(sample_rate: ExpressionLike, quantum: int, - min_quanta: Optional[int] = None) -> Callable[[ExpressionLike],ExpressionScalar]: + min_quanta: Optional[int] = None) -> Callable[[ExpressionLike],ExpressionScalar]: """Construct a helper function to expand a duration to one corresponding to valid sample multiples according to the arguments given. Useful e.g. for PulseTemplate.pad_to's 'to_new_duration'-argument. @@ -147,6 +158,22 @@ def to_next_multiple(sample_rate: ExpressionLike, quantum: int, #double negative for ceil division. return lambda duration: -(-(duration*sample_rate)//quantum) * (quantum/sample_rate) else: - #still return 0 if duration==0 - return lambda duration: ExpressionScalar(f'{quantum}/({sample_rate})*Max({min_quanta},-(-{duration}*{sample_rate}//{quantum}))*Max(0, sign({duration}))') - \ No newline at end of file + qI = sp.Integer(quantum) + k = qI / sample_rate # factor to go from #quanta -> duration + mqI = sp.Integer(min_quanta) + + def _build_sym(d): + u = d*sample_rate/qI # "duration in quanta" (real) + ce = sp.ceiling(u) # number of quanta after rounding up + + # Enforce: 0 if d <= 0; else at least mqI quanta. + # max(mqI, ceil(u)) <=> mqI if u <= mqI, else ceil(u) + # do not evaluate right now because parameters could still be variable, + # then it's just overhead. + return sp.Piecewise( + (0, sp.Le(d, 0)), + (k*mqI, sp.Le(u, mqI)), + (k*ce, True) + , evaluate=False) + + return lambda duration: ExpressionScalar(_build_sym(duration)) \ No newline at end of file diff --git a/qupulse/utils/sympy.py b/qupulse/utils/sympy.py index 7a04d53f6..f5076eeda 100644 --- a/qupulse/utils/sympy.py +++ b/qupulse/utils/sympy.py @@ -11,21 +11,32 @@ import numpy try: - from sympy.printing.numpy import NumPyPrinter + import scipy +except ImportError: + scipy = None + +try: + from sympy.printing.numpy import NumPyPrinter, SciPyPrinter except ImportError: # sympy moved NumPyPrinter in release 1.8 from sympy.printing.pycode import NumPyPrinter + SciPyPrinter = None warnings.warn("Please update sympy.", DeprecationWarning) -try: +if scipy: import scipy.special as _special_functions -except ImportError: +else: _special_functions = {fname: numpy.vectorize(fobject) for fname, fobject in math.__dict__.items() if callable(fobject) and not fname.startswith('_') and fname not in numpy.__dict__} warnings.warn('scipy is not installed. This reduces the set of available functions to those present in numpy + ' 'manually vectorized functions in math.') +if scipy and SciPyPrinter: + PrinterBase = SciPyPrinter +else: + PrinterBase = NumPyPrinter + __all__ = ["sympify", "substitute_with_eval", "to_numpy", "get_variables", "get_free_symbols", "recursive_substitution", "evaluate_lambdified", "get_most_simple_representation"] @@ -368,8 +379,13 @@ def recursive_substitution(expression: sympy.Expr, _numpy_environment = {**_base_environment, **numpy.__dict__} _sympy_environment = {**_base_environment, **sympy.__dict__} + _lambdify_modules = [{'ceiling': numpy_compatible_ceiling, 'floor': _floor_to_int, - 'Broadcast': numpy.broadcast_to}, 'numpy', _special_functions] + 'Broadcast': numpy.broadcast_to}, 'numpy', _special_functions,] + +if scipy: + # this is required for Integral lambdification + _lambdify_modules.append("scipy") def evaluate_compiled(expression: sympy.Expr, @@ -397,7 +413,7 @@ def evaluate_lambdified(expression: Union[sympy.Expr, numpy.ndarray], return lambdified(**parameters), lambdified -class HighPrecPrinter(NumPyPrinter): +class HighPrecPrinter(PrinterBase): """Custom printer that translates sympy.Rational into TimeType""" def _print_Rational(self, expr): return f'TimeType.from_fraction({expr.p}, {expr.q})' diff --git a/setup.cfg b/setup.cfg index a77b1e909..76b607da4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -71,7 +71,7 @@ qctoolkit = *.pyi [tool:pytest] -testpaths = tests tests/pulses tests/hardware tests/backward_compatibility +testpaths = tests tests/pulses tests/hardware tests/backward_compatibility tests/_program tests/expressions tests/program tests/utils python_files=*_tests.py *_bug.py filterwarnings = # syntax is action:message_regex:category:module_regex:lineno diff --git a/tests/_program/waveforms_tests.py b/tests/_program/waveforms_tests.py index c62fceb3a..f38631e87 100644 --- a/tests/_program/waveforms_tests.py +++ b/tests/_program/waveforms_tests.py @@ -50,7 +50,11 @@ def unsafe_sample(self, def compare_key(self): raise NotImplementedError() - + @property + def _compare_subset_key(self, channel_subset): + raise NotImplementedError() + + class WaveformTest(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -292,8 +296,25 @@ def test_constant_default_impl(self): self.assertIsNone(wf_non_const.constant_value_dict()) self.assertIsNone(wf_mixed.constant_value_dict()) self.assertEqual(wf_mixed.constant_value('C'), 2.2) - - + + def test_hash_subset(self): + dwf_a = DummyWaveform(duration=246.2, defined_channels={'A'}, sample_output={'A': 1*np.ones(3)}) + dwf_b = DummyWaveform(duration=246.2, defined_channels={'B'}, sample_output={'B': 2*np.ones(3)}) + dwf_c = DummyWaveform(duration=246.2, defined_channels={'C'}, sample_output={'C': 3*np.ones(3)}) + waveform_a1 = MultiChannelWaveform([dwf_a, dwf_b, dwf_c]) + waveform_a2 = MultiChannelWaveform([dwf_a, dwf_b]) + waveform_a3 = MultiChannelWaveform([dwf_a, dwf_c]) + + self.assertEqual(waveform_a1._hash_only_subset({'A','B'}), + waveform_a2._hash_only_subset({'A','B'})) + self.assertEqual(waveform_a1._hash_only_subset({'A','C'}), + waveform_a3._hash_only_subset({'A','C'})) + self.assertNotEqual(waveform_a1._hash_only_subset({'A','B'}), + waveform_a3._hash_only_subset({'A','C'})) + + self.assertRaises(KeyError, lambda: waveform_a1._hash_only_subset({'A','B','C','D'})) + + class RepetitionWaveformTest(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -335,6 +356,11 @@ def test_compare_key(self): wf = RepetitionWaveform(body_wf, 2) self.assertEqual(wf.compare_key, (body_wf.compare_key, 2)) + def test_compare_subset(self): + body_wf = DummyWaveform(defined_channels={'a'}) + wf = RepetitionWaveform(body_wf, 2) + self.assertEqual(wf._compare_subset_key({'a',}), (body_wf._compare_subset_key({'a',}), 2)) + def test_unsafe_get_subset_for_channels(self): body_wf = DummyWaveform(defined_channels={'a', 'b'}) @@ -492,6 +518,12 @@ def test_repr(self): r = repr(swf) self.assertEqual(swf, eval(r)) + def test_compare_subset(self): + body_wf = DummyWaveform(defined_channels={'a'}) + wf = SequenceWaveform([body_wf, body_wf]) + self.assertEqual(wf.get_subset_for_channels({'a'})._compare_subset_key({'a',}), + tuple(2*[body_wf.get_subset_for_channels({'a'})._compare_subset_key({'a',}),])) + class ConstantWaveformTests(unittest.TestCase): def test_waveform_duration(self): @@ -521,6 +553,14 @@ def test_constness(self): self.assertTrue(waveform.is_constant()) assert_constant_consistent(self, waveform) + def test_hash_subset(self): + wf_1 = ConstantWaveform(10, 1., 'A') + wf_2 = ConstantWaveform(10, 1., 'B') + wf_3 = ConstantWaveform(10, 2., 'A') + + self.assertEqual(wf_1._hash_only_subset({'A',}), wf_2._hash_only_subset({'B',})) + self.assertNotEqual(wf_1._hash_only_subset({'A',}), wf_3._hash_only_subset({'A',})) + class TableWaveformTests(unittest.TestCase): @@ -654,6 +694,25 @@ def test_simple_properties(self): evaled = eval(repr(waveform)) self.assertEqual(evaled, waveform) + def test_hash_subset(self): + + interp = HoldInterpolationStrategy() + entries = (TableWaveformEntry(0, 0, interp), + TableWaveformEntry(2.1, -33.2, interp), + TableWaveformEntry(5.7, 123.4, interp)) + wf_1 = TableWaveform('A', entries) + entries = (TableWaveformEntry(0, 0, interp), + TableWaveformEntry(2.1, -33.2, interp), + TableWaveformEntry(5.7, 123.4, interp)) + wf_2 = TableWaveform('B', entries) + entries = (TableWaveformEntry(0.5, 0, interp), + TableWaveformEntry(2.1, -33.2, interp), + TableWaveformEntry(5.7, 123.4, interp)) + wf_3 = TableWaveform('A', entries) + + self.assertEqual(wf_1._hash_only_subset({'A',}), wf_2._hash_only_subset({'B',})) + self.assertNotEqual(wf_1._hash_only_subset({'A',}), wf_3._hash_only_subset({'A',})) + class WaveformEntryTest(unittest.TestCase): def test_interpolation_exception(self): @@ -796,6 +855,16 @@ def test_const_value(self): with mock.patch.object(inner_wf, 'constant_value', side_effect=inner_const_values.values()) as constant_value: self.assertIsNone(trafo_wf.constant_value('C')) + def test_compare_subset(self): + output_channels = {'c', 'd', 'e'} + input_channels = {'a', 'b'} + trafo = TransformationDummy(output_channels=output_channels,input_channels=input_channels) + inner_wf = DummyWaveform(duration=1.5, defined_channels=input_channels) + trafo_wf = TransformingWaveform(inner_waveform=inner_wf, transformation=trafo) + + self.assertEqual(trafo_wf._compare_subset_key(output_channels), + (inner_wf._compare_subset_key(input_channels), trafo)) + class SubsetWaveformTest(unittest.TestCase): def test_simple_properties(self): @@ -836,6 +905,13 @@ def test_unsafe_sample(self): self.assertIs(expected_data, actual_data) unsafe_sample.assert_called_once_with('g', time, output) + def test_compare_subset(self): + inner_wf = DummyWaveform(duration=1.5, defined_channels={'a', 'b', 'c'}) + subset_wf = SubsetWaveform(inner_wf, {'a', 'c'}) + + self.assertEqual(subset_wf._compare_subset_key({'a', 'b', 'c'}), + inner_wf._compare_subset_key({'a', 'b', 'c'}),) + class ArithmeticWaveformTest(unittest.TestCase): def test_from_operator(self): @@ -891,8 +967,22 @@ def test_simple_properties(self): self.assertIs(rhs, arith.rhs) self.assertEqual('-', arith.arithmetic_operator) self.assertEqual(lhs.duration, arith.duration) - + + def test_compare_subset(self): + lhs_1 = DummyWaveform(duration=1.5, defined_channels={'a',}) + lhs_2 = DummyWaveform(duration=1.5, defined_channels={'b',}) + lhs_3 = DummyWaveform(duration=1.5, defined_channels={'c'}) + rhs_1 = DummyWaveform(duration=1.5, defined_channels={'a',}) + rhs_2 = DummyWaveform(duration=1.5, defined_channels={'b',}) + rhs_3 = DummyWaveform(duration=1.5, defined_channels={'d'}) + + lhs = MultiChannelWaveform([lhs_1,lhs_2,lhs_3]) + rhs = MultiChannelWaveform([rhs_1,rhs_2,rhs_3]) + arith = ArithmeticWaveform(lhs, '-', rhs) + self.assertEqual(('-', lhs, rhs), arith.compare_key) + self.assertEqual(('-', lhs._compare_subset_key({'a','b',}),rhs._compare_subset_key({'a','b',})), + arith._compare_subset_key({'a','b',})) def test_unsafe_get_subset_for_channels(self): lhs = DummyWaveform(duration=1.5, defined_channels={'a', 'b', 'c'}) @@ -1007,6 +1097,14 @@ def test_repr(self): r = repr(wf) self.assertEqual(wf, eval(r)) + def test_compare_subset(self): + wf1a = FunctionWaveform(ExpressionScalar('1+2*t'), 3, channel='A') + wf1b = FunctionWaveform(ExpressionScalar('t*2+1'), 3, channel='A') + + self.assertEqual(wf1a._compare_subset_key({'A',}), (wf1a._expression,wf1a._duration)) + self.assertEqual(wf1a._compare_subset_key({'A',}), + wf1b._compare_subset_key({'A',}),) + class FunctorWaveformTests(unittest.TestCase): def test_duration(self): @@ -1075,6 +1173,16 @@ def test_compare_key(self): self.assertNotEqual(wf11, wf21) self.assertNotEqual(wf11, wf22) + def test_compare_subset(self): + inner_wf_1 = DummyWaveform(defined_channels={'A', 'B'}) + functors_1 = dict(A=np.positive, B=np.negative) + + wf11 = FunctorWaveform(inner_wf_1, functors_1) + + self.assertEqual((inner_wf_1._compare_subset_key({'A', 'B'}), + frozenset(functors_1.items())), + wf11._compare_subset_key({'A', 'B'})) + class ReversedWaveformTest(unittest.TestCase): def test_simple_properties(self): @@ -1084,6 +1192,8 @@ def test_simple_properties(self): self.assertEqual(dummy_wf.duration, reversed_wf.duration) self.assertEqual(dummy_wf.defined_channels, reversed_wf.defined_channels) self.assertEqual(dummy_wf.compare_key, reversed_wf.compare_key) + self.assertEqual(reversed_wf._compare_subset_key({'A','B'}), + (dummy_wf._compare_subset_key({'A','B'}),'-')) self.assertNotEqual(reversed_wf, dummy_wf) def test_reversed_sample(self): @@ -1103,4 +1213,4 @@ def test_reversed_sample(self): np.testing.assert_equal(output, sample_output[::-1]) np.testing.assert_equal(dummy_wf.sample_calls, [ ('A', list(1.5 - time_array[::-1]), None), - ('A', list(1.5 - time_array[::-1]), mem[::-1])]) + ('A', list(1.5 - time_array[::-1]), mem[::-1])]) \ No newline at end of file diff --git a/tests/program/linspace_tests.py b/tests/program/linspace_tests.py index 03a5b2971..541afa915 100644 --- a/tests/program/linspace_tests.py +++ b/tests/program/linspace_tests.py @@ -287,3 +287,23 @@ def test_local_trafo_program(self): global_transformation=self.transformation, to_single_waveform={self.pulse_template}) self.assertEqual([self.program], program) + + +class TimeSweepTest(TestCase): + def setUp(self,base_time=1e2,rep_factor=2): + wait = ConstantPT(f'64*{base_time}*1e1*(1+idx_t)', {'a': '-1. + idx_a * 0.01', 'b': '-.5 + idx_b * 0.02'}) + + random_constant = ConstantPT(10 ** 5, {'a': -.4, 'b': -.3}) + meas = ConstantPT(64*base_time, {'a': 0.05, 'b': 0.06}) + + singlet_scan = (random_constant @ wait @ meas).with_iteration('idx_a', rep_factor*10*2)\ + .with_iteration('idx_b', rep_factor*10)\ + .with_iteration('idx_t', 10) + self.pulse_template = singlet_scan + + + def test_singlet_scan_program(self): + program_builder = LinSpaceBuilder(('a', 'b')) + program = self.pulse_template.create_program(program_builder=program_builder) + # so far just a test to see if the program creation works at all. + # self.assertEqual([self.program], program) \ No newline at end of file diff --git a/tests/program/linspace_tests2.py b/tests/program/linspace_tests2.py new file mode 100644 index 000000000..275ed086f --- /dev/null +++ b/tests/program/linspace_tests2.py @@ -0,0 +1,220 @@ +import copy +import unittest +from unittest import TestCase + +from qupulse.pulses import * +from qupulse.program.linspace import * +from qupulse.program.transformation import * + +class SingleRampTest(TestCase): + def setUp(self): + hold = ConstantPT(10 ** 6, {'a': '-1. + idx * 0.01'}) + self.pulse_template = hold.with_iteration('idx', 200) + + self.program = LinSpaceTopLevel(body=(LinSpaceIter( + to_be_stepped=False, + length=200, + body=(LinSpaceHold( + channels=('a',), + bases={'a': -1.0}, + factors={'a': (0.01,)}, + duration_base=TimeType(10**6), + duration_factors=None + ),) + ),)) + + key = DepKey.from_voltages((0.01,), DEFAULT_INCREMENT_RESOLUTION) + + self.commands = [ + Set('a',ResolutionDependentValue((),(),-1.0),key=key), + Wait(TimeType(10 ** 6)), + LoopLabel(0, 199), + Increment('a',ResolutionDependentValue((0.01,),(1,),0.0),key=key), + Wait(TimeType(10 ** 6)), + LoopJmp(0) + ] + + def test_program(self): + program_builder = LinSpaceBuilder() + program = self.pulse_template.create_program(program_builder=program_builder) + self.assertEqual([self.program], program) + + def test_commands(self): + program_builder = LinSpaceBuilder() + program = self.pulse_template.create_program(program_builder=program_builder) + self.commands_to_test = to_increment_commands(program) + self.assertEqual(self.commands, self.commands_to_test) + + +class TimeSweepTest(TestCase): + def setUp(self,base_time=1e2,rep_factor=2): + wait = ConstantPT(f'64*{base_time}*1e1*(1+idx_t)', + {'a': '-1. + idx_a * 0.01', 'b': '-.5 + idx_b * 0.02'}) + + random_constant = ConstantPT(10 ** 5, {'a': -.4, 'b': -.3}) + meas = ConstantPT(64*base_time, {'a': 0.05, 'b': 0.06}) + + singlet_scan = (random_constant @ wait @ meas).with_iteration('idx_a', rep_factor*10*2)\ + .with_iteration('idx_b', rep_factor*10)\ + .with_iteration('idx_t', 10) + self.pulse_template = singlet_scan + + + def test_program(self): + program_builder = LinSpaceBuilder() + self._program_to_test = self.pulse_template.create_program(program_builder=program_builder) + + def test_commands(self): + program_builder = LinSpaceBuilder() + self._program_to_test = self.pulse_template.create_program(program_builder=program_builder) + commands = to_increment_commands(self._program_to_test) + # so far just a test to see if the program creation works at all. + # self.assertEqual([self.program], program) + + +class SequencedIterationTest(TestCase): + def setUp(self,base_time=1e2,rep_factor=2): + wait = AtomicMultiChannelPT( + ConstantPT(f'64*{base_time}*1e1', {'a': '-1. + idx_a * 0.01 + y_gain', }), + ConstantPT(f'64*{base_time}*1e1', {'b': '-.5 + idx_b * 0.02'}) + ) + + dependent_constant = AtomicMultiChannelPT( + ConstantPT(10 ** 5, {'a': -.3}), + ConstantPT(10 ** 5, {'b': 'idx_b*0.02',}), + ) + + dependent_constant2 = AtomicMultiChannelPT( + ConstantPT(2*10 ** 5, {'a': '-.3+idx_b*0.01'}), + ConstantPT(2*10 ** 5, {'b': 'idx_b*0.01',}), + ) + + pt = (dependent_constant @ dependent_constant2 @ (wait.with_iteration('idx_a', rep_factor*10*2)) \ + @ dependent_constant2).with_iteration('idx_b', rep_factor*10)\ + + self.pulse_template = MappingPT(pt,parameter_mapping=dict(y_gain=0.3,)) + + def test_program(self): + program_builder = LinSpaceBuilder() + self._program_to_test = self.pulse_template.create_program(program_builder=program_builder) + + def test_commands(self): + program_builder = LinSpaceBuilder() + self._program_to_test = self.pulse_template.create_program(program_builder=program_builder) + commands = to_increment_commands(self._program_to_test) + # so far just a test to see if the program creation works at all. + # self.assertEqual([self.program], program) + + +class SequencedIterationTest(TestCase): + def setUp(self,base_time=1e2,rep_factor=2): + wait = AtomicMultiChannelPT( + ConstantPT(f'64*{base_time}*1e1', {'a': '-1. + idx_a * 0.01 + y_gain', }), + ConstantPT(f'64*{base_time}*1e1', {'b': '-.5 + idx_b * 0.02'}) + ) + + dependent_constant = AtomicMultiChannelPT( + ConstantPT(10 ** 5, {'a': -.3}), + ConstantPT(10 ** 5, {'b': 'idx_b*0.02',}), + ) + + dependent_constant2 = AtomicMultiChannelPT( + ConstantPT(2*10 ** 5, {'a': '-.3+idx_b*0.01'}), + ConstantPT(2*10 ** 5, {'b': 'idx_b*0.01',}), + measurements=[('c',0,10**5)] + ) + + pt = (dependent_constant @ dependent_constant2 @ (wait.with_iteration('idx_a', rep_factor*10*2)) \ + @ dependent_constant2).with_iteration('idx_b', rep_factor*10)\ + + step_len = 10**5*5 + rep_factor*10*2*64*base_time*1e1 + + self.measurements = { + 'c': ( + np.repeat(np.arange(0.,(rep_factor*10)*step_len,step_len),2)+np.array([10**5,step_len-2*10**5]*rep_factor*10), + np.ones(rep_factor*10*2)*10**5 + ) + } + + self.pulse_template = MappingPT(pt,parameter_mapping=dict(y_gain=0.3,)) + + def test_measurements(self): + program_builder = LinSpaceBuilder() + self._program_to_test = self.pulse_template.create_program(program_builder=program_builder) + self.assertIsNone(np.testing.assert_array_equal(self._program_to_test[0].get_measurement_windows()['c'], self.measurements['c'])) + + +class AmplitudeSweepTest(TestCase): + def setUp(self,rep_factor=2): + + normal_pt = FunctionPT("sin(t/100)","t_sin",channel='a') + amp_pt = "amp*1/8"*FunctionPT("sin(t/1000)","t_sin",channel='a') + + pt = (normal_pt@amp_pt@normal_pt@amp_pt@amp_pt@normal_pt).with_iteration('amp', rep_factor) + self.pulse_template = MappingPT(pt,parameter_mapping=dict(t_sin=64e2,)) + + def test_program(self): + program_builder = LinSpaceBuilder() + self._program_to_test = self.pulse_template.create_program(program_builder=program_builder) + + def test_commands(self): + program_builder = LinSpaceBuilder() + self._program_to_test = self.pulse_template.create_program(program_builder=program_builder) + commands = to_increment_commands(self._program_to_test) + + # so far just a test to see if the program creation works at all. + # self.assertEqual([self.program], program) + + +class SteppedRepetitionTest(TestCase): + def setUp(self,base_time=1e2,rep_factor=2): + + wait = ConstantPT(f'64*{base_time}*1e1*(1+idx_t)', {'a': '-0.5 + idx_a * 0.15', 'b': '-.5 + idx_a * 0.3'}) + normal_pt = ParallelConstantChannelPT(FunctionPT("sin(t/1000)","t_sin",channel='a'),{'b':-0.2}) + amp_pt = ParallelConstantChannelPT("amp*1/8"*FunctionPT("sin(t/1000)","t_sin",channel='a'),{'b':-0.5}) + # amp_pt2 = ParallelConstantChannelPT("amp2*1/8"*FunctionPT("sin(t/1000)","t_sin",channel='a'),{'b':-0.5}) + amp_inner = ParallelConstantChannelPT(FunctionPT(f"(1+amp)*1/(2*{rep_factor})*sin(4*pi*t/t_sin)","t_sin",channel='a'),{'b':-0.5}) + amp_inner2 = ParallelConstantChannelPT(FunctionPT(f"(1+amp2)*1/(2*{rep_factor})*sin((1*freq)*4*pi*t/t_sin)+off/(2*{rep_factor})","t_sin",channel='a'),{'b':-0.3}) + + pt = (((normal_pt@amp_inner2).with_iteration('off', rep_factor)@normal_pt@wait)\ + .with_repetition(rep_factor)@amp_inner.with_iteration('amp', rep_factor))\ + .with_iteration('amp2', rep_factor).with_iteration('freq', rep_factor).with_iteration('idx_a',rep_factor) + + self.pulse_template = MappingPT(pt,parameter_mapping=dict(t_sin=64e2,idx_t=1,)) + + def test_program(self): + program_builder = LinSpaceBuilder(to_stepping_repeat={'amp','amp2','off','freq'},) + self._program_to_test = self.pulse_template.create_program(program_builder=program_builder) + + def test_commands(self): + program_builder = LinSpaceBuilder(to_stepping_repeat={'amp','amp2','off','freq'},) + self._program_to_test = self.pulse_template.create_program(program_builder=program_builder) + commands = to_increment_commands(self._program_to_test) + # so far just a test to see if the program creation works at all. + # self.assertEqual([self.program], program) + + +class CombinedSweepTest(TestCase): + def setUp(self,base_time=1e2,rep_factor=2): + + wait = ConstantPT(f'64*{base_time}*1e1*(1+idx_t)', {'a': f'-1. + idx_a * 0.5/{rep_factor}', 'b': f'-.5 + idx_a * 0.8/{rep_factor}'}) + normal_pt = ParallelConstantChannelPT(FunctionPT("sin(t/2000)","t_sin",channel='a'),{'b':-0.2}) + amp_pt = ParallelConstantChannelPT(f"amp*1/1.5 * 1/{rep_factor}"*FunctionPT("sin(t/2000)","t_sin",channel='a'),{'b':-0.5}) + + pt = (normal_pt@amp_pt@normal_pt@wait@amp_pt@amp_pt@normal_pt)\ + .with_iteration('amp', rep_factor)\ + .with_iteration('idx_a', rep_factor)\ + .with_iteration('idx_t', rep_factor) + + self.pulse_template = MappingPT(pt,parameter_mapping=dict(t_sin=64e2,)) + + def test_program(self): + program_builder = LinSpaceBuilder() + self._program_to_test = self.pulse_template.create_program(program_builder=program_builder) + + def test_commands(self): + program_builder = LinSpaceBuilder() + self._program_to_test = self.pulse_template.create_program(program_builder=program_builder) + commands = to_increment_commands(self._program_to_test) + # so far just a test to see if the program creation works at all. + # self.assertEqual([self.program], program) \ No newline at end of file diff --git a/tests/program/pulse_metadata_test.py b/tests/program/pulse_metadata_test.py new file mode 100644 index 000000000..b344b28bc --- /dev/null +++ b/tests/program/pulse_metadata_test.py @@ -0,0 +1,31 @@ +import copy +import unittest +from unittest import TestCase + +from qupulse.pulses import * +from qupulse.program.linspace import * +from qupulse.program.transformation import * + + +class PulseMetaDataTest(TestCase): + def setUp(self): + hold_a = ConstantPT(10 ** 6, {'a': '-1. + idx * 0.01'}) + hold_b = ConstantPT(10 ** 6, {'a': '-0.2 + idx * 0.005'}) + hold_combined = SequencePT(hold_a,hold_b,identifier='hold_pt') + hold_2 = ConstantPT(10 ** 6, {'a': '-0.5'}) + play_arbitrary = FunctionPT("tanh(a*t**2 + b*t + c) * sin(b*t + c) + cos(a*t)/2",192*1e5,channel="a") + + self.pulse_template = (hold @ play_arbitrary @ hold_2).with_iteration('idx', 200) + + self.pulse_metadata = { + 'hold_pt': PTMetaData(True,2.0), + hold_2: PTMetaData(to_single_waveform=False,minimal_sample_rate=1e-10,), + play_arbitrary: PTMetaData(False,1e-3) + } + + def test_program(self): + program_builder = LinSpaceBuilder(('a',)) + program = self.pulse_template.create_program( + program_builder=program_builder, + metadata=self.pulse_metadata + ) diff --git a/tests/pulses/sequencing_dummies.py b/tests/pulses/sequencing_dummies.py index 21c3c7e62..3d2ccaba4 100644 --- a/tests/pulses/sequencing_dummies.py +++ b/tests/pulses/sequencing_dummies.py @@ -43,7 +43,7 @@ def __init__(self, duration: Union[float, TimeType]=0, sample_output: Union[nump defined_channels = set(sample_output.keys()) else: defined_channels = {'A'} - self.defined_channels_ = defined_channels + self.defined_channels_ = self.channels = defined_channels self.sample_calls = [] @property @@ -59,7 +59,23 @@ def compare_key(self) -> Any: ) else: return id(self) - + + def _compare_subset_key(self, channel_subset) -> Any: + assert self.channels==channel_subset + if self.sample_output is not None: + try: + if isinstance(self.sample_output,dict): + return hash(self.sample_output.values().tobytes()) + return hash(self.sample_output.tobytes()) + except AttributeError: + pass + return hash( + tuple(sorted((getattr(output, 'tobytes', lambda: output)(),) + for output in self.sample_output.values())) + ) + else: + return id(self) + @property def measurement_windows(self): return [] diff --git a/tests/utils/utils_tests.py b/tests/utils/utils_tests.py index 83e1a26aa..8c7d5c046 100644 --- a/tests/utils/utils_tests.py +++ b/tests/utils/utils_tests.py @@ -108,6 +108,7 @@ class ToNextMultipleTests(unittest.TestCase): def test_to_next_multiple(self): from qupulse.utils.types import TimeType from qupulse.expressions import ExpressionScalar + precision_digits = 12 duration = TimeType.from_float(47.1415926535) evaluated = to_next_multiple(sample_rate=TimeType.from_float(2.4),quantum=16)(duration) @@ -120,24 +121,32 @@ def test_to_next_multiple(self): self.assertEqual(evaluated, expected) duration = 6185240.0000001 - evaluated = to_next_multiple(sample_rate=1.0,quantum=16,min_quanta=13)(duration) + evaluated = to_next_multiple(sample_rate=1.0,quantum=16,min_quanta=13)(duration).evaluate_numeric() expected = 6185248 - self.assertEqual(evaluated, expected) + self.assertAlmostEqual(evaluated, expected, precision_digits) + + duration = 63.99 + evaluated = to_next_multiple(sample_rate=1.0,quantum=16,min_quanta=4)(duration).evaluate_numeric() + expected = 64 + self.assertAlmostEqual(evaluated, expected, precision_digits) + + duration = 64.01 + evaluated = to_next_multiple(sample_rate=1.0,quantum=16,min_quanta=4)(duration).evaluate_numeric() + expected = 80 + self.assertAlmostEqual(evaluated, expected, precision_digits) duration = 0. - evaluated = to_next_multiple(sample_rate=1.0,quantum=16,min_quanta=13)(duration) + evaluated = to_next_multiple(sample_rate=1.0,quantum=16,min_quanta=13)(duration).evaluate_numeric() expected = 0. - self.assertEqual(evaluated, expected) + self.assertAlmostEqual(evaluated, expected, precision_digits) duration = ExpressionScalar('abc') evaluated = to_next_multiple(sample_rate=1.0,quantum=16,min_quanta=13)(duration).evaluate_in_scope(dict(abc=0.)) expected = 0. - self.assertEqual(evaluated, expected) + self.assertAlmostEqual(evaluated, expected, precision_digits) duration = ExpressionScalar('q') evaluated = to_next_multiple(sample_rate=ExpressionScalar('w'),quantum=16,min_quanta=1)(duration).evaluate_in_scope( dict(q=3.14159,w=1.0)) expected = 16. - self.assertEqual(evaluated, expected) - - \ No newline at end of file + self.assertAlmostEqual(evaluated, expected, precision_digits) \ No newline at end of file