ONNX-Modelle in Klassen packen
Einführung
Im vorangegangenen Artikel haben wir zwei ONNX-Modelle verwendet, um den Abstimmungsklassifikator zu erstellen. Der gesamte Quelltext wurde als eine einzige MQ5-Datei organisiert. Der gesamte Code wurde in Funktionen unterteilt. Was aber, wenn wir versuchen, die Modelle zu tauschen? Oder ein weiteres Modell hinzufügen? Der ursprüngliche Text wird dann noch größer. Versuchen wir es mit dem objektorientierten Ansatz.
1. Welche Modelle werden wir verwenden?
Im vorherigen Voting Classifier haben wir ein Klassifizierungsmodell und ein Regressionsmodell verwendet. Im Regressionsmodell erhalten wir anstelle der vorhergesagten Preisbewegung (nach unten, nach oben, keine Veränderung) den vorhergesagten Preis, der zur Berechnung der Klasse verwendet wird. In diesem Fall verfügen wir jedoch nicht über eine Wahrscheinlichkeitsverteilung nach Klassen, sodass die so genannte „weiche Abstimmung“ nicht möglich ist.
Wir haben 3 Klassifizierungsmodelle vorbereitet. Zwei Modelle wurden bereits in dem Artikel „Ein Beispiel für die Zusammenstellung von ONNX-Modellen in MQL5“ verwendet. Das erste Modell (Regression) wurde in ein Klassifikationsmodell umgewandelt. Das Training wurde anhand einer Reihe von 10 OHLC-Preisen durchgeführt. Das zweite Modell ist das Klassifizierungsmodell. Das Training wurde an einer Reihe von 63 Schlusskursen durchgeführt.
Schließlich gibt es noch ein weiteres Modell. Das Klassifizierungsmodell wurde anhand einer Reihe von 30 Schlusskursen und einer Reihe von einfachen gleitenden Durchschnitten mit Mittelungslängen von 21 und 34 trainiert. Wir haben keine Annahmen über den Schnittpunkt der gleitenden Durchschnitte mit dem Close-Chart und untereinander gemacht - alle Muster werden vom Netzwerk in Form von Koeffizientenmatrizen zwischen den Schichten berechnet und gespeichert.
Alle Modelle wurden auf den Daten des MetaQuotes-Demo-Servers trainiert, EURUSD D1 von 2010.01.01 bis 2023.01.01. Die Trainingsskripte für alle drei Modelle wurden in Python geschrieben und sind diesem Artikel beigefügt. Wir werden ihre Quellcodes hier nicht angeben, um die Aufmerksamkeit des Lesers nicht vom Hauptthema unseres Artikels abzulenken.
2. Eine Basisklasse für alle Modelle ist erforderlich
Es gibt drei Modelle. Sie unterscheiden sich von den anderen durch den Umfang und die Aufbereitung der Eingabedaten. Alle Modelle verfügen über die gleiche Schnittstelle. Die Klassen aller Modelle müssen von der gleichen Basisklasse abgeleitet werden.
Lassen Sie uns versuchen, die Basisklasse darzustellen.
//+------------------------------------------------------------------+ //| ModelSymbolPeriod.mqh | //| Copyright 2023, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ //--- price movement prediction #define PRICE_UP 0 #define PRICE_SAME 1 #define PRICE_DOWN 2 //+------------------------------------------------------------------+ //| Base class for models based on trained symbol and period | //+------------------------------------------------------------------+ class CModelSymbolPeriod { protected: long m_handle; // created model session handle string m_symbol; // symbol of trained data ENUM_TIMEFRAMES m_period; // timeframe of trained data datetime m_next_bar; // time of next bar (we work at bar begin only) double m_class_delta; // delta to recognize "price the same" in regression models public: //+------------------------------------------------------------------+ //| Constructor | //+------------------------------------------------------------------+ CModelSymbolPeriod(const string symbol,const ENUM_TIMEFRAMES period,const double class_delta=0.0001) { m_handle=INVALID_HANDLE; m_symbol=symbol; m_period=period; m_next_bar=0; m_class_delta=class_delta; } //+------------------------------------------------------------------+ //| Destructor | //+------------------------------------------------------------------+ ~CModelSymbolPeriod(void) { Shutdown(); } //+------------------------------------------------------------------+ //| virtual stub for Init | //+------------------------------------------------------------------+ virtual bool Init(const string symbol,const ENUM_TIMEFRAMES period) { return(false); } //+------------------------------------------------------------------+ //| Check for initialization, create model | //+------------------------------------------------------------------+ bool CheckInit(const string symbol,const ENUM_TIMEFRAMES period,const uchar& model[]) { //--- check symbol, period if(symbol!=m_symbol || period!=m_period) { PrintFormat("Model must work with %s,%s",m_symbol,EnumToString(m_period)); return(false); } //--- create a model from static buffer m_handle=OnnxCreateFromBuffer(model,ONNX_DEFAULT); if(m_handle==INVALID_HANDLE) { Print("OnnxCreateFromBuffer error ",GetLastError()); return(false); } //--- ok return(true); } //+------------------------------------------------------------------+ //| Release ONNX session | //+------------------------------------------------------------------+ void Shutdown(void) { if(m_handle!=INVALID_HANDLE) { OnnxRelease(m_handle); m_handle=INVALID_HANDLE; } } //+------------------------------------------------------------------+ //| Check for continue OnTick | //+------------------------------------------------------------------+ virtual bool CheckOnTick(void) { //--- check new bar if(TimeCurrent()<m_next_bar) return(false); //--- set next bar time m_next_bar=TimeCurrent(); m_next_bar-=m_next_bar%PeriodSeconds(m_period); m_next_bar+=PeriodSeconds(m_period); //--- work on new day bar return(true); } //+------------------------------------------------------------------+ //| virtual stub for PredictPrice (regression model) | //+------------------------------------------------------------------+ virtual double PredictPrice(void) { return(DBL_MAX); } //+------------------------------------------------------------------+ //| Predict class (regression -> classification) | //+------------------------------------------------------------------+ virtual int PredictClass(void) { double predicted_price=PredictPrice(); if(predicted_price==DBL_MAX) return(-1); int predicted_class=-1; double last_close=iClose(m_symbol,m_period,1); //--- classify predicted price movement double delta=last_close-predicted_price; if(fabs(delta)<=m_class_delta) predicted_class=PRICE_SAME; else { if(delta<0) predicted_class=PRICE_UP; else predicted_class=PRICE_DOWN; } //--- return predicted class return(predicted_class); } }; //+------------------------------------------------------------------+
Die Basisklasse kann sowohl für Regressions- als auch für Klassifikationsmodelle verwendet werden. Wir müssen nur die entsprechende Methode in der Nachfolgeklasse — PredictPrice oder PredictClass — implementieren.
Die Basisklasse legt den Zeitraum für das Symbol fest, mit dem das Modell arbeiten soll (die Daten, mit denen das Modell trainiert wurde). Die Basisklasse prüft auch, ob der EA, der das Modell verwendet, mit dem benötigten Zeitraum des Symbols arbeitet, und erstellt eine ONNX-Sitzung zur Ausführung des Modells. Die Basisklasse bietet nur zu Beginn eines neuen Taktes Arbeit.
3. Erste Modellklasse
Unser erstes Modell heißt model.eurusd.D1.10.class.onnx. Es handelt sich um ein Klassifizierungsmodell, das auf EURUSD D1 anhand einer Serie von 10 OHLC-Kursen trainiert wurde.
//+------------------------------------------------------------------+ //| ModelEurusdD1_10Class.mqh | //| Copyright 2023, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #include "ModelSymbolPeriod.mqh" #resource "Python/model.eurusd.D1.10.class.onnx" as uchar model_eurusd_D1_10_class[] //+------------------------------------------------------------------+ //| ONNX-model wrapper class | //+------------------------------------------------------------------+ class CModelEurusdD1_10Class : public CModelSymbolPeriod { private: int m_sample_size; public: //+------------------------------------------------------------------+ //| Constructor | //+------------------------------------------------------------------+ CModelEurusdD1_10Class(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1) { m_sample_size=10; } //+------------------------------------------------------------------+ //| ONNX-model initialization | //+------------------------------------------------------------------+ virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period) { //--- check symbol, period, create model if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_10_class)) { Print("model_eurusd_D1_10_class : initialization error"); return(false); } //--- since not all sizes defined in the input tensor we must set them explicitly //--- first index - batch size, second index - series size, third index - number of series (OHLC) const long input_shape[] = {1,m_sample_size,4}; if(!OnnxSetInputShape(m_handle,0,input_shape)) { Print("model_eurusd_D1_10_class : OnnxSetInputShape error ",GetLastError()); return(false); } //--- since not all sizes defined in the output tensor we must set them explicitly //--- first index - batch size, must match the batch size of the input tensor //--- second index - number of classes (up, same or down) const long output_shape[] = {1,3}; if(!OnnxSetOutputShape(m_handle,0,output_shape)) { Print("model_eurusd_D1_10_class : OnnxSetOutputShape error ",GetLastError()); return(false); } //--- ok return(true); } //+------------------------------------------------------------------+ //| Predict class | //+------------------------------------------------------------------+ virtual int PredictClass(void) { static matrixf input_data(m_sample_size,4); // matrix for prepared input data static vectorf output_data(3); // vector to get result static matrix mm(m_sample_size,4); // matrix of horizontal vectors Mean static matrix ms(m_sample_size,4); // matrix of horizontal vectors Std static matrix x_norm(m_sample_size,4); // matrix for prices normalize //--- prepare input data matrix rates; //--- request last bars if(!rates.CopyRates(m_symbol,m_period,COPY_RATES_OHLC,1,m_sample_size)) return(-1); //--- get series Mean vector m=rates.Mean(1); //--- get series Std vector s=rates.Std(1); //--- prepare matrices for prices normalization for(int i=0; i<m_sample_size; i++) { mm.Row(m,i); ms.Row(s,i); } //--- the input of the model must be a set of vertical OHLC vectors x_norm=rates.Transpose(); //--- normalize prices x_norm-=mm; x_norm/=ms; //--- run the inference input_data.Assign(x_norm); if(!OnnxRun(m_handle,ONNX_NO_CONVERSION,input_data,output_data)) return(-1); //--- evaluate prediction return(int(output_data.ArgMax())); } }; //+------------------------------------------------------------------+
Wie bereits oben erwähnt: „Es gibt drei Modelle. Jedes unterscheidet sich von den anderen durch den Umfang und die Aufbereitung der Eingabedaten.“ Wir haben nur zwei Methoden neu definiert — Init und PredictClass. Die gleichen Methoden werden in den beiden anderen Klassen für die beiden anderen Modelle neu definiert.
Die Init-Methode ruft die Methode der Basisklasse CheckInit auf, in der eine Sitzung für unser ONNX-Modell erstellt wird und die Größen der Eingabe- und Ausgabetensoren explizit festgelegt werden. Hier gibt es mehr Kommentare als Code.
Die Methode PredictClass bietet genau die gleiche Vorbereitung der Eingabedaten wie beim Training des Modells. Die Eingabe ist eine Matrix von normalisierten OHLC-Preisen.
4. Sehen wir uns an, wie es funktioniert
Um die Leistungsfähigkeit unserer Klasse zu testen, wurde ein sehr kompakter Expert Advisor erstellt.
//+------------------------------------------------------------------+ //| ONNX.eurusd.D1.Prediction.mq5 | //| Copyright 2023, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #property copyright "Copyright 2023, MetaQuotes Ltd." #property link "https://www.mql5.com" #property version "1.00" #include "ModelEurusdD1_10Class.mqh" #include <Trade\Trade.mqh> input double InpLots = 1.0; // Lots amount to open position CModelEurusdD1_10Class ExtModel; CTrade ExtTrade; //+------------------------------------------------------------------+ //| Expert initialization function | //+------------------------------------------------------------------+ int OnInit() { if(!ExtModel.Init(_Symbol,_Period)) return(INIT_FAILED); //--- return(INIT_SUCCEEDED); } //+------------------------------------------------------------------+ //| Expert deinitialization function | //+------------------------------------------------------------------+ void OnDeinit(const int reason) { ExtModel.Shutdown(); } //+------------------------------------------------------------------+ //| Expert tick function | //+------------------------------------------------------------------+ void OnTick() { if(!ExtModel.CheckOnTick()) return; //--- predict next price movement int predicted_class=ExtModel.PredictClass(); //--- check trading according to prediction if(predicted_class>=0) if(PositionSelect(_Symbol)) CheckForClose(predicted_class); else CheckForOpen(predicted_class); } //+------------------------------------------------------------------+ //| Check for open position conditions | //+------------------------------------------------------------------+ void CheckForOpen(const int predicted_class) { ENUM_ORDER_TYPE signal=WRONG_VALUE; //--- check signals if(predicted_class==PRICE_DOWN) signal=ORDER_TYPE_SELL; // sell condition else { if(predicted_class==PRICE_UP) signal=ORDER_TYPE_BUY; // buy condition } //--- open position if possible according to signal if(signal!=WRONG_VALUE && TerminalInfoInteger(TERMINAL_TRADE_ALLOWED)) { double price=SymbolInfoDouble(_Symbol,(signal==ORDER_TYPE_SELL) ? SYMBOL_BID : SYMBOL_ASK); ExtTrade.PositionOpen(_Symbol,signal,InpLots,price,0,0); } } //+------------------------------------------------------------------+ //| Check for close position conditions | //+------------------------------------------------------------------+ void CheckForClose(const int predicted_class) { bool bsignal=false; //--- position already selected before long type=PositionGetInteger(POSITION_TYPE); //--- check signals if(type==POSITION_TYPE_BUY && predicted_class==PRICE_DOWN) bsignal=true; if(type==POSITION_TYPE_SELL && predicted_class==PRICE_UP) bsignal=true; //--- close position if possible if(bsignal && TerminalInfoInteger(TERMINAL_TRADE_ALLOWED)) { ExtTrade.PositionClose(_Symbol,3); //--- open opposite CheckForOpen(predicted_class); } } //+------------------------------------------------------------------+
Da das Modell auf Preisdaten bis 2023 trainiert wurde, starten wir den Test am 1. Januar 2023.
Das Ergebnis ist unten zu sehen:
Wie wir sehen können, ist das Modell voll funktionsfähig.
5. Zweite Modellklasse
Das zweite Modell heißt model.eurusd.D1.30.class.onnx. Das auf EURUSD D1 trainierte Klassifizierungsmodell basiert auf einer Serie von 30 Schlusskursen und zwei einfachen gleitenden Durchschnitten mit Mittelungslängen von 21 und 34.
//+------------------------------------------------------------------+ //| ModelEurusdD1_30Class.mqh | //| Copyright 2023, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #include "ModelSymbolPeriod.mqh" #resource "Python/model.eurusd.D1.30.class.onnx" as uchar model_eurusd_D1_30_class[] //+------------------------------------------------------------------+ //| ONNX-model wrapper class | //+------------------------------------------------------------------+ class CModelEurusdD1_30Class : public CModelSymbolPeriod { private: int m_sample_size; int m_fast_period; int m_slow_period; int m_sma_fast; int m_sma_slow; public: //+------------------------------------------------------------------+ //| Constructor | //+------------------------------------------------------------------+ CModelEurusdD1_30Class(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1) { m_sample_size=30; m_fast_period=21; m_slow_period=34; m_sma_fast=INVALID_HANDLE; m_sma_slow=INVALID_HANDLE; } //+------------------------------------------------------------------+ //| ONNX-model initialization | //+------------------------------------------------------------------+ virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period) { //--- check symbol, period, create model if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_30_class)) { Print("model_eurusd_D1_30_class : initialization error"); return(false); } //--- since not all sizes defined in the input tensor we must set them explicitly //--- first index - batch size, second index - series size, third index - number of series (Close, MA fast, MA slow) const long input_shape[] = {1,m_sample_size,3}; if(!OnnxSetInputShape(m_handle,0,input_shape)) { Print("model_eurusd_D1_30_class : OnnxSetInputShape error ",GetLastError()); return(false); } //--- since not all sizes defined in the output tensor we must set them explicitly //--- first index - batch size, must match the batch size of the input tensor //--- second index - number of classes (up, same or down) const long output_shape[] = {1,3}; if(!OnnxSetOutputShape(m_handle,0,output_shape)) { Print("model_eurusd_D1_30_class : OnnxSetOutputShape error ",GetLastError()); return(false); } //--- indicators m_sma_fast=iMA(m_symbol,m_period,m_fast_period,0,MODE_SMA,PRICE_CLOSE); m_sma_slow=iMA(m_symbol,m_period,m_slow_period,0,MODE_SMA,PRICE_CLOSE); if(m_sma_fast==INVALID_HANDLE || m_sma_slow==INVALID_HANDLE) { Print("model_eurusd_D1_30_class : cannot create indicator"); return(false); } //--- ok return(true); } //+------------------------------------------------------------------+ //| Predict class | //+------------------------------------------------------------------+ virtual int PredictClass(void) { static matrixf input_data(m_sample_size,3); // matrix for prepared input data static vectorf output_data(3); // vector to get result static matrix x_norm(m_sample_size,3); // matrix for prices normalize static vector vtemp(m_sample_size); static double ma_buffer[]; //--- request last bars if(!vtemp.CopyRates(m_symbol,m_period,COPY_RATES_CLOSE,1,m_sample_size)) return(-1); //--- get series Mean double m=vtemp.Mean(); //--- get series Std double s=vtemp.Std(); //--- normalize vtemp-=m; vtemp/=s; x_norm.Col(vtemp,0); //--- fast sma if(CopyBuffer(m_sma_fast,0,1,m_sample_size,ma_buffer)!=m_sample_size) return(-1); vtemp.Assign(ma_buffer); m=vtemp.Mean(); s=vtemp.Std(); vtemp-=m; vtemp/=s; x_norm.Col(vtemp,1); //--- slow sma if(CopyBuffer(m_sma_slow,0,1,m_sample_size,ma_buffer)!=m_sample_size) return(-1); vtemp.Assign(ma_buffer); m=vtemp.Mean(); s=vtemp.Std(); vtemp-=m; vtemp/=s; x_norm.Col(vtemp,2); //--- run the inference input_data.Assign(x_norm); if(!OnnxRun(m_handle,ONNX_NO_CONVERSION,input_data,output_data)) return(-1); //--- evaluate prediction return(int(output_data.ArgMax())); } }; //+------------------------------------------------------------------+
Wie in der vorherigen Klasse wird die Methode der Basisklasse CheckInit in der Init-Methode aufgerufen. In der Basisklassenmethode wird eine Sitzung für das ONNX-Modell erstellt und die Größen der Eingabe- und Ausgabetensoren werden explizit festgelegt.
Die Methode PredictClass liefert eine Reihe von 30 vorherigen Schlusskursen und berechneten gleitenden Durchschnitten. Die Daten werden auf die gleiche Weise normalisiert wie beim Training.
Schauen wir uns an, wie dieses Modell funktioniert. Zu diesem Zweck ändern wir nur zwei Zeilen des Test-EA
#include "ModelEurusdD1_30Class.mqh" #include <Trade\Trade.mqh> input double InpLots = 1.0; // Lots amount to open position CModelEurusdD1_30Class ExtModel; CTrade ExtTrade;
Die Testparameter sind die gleichen.
Wir sehen, dass das Modell funktioniert.
6. Die dritte Modellklasse
Das letzte Modell heißt model.eurusd.D1.63.class.onnx. Das Klassifizierungsmodell wurde für EURUSD D1 anhand einer Serie von 63 Schlusskursen trainiert.
//+------------------------------------------------------------------+ //| ModelEurusdD1_63.mqh | //| Copyright 2023, MetaQuotes Ltd. | //| https://www.mql5.com | //+------------------------------------------------------------------+ #include "ModelSymbolPeriod.mqh" #resource "Python/model.eurusd.D1.63.class.onnx" as uchar model_eurusd_D1_63_class[] //+------------------------------------------------------------------+ //| ONNX-model wrapper class | //+------------------------------------------------------------------+ class CModelEurusdD1_63Class : public CModelSymbolPeriod { private: int m_sample_size; public: //+------------------------------------------------------------------+ //| Constructor | //+------------------------------------------------------------------+ CModelEurusdD1_63Class(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1,0.0001) { m_sample_size=63; } //+------------------------------------------------------------------+ //| ONNX-model initialization | //+------------------------------------------------------------------+ virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period) { //--- check symbol, period, create model if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_63_class)) { Print("model_eurusd_D1_63_class : initialization error"); return(false); } //--- since not all sizes defined in the input tensor we must set them explicitly //--- first index - batch size, second index - series size const long input_shape[] = {1,m_sample_size}; if(!OnnxSetInputShape(m_handle,0,input_shape)) { Print("model_eurusd_D1_63_class : OnnxSetInputShape error ",GetLastError()); return(false); } //--- since not all sizes defined in the output tensor we must set them explicitly //--- first index - batch size, must match the batch size of the input tensor //--- second index - number of classes (up, same or down) const long output_shape[] = {1,3}; if(!OnnxSetOutputShape(m_handle,0,output_shape)) { Print("model_eurusd_D1_63_class : OnnxSetOutputShape error ",GetLastError()); return(false); } //--- ok return(true); } //+------------------------------------------------------------------+ //| Predict class | //+------------------------------------------------------------------+ virtual int PredictClass(void) { static vectorf input_data(m_sample_size); // vector for prepared input data static vectorf output_data(3); // vector to get result //--- request last bars if(!input_data.CopyRates(m_symbol,m_period,COPY_RATES_CLOSE,1,m_sample_size)) return(-1); //--- get series Mean float m=input_data.Mean(); //--- get series Std float s=input_data.Std(); //--- normalize prices input_data-=m; input_data/=s; //--- run the inference if(!OnnxRun(m_handle,ONNX_NO_CONVERSION,input_data,output_data)) return(-1); //--- evaluate prediction return(int(output_data.ArgMax())); } }; //+------------------------------------------------------------------+
Dies ist das einfachste der drei Modelle. Aus diesem Grund ist der Code für die Methode PredictClass so kompakt.
Ändern wir wieder zwei Zeilen im EA
#include "ModelEurusdD1_63Class.mqh" #include <Trade\Trade.mqh> input double InpLots = 1.0; // Lots amount to open position CModelEurusdD1_63Class ExtModel; CTrade ExtTrade;
Wir starten den Test mit denselben Einstellungen.
Das Modell funktioniert.
7. Sammeln aller Modelle in einem EA. Harte Abstimmung
Alle drei Modelle haben ihre Arbeitsfähigkeit bewiesen. Versuchen wir nun, ihre Anstrengungen zu bündeln. Lassen Sie uns über die Modelle abstimmen.
Forward-Deklarationen und Definitionen
#include "ModelEurusdD1_10Class.mqh" #include "ModelEurusdD1_30Class.mqh" #include "ModelEurusdD1_63Class.mqh" #include <Trade\Trade.mqh> input double InpLots = 1.0; // Lots amount to open position CModelSymbolPeriod *ExtModels[3]; CTrade ExtTrade;
Die Funktion OnInit
int OnInit() { ExtModels[0]=new CModelEurusdD1_10Class; ExtModels[1]=new CModelEurusdD1_30Class; ExtModels[2]=new CModelEurusdD1_63Class; for(long i=0; i<ExtModels.Size(); i++) if(!ExtModels[i].Init(_Symbol,_Period)) return(INIT_FAILED); //--- return(INIT_SUCCEEDED); }
Die Funktion OnTick
void OnTick() { for(long i=0; i<ExtModels.Size(); i++) if(!ExtModels[i].CheckOnTick()) return; //--- predict next price movement int returned[3]={0,0,0}; //--- collect returned classes for(long i=0; i<ExtModels.Size(); i++) { int pred=ExtModels[i].PredictClass(); if(pred>=0) returned[pred]++; } //--- get one prediction for all models int predicted_class=-1; //--- count votes for predictions for(int n=0; n<3; n++) { if(returned[n]>=2) { predicted_class=n; break; } } //--- check trading according to prediction if(predicted_class>=0) if(PositionSelect(_Symbol)) CheckForClose(predicted_class); else CheckForOpen(predicted_class); }
Die Mehrheit der Stimmen wird mit der Formel <Gesamtzahl der Stimmen>/2 + 1 berechnet. Bei einer Gesamtzahl von 3 Stimmen beträgt die Mehrheit 2 Stimmen. Dies ist eine so genannte „harte Abstimmung“.
Das Testergebnis mit den gleichen Einstellungen.
Erinnern wir uns an die Arbeit aller drei Modelle getrennt, nämlich die Anzahl der profitablen und unprofitablen Geschäfte. Erstes Modell — 11 : 3, 2. — 6 : 1, 3. — 16 : 10.
Es scheint, dass wir das Ergebnis mit Hilfe der harten Abstimmung verbessert haben — 16 : 4. Aber natürlich müssen wir uns die Berichte und Testdiagramme insgesamt ansehen.
8. Weiche Abstimmung
Die weiche Abstimmung unterscheidet sich von der harten Abstimmung dadurch, dass nicht die Anzahl der Stimmen berücksichtigt wird, sondern die Summe der Wahrscheinlichkeiten aller drei Klassen aus allen drei Modellen. Es wird die Klasse mit der höchsten Wahrscheinlichkeit gewählt.
Um eine weiche Abstimmung zu gewährleisten, müssen einige Änderungen vorgenommen werden.
In der Basisklasse:
//+------------------------------------------------------------------+ //| Predict class (regression -> classification) | //+------------------------------------------------------------------+ virtual int PredictClass(vector& probabilities) { ... //--- set predicted probability as 1.0 probabilities.Fill(0); if(predicted_class<(int)probabilities.Size()) probabilities[predicted_class]=1; //--- and return predicted class return(predicted_class); }
In den abgeleiteten Klassen:
//+------------------------------------------------------------------+ //| Predict class | //+------------------------------------------------------------------+ virtual int PredictClass(vector& probabilities) { ... //--- evaluate prediction probabilities.Assign(output_data); return(int(output_data.ArgMax())); }
In dem EA:
#include "ModelEurusdD1_10Class.mqh" #include "ModelEurusdD1_30Class.mqh" #include "ModelEurusdD1_63Class.mqh" #include <Trade\Trade.mqh> enum EnVotes { Two=2, // Two votes Three=3, // Three votes Soft=4 // Soft voting }; input double InpLots = 1.0; // Lots amount to open position input EnVotes InpVotes = Two; // Votes to make trade decision CModelSymbolPeriod *ExtModels[3]; CTrade ExtTrade;
void OnTick() { for(long i=0; i<ExtModels.Size(); i++) if(!ExtModels[i].CheckOnTick()) return; //--- predict next price movement int returned[3]={0,0,0}; vector soft=vector::Zeros(3); //--- collect returned classes for(long i=0; i<ExtModels.Size(); i++) { vector prob(3); int pred=ExtModels[i].PredictClass(prob); if(pred>=0) { returned[pred]++; soft+=prob; } } //--- get one prediction for all models int predicted_class=-1; //--- soft or hard voting if(InpVotes==Soft) predicted_class=(int)soft.ArgMax(); else { //--- count votes for predictions for(int n=0; n<3; n++) { if(returned[n]>=InpVotes) { predicted_class=n; break; } } } //--- check trading according to prediction if(predicted_class>=0) if(PositionSelect(_Symbol)) CheckForClose(predicted_class); else CheckForOpen(predicted_class); }
Die Prüfeinstellungen sind dieselben. Wählen Sie in den Eingaben Soft.
Das Ergebnis ist wie folgt.
Gewinnbringende Positionen — 15, nicht gewinnbringende Positionen — 3. Auch in finanzieller Hinsicht erwies sich die harte Abstimmung als besser als die weiche Abstimmung.
Betrachten wir das Ergebnis einer einstimmigen Abstimmung, d.h. mit einer Stimmenzahl von 3.
Sehr konservativer Handel. Die einzige Verlustposition wurde am Ende des Tests geschlossen (vielleicht wäre sie ja gar nicht unrentabel gewesen).
Wichtiger Hinweis: Bitte beachten Sie, dass die in diesem Artikel verwendeten Modelle nur zur Demonstration der Arbeit mit ONNX-Modellen in der Sprache MQL5 dienen. Der Expert Advisor ist nicht für einen Handel auf realen Konten gedacht.
Schlussfolgerung
In diesem Artikel haben wir gezeigt, wie die objektorientierte Programmierung das Schreiben von Programmen erleichtert. Die gesamte Komplexität der Modelle ist in ihren Klassen verborgen (die Modelle können viel komplexer sein als die, die wir als Beispiel vorgestellt haben). Der Rest der „Komplexität“ passt in 45 Zeilen der OnTick-Funktion.
Übersetzt aus dem Russischen von MetaQuotes Ltd.
Originalartikel: https://www.mql5.com/ru/articles/12484
- Freie Handelsapplikationen
- Freie Forex-VPS für 24 Stunden
- Über 8.000 Signale zum Kopieren
- Wirtschaftsnachrichten für die Lage an den Finanzmärkte
Sie stimmen der Website-Richtlinie und den Nutzungsbedingungen zu.