|
25 | 25 |
|
26 | 26 | namespace cmdstan { |
27 | 27 |
|
28 | | - int generate_quantities(CLI::App& app, |
29 | | - SharedOptions& shared_options, |
30 | | - GenerateQuantitiesOptions& gq_options) { |
31 | | - static int hmc_fixed_cols = 7; // hmc sampler outputs columns __lp + 6 |
| 28 | +int generate_quantities(CLI::App& app, SharedOptions& shared_options, |
| 29 | + GenerateQuantitiesOptions& gq_options) { |
| 30 | + static int hmc_fixed_cols = 7; // hmc sampler outputs columns __lp + 6 |
32 | 31 |
|
33 | | - stan::callbacks::stream_writer info(std::cout); |
34 | | - stan::callbacks::stream_writer err(std::cout); |
35 | | - stan::callbacks::stream_logger logger(std::cout, std::cout, std::cout, |
36 | | - std::cerr, std::cerr); |
| 32 | + stan::callbacks::stream_writer info(std::cout); |
| 33 | + stan::callbacks::stream_writer err(std::cout); |
| 34 | + stan::callbacks::stream_logger logger(std::cout, std::cout, std::cout, |
| 35 | + std::cerr, std::cerr); |
37 | 36 |
|
38 | | - // Read arguments |
39 | | - write_parallel_info(info); |
40 | | - write_opencl_device(info); |
41 | | - info(); |
| 37 | + // Read arguments |
| 38 | + write_parallel_info(info); |
| 39 | + write_opencl_device(info); |
| 40 | + info(); |
42 | 41 |
|
43 | | - if (gq_options.fitted_params == shared_options.output_file) { |
44 | | - std::stringstream msg; |
45 | | - msg << "Filename conflict, fitted_params file " |
46 | | - << gq_options.fitted_params |
47 | | - << " and output file have same name, must be different." |
48 | | - << std::endl; |
49 | | - throw std::invalid_argument(msg.str()); |
50 | | - } |
| 42 | + if (gq_options.fitted_params == shared_options.output_file) { |
| 43 | + std::stringstream msg; |
| 44 | + msg << "Filename conflict, fitted_params file " << gq_options.fitted_params |
| 45 | + << " and output file have same name, must be different." << std::endl; |
| 46 | + throw std::invalid_argument(msg.str()); |
| 47 | + } |
51 | 48 |
|
52 | | - stan::callbacks::interrupt interrupt; |
| 49 | + stan::callbacks::interrupt interrupt; |
53 | 50 |
|
54 | | - std::fstream output_stream(shared_options.output_file.c_str(), |
55 | | - std::fstream::out); |
56 | | - stan::callbacks::stream_writer sample_writer(output_stream, "# "); |
| 51 | + std::fstream output_stream(shared_options.output_file.c_str(), |
| 52 | + std::fstream::out); |
| 53 | + stan::callbacks::stream_writer sample_writer(output_stream, "# "); |
57 | 54 |
|
58 | | - ////////////////////////////////////////////////// |
59 | | - // Initialize Model // |
60 | | - ////////////////////////////////////////////////// |
| 55 | + ////////////////////////////////////////////////// |
| 56 | + // Initialize Model // |
| 57 | + ////////////////////////////////////////////////// |
61 | 58 |
|
62 | | - std::shared_ptr<stan::io::var_context> var_context |
| 59 | + std::shared_ptr<stan::io::var_context> var_context |
63 | 60 | = get_var_context(shared_options.data_file); |
64 | 61 |
|
65 | | - stan::model::model_base &model |
| 62 | + stan::model::model_base& model |
66 | 63 | = new_model(*var_context, shared_options.seed, &std::cout); |
67 | 64 |
|
68 | | - write_stan(sample_writer); |
69 | | - write_model(sample_writer, model.model_name()); |
70 | | - print_old_command_header(app, shared_options, gq_options, sample_writer); |
71 | | - write_parallel_info(sample_writer); |
72 | | - write_opencl_device(sample_writer); |
73 | | - |
74 | | - std::ifstream stream(gq_options.fitted_params.c_str()); |
75 | | - if (stream.rdstate() & std::ifstream::failbit) { |
76 | | - std::stringstream msg; |
77 | | - msg << "Can't open specified file, \"" |
78 | | - << gq_options.fitted_params << "\"" |
79 | | - << std::endl; |
80 | | - throw std::invalid_argument(msg.str()); |
81 | | - } |
| 65 | + write_stan(sample_writer); |
| 66 | + write_model(sample_writer, model.model_name()); |
| 67 | + print_old_command_header(app, shared_options, gq_options, sample_writer); |
| 68 | + write_parallel_info(sample_writer); |
| 69 | + write_opencl_device(sample_writer); |
82 | 70 |
|
83 | | - stan::io::stan_csv fitted_params; |
| 71 | + std::ifstream stream(gq_options.fitted_params.c_str()); |
| 72 | + if (stream.rdstate() & std::ifstream::failbit) { |
84 | 73 | std::stringstream msg; |
85 | | - stan::io::stan_csv_reader |
86 | | - ::read_metadata(stream, fitted_params.metadata, &msg); |
87 | | - if (!stan::io::stan_csv_reader |
88 | | - ::read_header(stream, fitted_params.header, |
89 | | - &msg, false)) { |
90 | | - msg << "Error reading fitted param names from sample csv file \"" |
91 | | - << gq_options.fitted_params << "\"" << std::endl; |
92 | | - throw std::invalid_argument(msg.str()); |
93 | | - } |
94 | | - stan::io::stan_csv_reader |
95 | | - ::read_adaptation(stream, fitted_params.adaptation, &msg); |
96 | | - fitted_params.timing.warmup = 0; |
97 | | - fitted_params.timing.sampling = 0; |
98 | | - stan::io::stan_csv_reader::read_samples(stream, |
99 | | - fitted_params.samples, |
100 | | - fitted_params.timing, &msg); |
101 | | - stream.close(); |
| 74 | + msg << "Can't open specified file, \"" << gq_options.fitted_params << "\"" |
| 75 | + << std::endl; |
| 76 | + throw std::invalid_argument(msg.str()); |
| 77 | + } |
| 78 | + |
| 79 | + stan::io::stan_csv fitted_params; |
| 80 | + std::stringstream msg; |
| 81 | + stan::io::stan_csv_reader ::read_metadata(stream, fitted_params.metadata, |
| 82 | + &msg); |
| 83 | + if (!stan::io::stan_csv_reader ::read_header(stream, fitted_params.header, |
| 84 | + &msg, false)) { |
| 85 | + msg << "Error reading fitted param names from sample csv file \"" |
| 86 | + << gq_options.fitted_params << "\"" << std::endl; |
| 87 | + throw std::invalid_argument(msg.str()); |
| 88 | + } |
| 89 | + stan::io::stan_csv_reader ::read_adaptation(stream, fitted_params.adaptation, |
| 90 | + &msg); |
| 91 | + fitted_params.timing.warmup = 0; |
| 92 | + fitted_params.timing.sampling = 0; |
| 93 | + stan::io::stan_csv_reader::read_samples(stream, fitted_params.samples, |
| 94 | + fitted_params.timing, &msg); |
| 95 | + stream.close(); |
102 | 96 |
|
103 | | - std::vector<std::string> param_names; |
104 | | - model.constrained_param_names(param_names, false, false); |
105 | | - size_t num_cols = param_names.size(); |
106 | | - size_t num_rows = fitted_params.metadata.num_samples; |
107 | | - // check that all parameter names are in sample, in order |
108 | | - if (num_cols + hmc_fixed_cols > fitted_params.header.size()) { |
| 97 | + std::vector<std::string> param_names; |
| 98 | + model.constrained_param_names(param_names, false, false); |
| 99 | + size_t num_cols = param_names.size(); |
| 100 | + size_t num_rows = fitted_params.metadata.num_samples; |
| 101 | + // check that all parameter names are in sample, in order |
| 102 | + if (num_cols + hmc_fixed_cols > fitted_params.header.size()) { |
| 103 | + std::stringstream msg; |
| 104 | + msg << "Mismatch between model and fitted_parameters csv file \"" |
| 105 | + << gq_options.fitted_params << "\"" << std::endl; |
| 106 | + throw std::invalid_argument(msg.str()); |
| 107 | + } |
| 108 | + for (size_t i = 0; i < num_cols; ++i) { |
| 109 | + if (param_names[i].compare(fitted_params.header[i + hmc_fixed_cols]) != 0) { |
109 | 110 | std::stringstream msg; |
110 | 111 | msg << "Mismatch between model and fitted_parameters csv file \"" |
111 | | - << gq_options.fitted_params << "\"" << std::endl; |
| 112 | + << gq_options.fitted_params << "\"" << std::endl; |
112 | 113 | throw std::invalid_argument(msg.str()); |
113 | 114 | } |
114 | | - for (size_t i = 0; i < num_cols; ++i) { |
115 | | - if (param_names[i].compare(fitted_params.header[i + hmc_fixed_cols]) |
116 | | - != 0) { |
117 | | - std::stringstream msg; |
118 | | - msg << "Mismatch between model and fitted_parameters csv file \"" |
119 | | - << gq_options.fitted_params << "\"" << std::endl; |
120 | | - throw std::invalid_argument(msg.str()); |
121 | | - } |
122 | | - } |
123 | | - return stan::services |
124 | | - ::standalone_generate(model, |
125 | | - fitted_params.samples.block(0, hmc_fixed_cols, num_rows, num_cols), |
126 | | - shared_options.seed, interrupt, logger, sample_writer); |
127 | | - return stan::services::error_codes::CONFIG; |
128 | 115 | } |
| 116 | + return stan::services ::standalone_generate( |
| 117 | + model, fitted_params.samples.block(0, hmc_fixed_cols, num_rows, num_cols), |
| 118 | + shared_options.seed, interrupt, logger, sample_writer); |
| 119 | + return stan::services::error_codes::CONFIG; |
| 120 | +} |
129 | 121 | } // namespace cmdstan |
130 | 122 | #endif |
0 commit comments