opm-common
Loading...
Searching...
No Matches
ml_model.hpp
1/*
2 Copyright (c) 2016 Robert W. Rose
3 Copyright (c) 2018 Paul Maevskikh
4 Copyright (c) 2024 NORCE
5
6 Permission is hereby granted, free of charge, to any person obtaining a copy
7 of this software and associated documentation files (the "Software"), to deal
8 in the Software without restriction, including without limitation the rights
9 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 copies of the Software, and to permit persons to whom the Software is
11 furnished to do so, subject to the following conditions:
12
13 The above copyright notice and this permission notice shall be included in all
14 copies or substantial portions of the Software.
15
16 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 SOFTWARE.
23
24 Note: This file is based on kerasify/keras_model.hh
25*/
26
27#ifndef ML_MODEL_H_
28#define ML_MODEL_H_
29
30#include <fmt/format.h>
31
32#include <algorithm>
33#include <chrono>
34#include <cmath>
35#include <cstdio>
36#include <numeric>
37#include <opm/common/ErrorMacros.hpp>
39#include <string>
40#include <vector>
41
42namespace Opm
43{
44
45namespace ML
46{
47
48 // NN layer
49 // ---------------------
53 template <class T>
54 class Tensor
55 {
56 public:
57 Tensor()
58 {
59 }
60
61 explicit Tensor(int i)
62 {
63 resizeI<std::vector<int>>({i});
64 }
65
66 Tensor(int i, int j)
67 {
68 resizeI<std::vector<int>>({i, j});
69 }
70
71 Tensor(int i, int j, int k)
72 {
73 resizeI<std::vector<int>>({i, j, k});
74 }
75
76 Tensor(int i, int j, int k, int l)
77 {
78 resizeI<std::vector<int>>({i, j, k, l});
79 }
80
81 template <typename Sizes>
82 void resizeI(const Sizes& sizes)
83 {
84 if (sizes.size() == 1)
85 dims_ = {(int)sizes[0]};
86 if (sizes.size() == 2)
87 dims_ = {(int)sizes[0], (int)sizes[1]};
88 if (sizes.size() == 3)
89 dims_ = {(int)sizes[0], (int)sizes[1], (int)sizes[2]};
90 if (sizes.size() == 4)
91 dims_ = {(int)sizes[0], (int)sizes[1], (int)sizes[2], (int)sizes[3]};
92
93 data_.resize(std::accumulate(begin(dims_), end(dims_), 1.0, std::multiplies<>()));
94 }
95
96 void flatten()
97 {
98 OPM_ERROR_IF(dims_.size() == 0, "Invalid tensor");
99
100 int elements = dims_[0];
101 for (unsigned int i = 1; i < dims_.size(); i++) {
102 elements *= dims_[i];
103 }
104 dims_ = {elements};
105 }
106
107 T& operator()(int i)
108 {
109 OPM_ERROR_IF(dims_.size() != 1, "Invalid indexing for tensor");
110
111 OPM_ERROR_IF(!(i < dims_[0] && i >= 0),
112 fmt::format(" Invalid i: "
113 "{}"
114 " max: "
115 "{}",
116 i,
117 dims_[0]));
118
119 return data_[i];
120 }
121
122 T& operator()(int i, int j)
123 {
124 OPM_ERROR_IF(dims_.size() != 2, "Invalid indexing for tensor");
125 OPM_ERROR_IF(!(i < dims_[0] && i >= 0),
126 fmt::format(" Invalid i: "
127 "{}"
128 " max: "
129 "{}",
130 i,
131 dims_[0]));
132 OPM_ERROR_IF(!(j < dims_[1] && j >= 0),
133 fmt::format(" Invalid j: "
134 "{}"
135 " max: "
136 "{}",
137 j,
138 dims_[1]));
139
140 return data_[dims_[1] * i + j];
141 }
142
143 const T& operator()(int i, int j) const
144 {
145 OPM_ERROR_IF(dims_.size() != 2, "Invalid indexing for tensor");
146 OPM_ERROR_IF(!(i < dims_[0] && i >= 0),
147 fmt::format(" Invalid i: "
148 "{}"
149 " max: "
150 "{}",
151 i,
152 dims_[0]));
153 OPM_ERROR_IF(!(j < dims_[1] && j >= 0),
154 fmt::format(" Invalid j: "
155 "{}"
156 " max: "
157 "{}",
158 j,
159 dims_[1]));
160 return data_[dims_[1] * i + j];
161 }
162
163 T& operator()(int i, int j, int k)
164 {
165 OPM_ERROR_IF(dims_.size() != 3, "Invalid indexing for tensor");
166 OPM_ERROR_IF(!(i < dims_[0] && i >= 0),
167 fmt::format(" Invalid i: "
168 "{}"
169 " max: "
170 "{}",
171 i,
172 dims_[0]));
173 OPM_ERROR_IF(!(j < dims_[1] && j >= 0),
174 fmt::format(" Invalid j: "
175 "{}"
176 " max: "
177 "{}",
178 j,
179 dims_[1]));
180 OPM_ERROR_IF(!(k < dims_[2] && k >= 0),
181 fmt::format(" Invalid k: "
182 "{}"
183 " max: "
184 "{}",
185 k,
186 dims_[2]));
187
188 return data_[dims_[2] * (dims_[1] * i + j) + k];
189 }
190
191 const T& operator()(int i, int j, int k) const
192 {
193 OPM_ERROR_IF(dims_.size() != 3, "Invalid indexing for tensor");
194 OPM_ERROR_IF(!(i < dims_[0] && i >= 0),
195 fmt::format(" Invalid i: "
196 "{}"
197 " max: "
198 "{}",
199 i,
200 dims_[0]));
201 OPM_ERROR_IF(!(j < dims_[1] && j >= 0),
202 fmt::format(" Invalid j: "
203 "{}"
204 " max: "
205 "{}",
206 j,
207 dims_[1]));
208 OPM_ERROR_IF(!(k < dims_[2] && k >= 0),
209 fmt::format(" Invalid k: "
210 "{}"
211 " max: "
212 "{}",
213 k,
214 dims_[2]));
215
216 return data_[dims_[2] * (dims_[1] * i + j) + k];
217 }
218
219 T& operator()(int i, int j, int k, int l)
220 {
221 OPM_ERROR_IF(dims_.size() != 4, "Invalid indexing for tensor");
222 OPM_ERROR_IF(!(i < dims_[0] && i >= 0),
223 fmt::format(" Invalid i: "
224 "{}"
225 " max: "
226 "{}",
227 i,
228 dims_[0]));
229 OPM_ERROR_IF(!(j < dims_[1] && j >= 0),
230 fmt::format(" Invalid j: "
231 "{}"
232 " max: "
233 "{}",
234 j,
235 dims_[1]));
236 OPM_ERROR_IF(!(k < dims_[2] && k >= 0),
237 fmt::format(" Invalid k: "
238 "{}"
239 " max: "
240 "{}",
241 k,
242 dims_[2]));
243 OPM_ERROR_IF(!(l < dims_[3] && l >= 0),
244 fmt::format(" Invalid l: "
245 "{}"
246 " max: "
247 "{}",
248 l,
249 dims_[3]));
250
251 return data_[dims_[3] * (dims_[2] * (dims_[1] * i + j) + k) + l];
252 }
253
254 const T& operator()(int i, int j, int k, int l) const
255 {
256 OPM_ERROR_IF(dims_.size() != 4, "Invalid indexing for tensor");
257 OPM_ERROR_IF(!(i < dims_[0] && i >= 0),
258 fmt::format(" Invalid i: "
259 "{}"
260 " max: "
261 "{}",
262 i,
263 dims_[0]));
264 OPM_ERROR_IF(!(j < dims_[1] && j >= 0),
265 fmt::format(" Invalid j: "
266 "{}"
267 " max: "
268 "{}",
269 j,
270 dims_[1]));
271 OPM_ERROR_IF(!(k < dims_[2] && k >= 0),
272 fmt::format(" Invalid k: "
273 "{}"
274 " max: "
275 "{}",
276 k,
277 dims_[2]));
278 OPM_ERROR_IF(!(l < dims_[3] && l >= 0),
279 fmt::format(" Invalid l: "
280 "{}"
281 " max: "
282 "{}",
283 l,
284 dims_[3]));
285
286 return data_[dims_[3] * (dims_[2] * (dims_[1] * i + j) + k) + l];
287 }
288
289 void fill(const T& value)
290 {
291 std::fill(data_.begin(), data_.end(), value);
292 }
293
294 // Tensor addition
295 Tensor operator+(const Tensor& other)
296 {
297 OPM_ERROR_IF(dims_.size() != other.dims_.size(),
298 "Cannot add tensors with different dimensions");
299 Tensor result;
300 result.dims_ = dims_;
301 result.data_.resize(data_.size());
302
303 std::transform(data_.begin(),
304 data_.end(),
305 other.data_.begin(),
306 result.data_.begin(),
307 [](const T& x, const T& y) { return x + y; });
308
309 return result;
310 }
311
312 // Tensor multiplication
313 Tensor multiply(const Tensor& other)
314 {
315 OPM_ERROR_IF(dims_.size() != other.dims_.size(),
316 "Cannot multiply elements with different dimensions");
317
318 Tensor result;
319 result.dims_ = dims_;
320 result.data_.resize(data_.size());
321
322 std::transform(data_.begin(),
323 data_.end(),
324 other.data_.begin(),
325 result.data_.begin(),
326 [](const T& x, const T& y) { return x * y; });
327
328 return result;
329 }
330
331 // Tensor dot for 2d tensor
332 Tensor dot(const Tensor& other)
333 {
334 OPM_ERROR_IF(dims_.size() != 2, "Invalid tensor dimensions");
335 OPM_ERROR_IF(other.dims_.size() != 2, "Invalid tensor dimensions");
336
337 OPM_ERROR_IF(dims_[1] != other.dims_[0],
338 "Cannot multiply with different inner dimensions");
339
340 Tensor tmp(dims_[0], other.dims_[1]);
341
342 for (int i = 0; i < dims_[0]; i++) {
343 for (int j = 0; j < other.dims_[1]; j++) {
344 for (int k = 0; k < dims_[1]; k++) {
345 tmp(i, j) += (*this)(i, k) * other(k, j);
346 }
347 }
348 }
349
350 return tmp;
351 }
352
353 void swap(Tensor& other)
354 {
355 dims_.swap(other.dims_);
356 data_.swap(other.data_);
357 }
358
359 std::vector<int> dims_;
360 std::vector<T> data_;
361 };
362
363 // NN layer
364 // ---------------------
369 template <class Evaluation>
370 class NNLayer
371 {
372 public:
373 NNLayer()
374 {
375 }
376
377 virtual ~NNLayer()
378 {
379 }
380
381 // Loads the ML trained file, returns true if the file exists
382 virtual bool loadLayer(std::ifstream& file) = 0;
383 // Apply the NN layers
384 virtual bool apply(const Tensor<Evaluation>& in, Tensor<Evaluation>& out) = 0;
385 };
386
390 template <class Evaluation>
391 class NNLayerActivation : public NNLayer<Evaluation>
392 {
393 public:
394 enum class ActivationType {
395 kLinear = 1,
396 kRelu = 2,
397 kSoftPlus = 3,
398 kSigmoid = 4,
399 kTanh = 5,
400 kHardSigmoid = 6
401 };
402
403 NNLayerActivation()
404 : activation_type_(ActivationType::kLinear)
405 {
406 }
407
408 bool loadLayer(std::ifstream& file) override;
409
410 bool apply(const Tensor<Evaluation>& in, Tensor<Evaluation>& out) override;
411
412 private:
413 ActivationType activation_type_;
414 };
415
419 template <class Evaluation>
420 class NNLayerScaling : public NNLayer<Evaluation>
421 {
422 public:
423 NNLayerScaling()
424 : data_min(1.0f)
425 , data_max(1.0f)
426 , feat_inf(1.0f)
427 , feat_sup(1.0f)
428 {
429 }
430
431 bool loadLayer(std::ifstream& file) override;
432
433 bool apply(const Tensor<Evaluation>& in, Tensor<Evaluation>& out) override;
434
435 private:
436 Tensor<float> weights_;
437 Tensor<float> biases_;
438 float data_min;
439 float data_max;
440 float feat_inf;
441 float feat_sup;
442 };
443
447 template <class Evaluation>
448 class NNLayerUnScaling : public NNLayer<Evaluation>
449 {
450 public:
451 NNLayerUnScaling()
452 : data_min(1.0f)
453 , data_max(1.0f)
454 , feat_inf(1.0f)
455 , feat_sup(1.0f)
456 {
457 }
458
459 bool loadLayer(std::ifstream& file) override;
460
461 bool apply(const Tensor<Evaluation>& in, Tensor<Evaluation>& out) override;
462
463 private:
464 Tensor<float> weights_;
465 Tensor<float> biases_;
466 float data_min;
467 float data_max;
468 float feat_inf;
469 float feat_sup;
470 };
471
475 template <class Evaluation>
476 class NNLayerDense : public NNLayer<Evaluation>
477 {
478 public:
479 bool loadLayer(std::ifstream& file) override;
480
481 bool apply(const Tensor<Evaluation>& in, Tensor<Evaluation>& out) override;
482
483 private:
484 Tensor<float> weights_;
485 Tensor<float> biases_;
486
488 };
489
493 template <class Evaluation>
494 class NNLayerEmbedding : public NNLayer<Evaluation>
495 {
496 public:
497 bool loadLayer(std::ifstream& file) override;
498
499 bool apply(const Tensor<Evaluation>& in, Tensor<Evaluation>& out) override;
500
501 private:
502 Tensor<float> weights_;
503 };
504
508 template <class Evaluation>
510 {
511 public:
512 enum class LayerType { kScaling = 1, kUnScaling = 2, kDense = 3, kActivation = 4 };
513
514 virtual ~NNModel() = default;
515
516 // loads models (.model files) generated by Kerasify
517 virtual bool loadModel(const std::string& filename);
518
519 virtual bool apply(const Tensor<Evaluation>& in, Tensor<Evaluation>& out);
520
521 private:
522 std::vector<std::unique_ptr<NNLayer<Evaluation>>> layers_;
523 };
524
528 {
529 public:
530 void start()
531 {
532 start_ = std::chrono::high_resolution_clock::now();
533 }
534
535 float stop()
536 {
537 std::chrono::time_point<std::chrono::high_resolution_clock> now
538 = std::chrono::high_resolution_clock::now();
539
540 std::chrono::duration<double> diff = now - start_;
541
542 return diff.count();
543 }
544
545 private:
546 std::chrono::time_point<std::chrono::high_resolution_clock> start_;
547 };
548
549} // namespace ML
550
551} // namespace Opm
552
553#endif // ML_MODEL_H_
A number of commonly used algebraic functions for the localized OPM automatic differentiation (AD) fr...
Definition ml_model.hpp:392
Definition ml_model.hpp:477
Definition ml_model.hpp:495
Definition ml_model.hpp:510
Definition ml_model.hpp:528
Implements mathematical tensor (Max 4d).
Definition ml_model.hpp:55
This class implements a small container which holds the transmissibility mulitpliers for all the face...
Definition Exceptions.hpp:30